Skip to main content

haima_api/
auth.rs

1//! JWT authentication middleware for Haima.
2//!
3//! Reuses [`lago_auth::jwt`] for token validation. Haima does not need
4//! Lago session mapping — it only validates that the caller holds a valid
5//! JWT signed with the shared secret.
6//!
7//! **Behaviour**:
8//! - If no JWT secret is configured, auth is **disabled** (local dev mode)
9//!   and a warning is logged on startup.
10//! - If a secret IS configured, all protected routes require a valid
11//!   `Authorization: Bearer <token>` header.
12
13use std::sync::Arc;
14
15use axum::extract::Request;
16use axum::http::StatusCode;
17use axum::middleware::Next;
18use axum::response::{IntoResponse, Response};
19use serde::Serialize;
20
21use lago_auth::jwt::{extract_bearer_token, validate_jwt};
22
23/// Auth configuration injected as axum state.
24#[derive(Clone)]
25pub struct AuthConfig {
26    /// `None` means auth is disabled (local dev).
27    pub jwt_secret: Option<String>,
28}
29
30impl AuthConfig {
31    /// Create auth config from environment variables.
32    ///
33    /// Checks `HAIMA_JWT_SECRET` first, then falls back to `AUTH_SECRET`.
34    /// Returns `None` secret (auth disabled) if neither is set.
35    pub fn from_env() -> Self {
36        let secret = std::env::var("HAIMA_JWT_SECRET")
37            .ok()
38            .or_else(|| std::env::var("AUTH_SECRET").ok())
39            .filter(|s| !s.is_empty());
40
41        if secret.is_none() {
42            tracing::warn!(
43                "no HAIMA_JWT_SECRET or AUTH_SECRET configured — auth DISABLED (local dev mode)"
44            );
45        } else {
46            tracing::info!("JWT auth enabled for protected routes");
47        }
48
49        Self { jwt_secret: secret }
50    }
51}
52
53/// Auth error response body.
54#[derive(Serialize)]
55struct AuthErrorBody {
56    error: String,
57    message: String,
58}
59
60fn auth_error(status: StatusCode, message: impl Into<String>) -> Response {
61    let body = AuthErrorBody {
62        error: "unauthorized".to_string(),
63        message: message.into(),
64    };
65    (status, axum::Json(body)).into_response()
66}
67
68/// Axum middleware that validates JWT bearer tokens on protected routes.
69///
70/// If auth is disabled (no secret configured), requests pass through.
71/// If auth is enabled, a valid `Authorization: Bearer <token>` is required.
72pub async fn require_auth(
73    axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
74    request: Request,
75    next: Next,
76) -> Response {
77    // If no secret configured, auth is disabled — pass through
78    let Some(secret) = &config.jwt_secret else {
79        return next.run(request).await;
80    };
81
82    // Extract Authorization header
83    let auth_header = match request.headers().get("authorization") {
84        Some(h) => match h.to_str() {
85            Ok(s) => s.to_string(),
86            Err(_) => return auth_error(StatusCode::UNAUTHORIZED, "invalid authorization header"),
87        },
88        None => return auth_error(StatusCode::UNAUTHORIZED, "missing authorization header"),
89    };
90
91    // Extract bearer token
92    let token = match extract_bearer_token(&auth_header) {
93        Ok(t) => t,
94        Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
95    };
96
97    // Validate JWT
98    match validate_jwt(token, secret) {
99        Ok(claims) => {
100            tracing::debug!(
101                user_id = %claims.sub,
102                email = %claims.email,
103                "authenticated request"
104            );
105        }
106        Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
107    }
108
109    next.run(request).await
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use axum::Router;
116    use axum::body::Body;
117    use axum::http::Request as HttpRequest;
118    use axum::routing::get;
119    use jsonwebtoken::{EncodingKey, Header, encode};
120    use lago_auth::BroomvaClaims;
121    use tower::ServiceExt;
122
123    const TEST_SECRET: &str = "test-haima-secret-key";
124
125    fn make_token(sub: &str, email: &str, secret: &str) -> String {
126        let now = std::time::SystemTime::now()
127            .duration_since(std::time::UNIX_EPOCH)
128            .unwrap()
129            .as_secs();
130        let claims = BroomvaClaims {
131            sub: sub.to_string(),
132            email: email.to_string(),
133            exp: now + 3600,
134            iat: now,
135        };
136        let key = EncodingKey::from_secret(secret.as_bytes());
137        encode(&Header::default(), &claims, &key).unwrap()
138    }
139
140    fn test_app(auth_config: AuthConfig) -> Router {
141        let config = Arc::new(auth_config);
142        Router::new()
143            .route(
144                "/protected",
145                get(|| async { axum::Json(serde_json::json!({"ok": true})) }),
146            )
147            .layer(axum::middleware::from_fn_with_state(config, require_auth))
148            .route(
149                "/public",
150                get(|| async { axum::Json(serde_json::json!({"public": true})) }),
151            )
152    }
153
154    #[tokio::test]
155    async fn public_route_no_auth_needed() {
156        let app = test_app(AuthConfig {
157            jwt_secret: Some(TEST_SECRET.to_string()),
158        });
159        let req = HttpRequest::builder()
160            .uri("/public")
161            .body(Body::empty())
162            .unwrap();
163        let resp = app.oneshot(req).await.unwrap();
164        assert_eq!(resp.status(), StatusCode::OK);
165    }
166
167    #[tokio::test]
168    async fn protected_route_without_token_returns_401() {
169        let app = test_app(AuthConfig {
170            jwt_secret: Some(TEST_SECRET.to_string()),
171        });
172        let req = HttpRequest::builder()
173            .uri("/protected")
174            .body(Body::empty())
175            .unwrap();
176        let resp = app.oneshot(req).await.unwrap();
177        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
178    }
179
180    #[tokio::test]
181    async fn protected_route_with_valid_token_returns_200() {
182        let app = test_app(AuthConfig {
183            jwt_secret: Some(TEST_SECRET.to_string()),
184        });
185        let token = make_token("user1", "user1@broomva.tech", TEST_SECRET);
186        let req = HttpRequest::builder()
187            .uri("/protected")
188            .header("Authorization", format!("Bearer {token}"))
189            .body(Body::empty())
190            .unwrap();
191        let resp = app.oneshot(req).await.unwrap();
192        assert_eq!(resp.status(), StatusCode::OK);
193    }
194
195    #[tokio::test]
196    async fn protected_route_with_wrong_secret_returns_401() {
197        let app = test_app(AuthConfig {
198            jwt_secret: Some(TEST_SECRET.to_string()),
199        });
200        let token = make_token("user1", "user1@broomva.tech", "wrong-secret");
201        let req = HttpRequest::builder()
202            .uri("/protected")
203            .header("Authorization", format!("Bearer {token}"))
204            .body(Body::empty())
205            .unwrap();
206        let resp = app.oneshot(req).await.unwrap();
207        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
208    }
209
210    #[tokio::test]
211    async fn auth_disabled_passes_through() {
212        let app = test_app(AuthConfig { jwt_secret: None });
213        let req = HttpRequest::builder()
214            .uri("/protected")
215            .body(Body::empty())
216            .unwrap();
217        let resp = app.oneshot(req).await.unwrap();
218        assert_eq!(resp.status(), StatusCode::OK);
219    }
220
221    #[tokio::test]
222    async fn invalid_auth_header_returns_401() {
223        let app = test_app(AuthConfig {
224            jwt_secret: Some(TEST_SECRET.to_string()),
225        });
226        let req = HttpRequest::builder()
227            .uri("/protected")
228            .header("Authorization", "Basic abc123")
229            .body(Body::empty())
230            .unwrap();
231        let resp = app.oneshot(req).await.unwrap();
232        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
233    }
234}