Skip to main content

mini_apm_admin/api/
auth.rs

1use axum::{
2    body::Body,
3    extract::State,
4    http::{Request, StatusCode, header},
5    middleware::Next,
6    response::Response,
7};
8
9use mini_apm::DbPool;
10
11/// Holds project information extracted from API key authentication
12#[derive(Clone, Debug)]
13pub struct ProjectContext {
14    pub project_id: Option<i64>,
15}
16
17pub async fn auth_middleware(
18    State(pool): State<DbPool>,
19    mut request: Request<Body>,
20    next: Next,
21) -> Result<Response, StatusCode> {
22    // Extract Authorization header
23    let auth_header = request
24        .headers()
25        .get(header::AUTHORIZATION)
26        .and_then(|h| h.to_str().ok());
27
28    let api_key = match auth_header {
29        Some(h) if h.starts_with("Bearer ") => &h[7..],
30        _ => return Err(StatusCode::UNAUTHORIZED),
31    };
32
33    // Always authenticate against project API keys
34    // A default project is always created on startup
35    match mini_apm::models::project::find_by_api_key(&pool, api_key) {
36        Ok(Some(project)) => {
37            request.extensions_mut().insert(ProjectContext {
38                project_id: Some(project.id),
39            });
40            Ok(next.run(request).await)
41        }
42        Ok(None) => Err(StatusCode::UNAUTHORIZED),
43        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use axum::{
51        Router,
52        body::Body,
53        http::{Request, StatusCode},
54        middleware,
55        routing::get,
56    };
57    use r2d2::Pool;
58    use r2d2_sqlite::SqliteConnectionManager;
59    use tower::util::ServiceExt;
60
61    fn create_test_pool() -> DbPool {
62        let manager = SqliteConnectionManager::memory();
63        let pool = Pool::builder().max_size(1).build(manager).unwrap();
64
65        let conn = pool.get().unwrap();
66        conn.execute_batch(
67            r#"
68            CREATE TABLE projects (
69                id INTEGER PRIMARY KEY,
70                name TEXT NOT NULL UNIQUE,
71                slug TEXT NOT NULL UNIQUE,
72                api_key TEXT NOT NULL UNIQUE,
73                created_at TEXT NOT NULL
74            );
75            "#,
76        )
77        .unwrap();
78
79        pool
80    }
81
82    async fn handler() -> &'static str {
83        "ok"
84    }
85
86    fn create_app(pool: DbPool) -> Router {
87        Router::new()
88            .route("/test", get(handler))
89            .layer(middleware::from_fn_with_state(
90                pool.clone(),
91                auth_middleware,
92            ))
93            .with_state(pool)
94    }
95
96    #[tokio::test]
97    async fn test_auth_requires_authorization_header() {
98        let pool = create_test_pool();
99        let app = create_app(pool);
100
101        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
102
103        let response = app.oneshot(req).await.unwrap();
104        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
105    }
106
107    #[tokio::test]
108    async fn test_auth_requires_bearer_prefix() {
109        let pool = create_test_pool();
110        let app = create_app(pool);
111
112        let req = Request::builder()
113            .uri("/test")
114            .header("Authorization", "Basic xyz")
115            .body(Body::empty())
116            .unwrap();
117
118        let response = app.oneshot(req).await.unwrap();
119        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
120    }
121
122    #[tokio::test]
123    async fn test_auth_rejects_invalid_key() {
124        let pool = create_test_pool();
125        // Create a valid project API key first
126        mini_apm::models::project::ensure_default_project(&pool).unwrap();
127
128        let app = create_app(pool);
129
130        let req = Request::builder()
131            .uri("/test")
132            .header("Authorization", "Bearer wrong_key")
133            .body(Body::empty())
134            .unwrap();
135
136        let response = app.oneshot(req).await.unwrap();
137        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
138    }
139
140    #[tokio::test]
141    async fn test_auth_accepts_valid_project_key() {
142        let pool = create_test_pool();
143        let project = mini_apm::models::project::ensure_default_project(&pool).unwrap();
144
145        let app = create_app(pool);
146
147        let req = Request::builder()
148            .uri("/test")
149            .header("Authorization", format!("Bearer {}", project.api_key))
150            .body(Body::empty())
151            .unwrap();
152
153        let response = app.oneshot(req).await.unwrap();
154        assert_eq!(response.status(), StatusCode::OK);
155    }
156}