use std::path::Path;
use std::sync::Arc;
use axum::response::IntoResponse;
use axum::routing::any;
use axum::Router;
use futures::future::BoxFuture;
use sqlx::SqlitePool;
use tower_http::trace::TraceLayer;
use crate::config::Config;
use crate::cors::build_cors_layer;
use crate::error::AtrgError;
use crate::state::AppState;
type CleanupFn = Box<dyn FnOnce(SqlitePool) + Send>;
pub struct AtrgApp {
router: Router<AppState>,
builtin_router: Option<Router<AppState>>,
cleanup_fn: Option<CleanupFn>,
event_handler: Option<atrg_stream::EventHandler<AppState>>,
#[cfg(feature = "firehose")]
firehose_handler: Option<atrg_firehose::FirehoseHandler<AppState>>,
}
impl AtrgApp {
pub fn new() -> Self {
Self {
router: Router::new(),
builtin_router: None,
cleanup_fn: None,
event_handler: None,
#[cfg(feature = "firehose")]
firehose_handler: None,
}
}
pub fn mount(mut self, router: Router<AppState>) -> Self {
self.router = self.router.merge(router);
self
}
pub fn with_auth_routes(mut self, router: Router<AppState>) -> Self {
self.builtin_router = Some(router);
self
}
pub fn with_cleanup_task<F>(mut self, f: F) -> Self
where
F: FnOnce(SqlitePool) + Send + 'static,
{
self.cleanup_fn = Some(Box::new(f));
self
}
pub fn on_event<F, Fut>(mut self, handler: F) -> Self
where
F: Fn(atrg_stream::JetstreamEvent, AppState) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
{
self.event_handler = Some(Arc::new(move |event, state| {
Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
}));
self
}
#[cfg(feature = "firehose")]
pub fn on_firehose_event<F, Fut>(mut self, handler: F) -> Self
where
F: Fn(atrg_firehose::FirehoseEvent, AppState) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
{
self.firehose_handler = Some(std::sync::Arc::new(move |event, state| {
Box::pin(handler(event, state)) as BoxFuture<'static, anyhow::Result<()>>
}));
self
}
pub fn with_feed_generator(self, feed_router: Router<AppState>) -> Self {
self.mount(feed_router)
}
pub fn with_labeler(self, labeler_router: Router<AppState>) -> Self {
self.mount(labeler_router)
}
pub async fn run(self) -> anyhow::Result<()> {
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| {
tracing_subscriber::EnvFilter::new(
"info,atrg_core=debug,atrg_db=debug,atrg_auth=debug,atrg_cli=debug,tower_http=debug",
)
});
let _ = tracing_subscriber::fmt()
.with_env_filter(env_filter)
.try_init();
let config_path = std::env::var("ATRG_CONFIG").unwrap_or_else(|_| "./atrg.toml".into());
tracing::info!(path = %config_path, "loading configuration");
let config = Config::load(&config_path)?;
let config = Arc::new(config);
let db = atrg_db::connect(&config.database.url).await?;
atrg_db::run_internal_migrations(&db).await?;
let user_migrations = Path::new("./migrations");
if user_migrations.is_dir() {
atrg_db::run_user_migrations(&db, user_migrations).await?;
}
let http = reqwest::Client::builder()
.user_agent(format!("atrg/{}", crate::version()))
.build()?;
let identity = Arc::new(atrg_identity::IdentityResolver::with_defaults(http.clone()));
let state = AppState {
config: config.clone(),
db,
http,
identity,
};
let cors = build_cors_layer(&config.app.cors_origins);
let mut router = Router::new();
router = router
.route("/healthz", axum::routing::get(crate::health::healthz))
.route("/readyz", axum::routing::get(crate::health::readyz));
if let Some(builtin) = self.builtin_router {
router = router.merge(builtin);
}
let mut router = router
.merge(self.router)
.fallback(any(fallback_not_found))
.with_state(state.clone())
.layer(cors)
.layer(axum::middleware::from_fn(
crate::request_id::request_id_middleware,
))
.layer(TraceLayer::new_for_http());
if config.app.environment != "development" {
router = router.layer(axum::middleware::from_fn(
crate::security::security_headers_middleware,
));
}
if let Some(ref js_config) = config.jetstream {
if let Some(handler) = self.event_handler {
let stream_config = atrg_stream::StreamConfig {
host: js_config.host.clone(),
collections: js_config.collections.clone(),
zstd_dict: js_config.zstd_dict.clone(),
channel_capacity: js_config.channel_capacity,
max_lag_events: js_config.max_lag_events,
};
atrg_stream::spawn_consumer(&stream_config, state.clone(), handler).await?;
} else {
tracing::warn!("jetstream configured but no on_event handler registered");
}
}
#[cfg(feature = "firehose")]
if let Some(ref fh_config) = config.firehose {
if let Some(handler) = self.firehose_handler {
let firehose_config = atrg_firehose::FirehoseConfig {
relay: fh_config.relay.clone(),
cursor: fh_config.cursor,
channel_capacity: fh_config.channel_capacity,
};
atrg_firehose::spawn_firehose(&firehose_config, state.clone(), handler).await?;
} else {
tracing::warn!("firehose configured but no on_firehose_event handler registered");
}
}
if let Some(cleanup) = self.cleanup_fn {
cleanup(state.db.clone());
}
let addr = format!("{}:{}", config.app.host, config.app.port);
tracing::info!(addr = %addr, name = %config.app.name, "at-rust-go API serving");
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, router).await?;
Ok(())
}
}
impl Default for AtrgApp {
fn default() -> Self {
Self::new()
}
}
async fn fallback_not_found() -> impl IntoResponse {
AtrgError::NotFound
}
#[cfg(test)]
pub(crate) fn build_test_router(user_router: Router<AppState>, state: AppState) -> Router {
build_test_router_with_auth(None, user_router, state)
}
#[cfg(test)]
pub(crate) fn build_test_router_with_auth(
auth_router: Option<Router<AppState>>,
user_router: Router<AppState>,
state: AppState,
) -> Router {
let cors = build_cors_layer(&state.config.app.cors_origins);
let mut router = Router::new();
if let Some(auth) = auth_router {
router = router.merge(auth);
}
router
.merge(user_router)
.fallback(any(fallback_not_found))
.with_state(state)
.layer(cors)
.layer(TraceLayer::new_for_http())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{AppConfig, AuthConfig, Config, DatabaseConfig};
use axum::body::Body;
use axum::routing::get;
use axum::Json;
use http_body_util::BodyExt;
use hyper::Request;
use tower::ServiceExt;
async fn test_state() -> AppState {
let db = atrg_db::connect("sqlite::memory:").await.unwrap();
atrg_db::run_internal_migrations(&db).await.unwrap();
let config = Config {
app: AppConfig {
name: "test-app".into(),
host: "127.0.0.1".into(),
port: 3000,
secret_key: "a]3)FRd9-x4bQ7Y!kN2mW#pL8v$Tz0cS".into(),
cors_origins: vec![],
environment: "development".into(),
},
auth: AuthConfig {
client_id: "http://localhost:3000/client-metadata.json".into(),
redirect_uri: "http://localhost:3000/auth/callback".into(),
scope: "atproto transition:generic".into(),
},
database: DatabaseConfig {
url: "sqlite::memory:".into(),
},
jetstream: None,
firehose: None,
feed_generator: None,
labeler: None,
rate_limit: None,
};
AppState {
config: Arc::new(config),
db,
http: reqwest::Client::new(),
identity: Arc::new(atrg_identity::IdentityResolver::with_defaults(
reqwest::Client::new(),
)),
}
}
async fn body_bytes(response: axum::response::Response) -> Vec<u8> {
response
.into_body()
.collect()
.await
.unwrap()
.to_bytes()
.to_vec()
}
#[test]
fn atrg_app_default_is_new() {
let _app = AtrgApp::default();
}
#[test]
fn on_event_sets_handler() {
let app = AtrgApp::new().on_event(|_event, _state| async { Ok(()) });
assert!(app.event_handler.is_some());
}
#[tokio::test]
async fn mount_ping_returns_200_json() {
let state = test_state().await;
let user_router: Router<AppState> = Router::new().route(
"/ping",
get(|| async { Json(serde_json::json!({"pong": true})) }),
);
let app = build_test_router(user_router, state);
let request = Request::builder().uri("/ping").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), 200);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/json"),
"expected application/json, got {ct}"
);
let bytes = body_bytes(response).await;
let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(body["pong"], true);
}
#[tokio::test]
async fn unknown_route_returns_404_json() {
let state = test_state().await;
let app = build_test_router(Router::new(), state);
let request = Request::builder()
.uri("/does-not-exist")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), 404);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("application/json"),
"expected application/json, got {ct}"
);
let bytes = body_bytes(response).await;
let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(body["error"], "not_found");
assert_eq!(body["message"], "Not found");
}
#[tokio::test]
async fn multiple_mounts_accumulate_routes() {
let state = test_state().await;
let r1: Router<AppState> = Router::new().route(
"/a",
get(|| async { Json(serde_json::json!({"route": "a"})) }),
);
let r2: Router<AppState> = Router::new().route(
"/b",
get(|| async { Json(serde_json::json!({"route": "b"})) }),
);
let app = build_test_router(r1.merge(r2), state);
let resp_a = app
.clone()
.oneshot(Request::builder().uri("/a").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_a.status(), 200);
let resp_b = app
.oneshot(Request::builder().uri("/b").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp_b.status(), 200);
}
#[tokio::test]
async fn with_auth_routes_merges_builtin() {
let state = test_state().await;
let auth_router: Router<AppState> = Router::new().route(
"/auth/test",
get(|| async { Json(serde_json::json!({"auth": true})) }),
);
let user_router: Router<AppState> = Router::new().route(
"/ping",
get(|| async { Json(serde_json::json!({"pong": true})) }),
);
let app = build_test_router_with_auth(Some(auth_router), user_router, state);
let resp = app
.clone()
.oneshot(
Request::builder()
.uri("/auth/test")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), 200);
let resp = app
.oneshot(Request::builder().uri("/ping").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
}