use crate::auth::AuthProvider;
use parking_lot::RwLock;
use polaris_system::api::API;
use std::sync::Arc;
pub struct HttpRouter {
routes: RwLock<Vec<axum::Router>>,
auth: RwLock<Option<Arc<dyn AuthProvider>>>,
}
impl std::fmt::Debug for HttpRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpRouter")
.field("route_count", &self.routes.read().len())
.field("auth", &self.auth.read().as_ref().map(|a| format!("{a:?}")))
.finish()
}
}
impl API for HttpRouter {}
impl HttpRouter {
pub(crate) fn new() -> Self {
Self {
routes: RwLock::new(Vec::new()),
auth: RwLock::new(None),
}
}
pub fn add_routes(&self, router: axum::Router) {
self.routes.write().push(router);
}
pub fn set_auth(&self, provider: impl AuthProvider) {
let mut guard = self.auth.write();
if guard.is_some() {
tracing::warn!("overwriting previously registered AuthProvider");
}
*guard = Some(Arc::new(provider));
}
pub(crate) fn take_routes(&self) -> Vec<axum::Router> {
std::mem::take(&mut *self.routes.write())
}
pub(crate) fn take_auth(&self) -> Option<Arc<dyn AuthProvider>> {
self.auth.write().take()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::routing::get;
#[test]
fn register_and_take_routes() {
let api = HttpRouter::new();
api.add_routes(axum::Router::new().route("/a", get(|| async { "a" })));
api.add_routes(axum::Router::new().route("/b", get(|| async { "b" })));
let routes = api.take_routes();
assert_eq!(routes.len(), 2);
let routes = api.take_routes();
assert!(routes.is_empty());
}
}