use crate::api::rate_limit::{IpKeyExtractor, log_rate_limit_events};
use crate::config::Config;
use crate::core::account_service::AccountService;
use crate::core::attachment_service::AttachmentService;
use crate::core::key_service::KeyService;
use crate::core::message_service::MessageService;
use crate::core::notification::Notifier;
use crate::storage::{
DbPool, attachment_repo::AttachmentRepository, key_repo::KeyRepository, message_repo::MessageRepository,
refresh_token_repo::RefreshTokenRepository, user_repo::UserRepository,
};
use axum::{
Router,
middleware::from_fn_with_state,
routing::{get, post},
};
use std::sync::Arc;
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
pub mod attachments;
pub mod auth;
pub mod docs;
pub mod gateway;
pub mod keys;
pub mod messages;
pub mod middleware;
pub mod rate_limit;
#[derive(Clone)]
pub struct AppState {
pub pool: DbPool,
pub config: Config,
pub notifier: Arc<dyn Notifier>,
pub extractor: IpKeyExtractor,
pub s3_client: aws_sdk_s3::Client,
pub key_service: KeyService,
pub attachment_service: AttachmentService,
pub account_service: AccountService,
pub message_service: MessageService,
}
pub fn app_router(pool: DbPool, config: Config, notifier: Arc<dyn Notifier>, s3_client: aws_sdk_s3::Client) -> Router {
let extractor = IpKeyExtractor::new(config.server.trusted_proxies.clone());
let key_repo = KeyRepository::new();
let message_repo = MessageRepository::new();
let user_repo = UserRepository::new();
let refresh_repo = RefreshTokenRepository::new();
let attachment_repo = AttachmentRepository::new();
let key_service =
KeyService::new(pool.clone(), key_repo, message_repo.clone(), notifier.clone(), config.messaging.clone());
let attachment_service =
AttachmentService::new(pool.clone(), attachment_repo, s3_client.clone(), config.s3.clone(), config.ttl_days);
let account_service =
AccountService::new(pool.clone(), config.auth.clone(), key_service.clone(), user_repo, refresh_repo);
let message_service = MessageService::new(
pool.clone(),
message_repo.clone(),
notifier.clone(),
config.messaging.clone(),
config.ttl_days,
);
let std_interval_ns = 1_000_000_000 / config.rate_limit.per_second.max(1);
let standard_conf = Arc::new(
GovernorConfigBuilder::default()
.per_nanosecond(std_interval_ns as u64)
.burst_size(config.rate_limit.burst)
.key_extractor(extractor.clone())
.finish()
.unwrap(),
);
let auth_interval_ns = 1_000_000_000 / config.rate_limit.auth_per_second.max(1);
let auth_conf = Arc::new(
GovernorConfigBuilder::default()
.per_nanosecond(auth_interval_ns as u64)
.burst_size(config.rate_limit.auth_burst)
.key_extractor(extractor.clone())
.finish()
.unwrap(),
);
let state = AppState {
pool,
config,
notifier,
extractor,
s3_client,
key_service,
attachment_service,
account_service,
message_service,
};
let auth_routes = Router::new()
.route("/users", post(auth::register))
.route("/sessions", post(auth::login))
.route("/sessions", axum::routing::delete(auth::logout))
.route("/sessions/refresh", post(auth::refresh))
.layer(GovernorLayer::new(auth_conf));
let api_routes = Router::new()
.route("/keys", post(keys::upload_keys))
.route("/keys/{userId}", get(keys::get_pre_key_bundle))
.route("/messages/{recipientId}", post(messages::send_message))
.route("/gateway", get(gateway::websocket_handler))
.route("/attachments", post(attachments::upload_attachment))
.route("/attachments/{id}", get(attachments::download_attachment))
.layer(GovernorLayer::new(standard_conf));
Router::new()
.route("/openapi.yaml", get(docs::openapi_yaml))
.nest("/v1", auth_routes.merge(api_routes))
.layer(from_fn_with_state(state.clone(), log_rate_limit_events))
.with_state(state)
}