use crate::config::Config;
use crate::core::notification::Notifier;
use crate::storage::DbPool;
use axum::{
Router,
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::{Next, from_fn_with_state},
response::Response,
routing::{get, post},
};
use std::sync::Arc;
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
use tracing::warn;
pub mod auth;
pub mod gateway;
pub mod keys;
pub mod messages;
pub mod middleware;
pub mod rate_limit;
#[derive(Clone)]
pub struct AppState {
pub pool: DbPool,
pub config: Config,
pub notifier: Arc<dyn Notifier>,
pub extractor: rate_limit::IpKeyExtractor,
}
pub fn app_router(pool: DbPool, config: Config, notifier: Arc<dyn Notifier>) -> Router {
let extractor = rate_limit::IpKeyExtractor::new(&config.trusted_proxies);
let std_interval_ns = 1_000_000_000 / config.rate_limit_per_second.max(1);
let standard_conf = Arc::new(
GovernorConfigBuilder::default()
.per_nanosecond(std_interval_ns as u64)
.burst_size(config.rate_limit_burst)
.key_extractor(extractor.clone())
.finish()
.unwrap(),
);
let auth_interval_ns = 1_000_000_000 / config.auth_rate_limit_per_second.max(1);
let auth_conf = Arc::new(
GovernorConfigBuilder::default()
.per_nanosecond(auth_interval_ns as u64)
.burst_size(config.auth_rate_limit_burst)
.key_extractor(extractor.clone())
.finish()
.unwrap(),
);
let state = AppState { pool, config, notifier, extractor };
let auth_routes = Router::new()
.route("/accounts", post(auth::register))
.route("/sessions", post(auth::login))
.layer(GovernorLayer::new(auth_conf));
let api_routes = Router::new()
.route("/keys", post(keys::upload_keys))
.route("/keys/{userId}", get(keys::get_pre_key_bundle))
.route("/messages/{recipientId}", post(messages::send_message))
.route("/gateway", get(gateway::websocket_handler))
.layer(GovernorLayer::new(standard_conf));
Router::new()
.nest("/v1", auth_routes.merge(api_routes))
.layer(from_fn_with_state(state.clone(), log_rate_limit_events))
.with_state(state)
}
async fn log_rate_limit_events(State(state): State<AppState>, req: Request<Body>, next: Next) -> Response {
let method = req.method().clone();
let path = req.uri().path().to_string();
let headers = req.headers().clone();
let peer_addr = req.extensions().get::<axum::extract::ConnectInfo<std::net::SocketAddr>>().map(|info| info.0.ip());
let mut response = next.run(req).await;
if response.status() == StatusCode::TOO_MANY_REQUESTS {
let ip = peer_addr
.map(|addr| state.extractor.identify_client_ip(&headers, addr).to_string())
.unwrap_or_else(|| "unknown".into());
warn!("Rate limit hit: client_ip={}, method={}, path={}", ip, method, path);
if let Some(after) = response.headers().get("x-ratelimit-after") {
let after = after.clone();
response.headers_mut().insert("retry-after", after);
}
}
response
}