Skip to main content

haima_api/
auth.rs

1//! JWT authentication middleware for Haima.
2//!
3//! Reuses [`lago_auth::jwt`] for token validation. Haima does not need
4//! Lago session mapping — it only validates that the caller holds a valid
5//! JWT signed with the shared secret.
6//!
7//! **Behaviour**:
8//! - If no JWT secret is configured, auth is **disabled** (local dev mode)
9//!   and a warning is logged on startup.
10//! - If a secret IS configured, all protected routes require a valid
11//!   `Authorization: Bearer <token>` header.
12
13use std::sync::Arc;
14
15use axum::extract::Request;
16use axum::http::StatusCode;
17use axum::middleware::Next;
18use axum::response::{IntoResponse, Response};
19use serde::Serialize;
20
21use lago_auth::jwt::{extract_bearer_token, validate_jwt};
22
23/// Auth configuration injected as axum state.
24#[derive(Clone)]
25pub struct AuthConfig {
26    /// `None` means auth is disabled (local dev).
27    pub jwt_secret: Option<String>,
28}
29
30impl AuthConfig {
31    /// Create auth config from environment variables.
32    ///
33    /// Checks `HAIMA_JWT_SECRET` first, then falls back to `AUTH_SECRET`.
34    /// Returns `None` secret (auth disabled) if neither is set.
35    pub fn from_env() -> Self {
36        let secret = std::env::var("HAIMA_JWT_SECRET")
37            .ok()
38            .or_else(|| std::env::var("AUTH_SECRET").ok())
39            .filter(|s| !s.is_empty());
40
41        if secret.is_none() {
42            tracing::warn!(
43                "no HAIMA_JWT_SECRET or AUTH_SECRET configured — auth DISABLED (local dev mode)"
44            );
45        } else {
46            tracing::info!("JWT auth enabled for protected routes");
47        }
48
49        Self { jwt_secret: secret }
50    }
51}
52
53/// Auth error response body.
54#[derive(Serialize)]
55struct AuthErrorBody {
56    error: String,
57    message: String,
58}
59
60fn auth_error(status: StatusCode, message: impl Into<String>) -> Response {
61    let body = AuthErrorBody {
62        error: "unauthorized".to_string(),
63        message: message.into(),
64    };
65    (status, axum::Json(body)).into_response()
66}
67
68/// Axum middleware that validates JWT bearer tokens on protected routes.
69///
70/// If auth is disabled (no secret configured), requests pass through.
71/// If auth is enabled, a valid `Authorization: Bearer <token>` is required.
72pub async fn require_auth(
73    axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
74    request: Request,
75    next: Next,
76) -> Response {
77    // If no secret configured, auth is disabled — pass through
78    let Some(secret) = &config.jwt_secret else {
79        return next.run(request).await;
80    };
81
82    // Extract Authorization header
83    let auth_header = match request.headers().get("authorization") {
84        Some(h) => match h.to_str() {
85            Ok(s) => s.to_string(),
86            Err(_) => return auth_error(StatusCode::UNAUTHORIZED, "invalid authorization header"),
87        },
88        None => return auth_error(StatusCode::UNAUTHORIZED, "missing authorization header"),
89    };
90
91    // Extract bearer token
92    let token = match extract_bearer_token(&auth_header) {
93        Ok(t) => t,
94        Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
95    };
96
97    // Validate JWT
98    match validate_jwt(token, secret) {
99        Ok(claims) => {
100            tracing::debug!(
101                user_id = %claims.sub,
102                email = %claims.email,
103                "authenticated request"
104            );
105        }
106        Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
107    }
108
109    next.run(request).await
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use axum::Router;
116    use axum::body::Body;
117    use axum::http::Request as HttpRequest;
118    use axum::routing::get;
119    use jsonwebtoken::{EncodingKey, Header, encode};
120    use lago_auth::BroomvaClaims;
121    use tower::ServiceExt;
122
123    const TEST_SECRET: &str = "test-haima-secret-key";
124
125    fn make_token(sub: &str, email: &str, secret: &str) -> String {
126        let now = std::time::SystemTime::now()
127            .duration_since(std::time::UNIX_EPOCH)
128            .unwrap()
129            .as_secs();
130        let claims = BroomvaClaims {
131            sub: sub.to_string(),
132            email: email.to_string(),
133            exp: now + 3600,
134            iat: now,
135        };
136        let key = EncodingKey::from_secret(secret.as_bytes());
137        encode(&Header::default(), &claims, &key).unwrap()
138    }
139
140    fn test_app(auth_config: AuthConfig) -> Router {
141        let config = Arc::new(auth_config);
142        Router::new()
143            .route(
144                "/protected",
145                get(|| async { axum::Json(serde_json::json!({"ok": true})) }),
146            )
147            .layer(axum::middleware::from_fn_with_state(
148                config.clone(),
149                require_auth,
150            ))
151            .route(
152                "/public",
153                get(|| async { axum::Json(serde_json::json!({"public": true})) }),
154            )
155    }
156
157    #[tokio::test]
158    async fn public_route_no_auth_needed() {
159        let app = test_app(AuthConfig {
160            jwt_secret: Some(TEST_SECRET.to_string()),
161        });
162        let req = HttpRequest::builder()
163            .uri("/public")
164            .body(Body::empty())
165            .unwrap();
166        let resp = app.oneshot(req).await.unwrap();
167        assert_eq!(resp.status(), StatusCode::OK);
168    }
169
170    #[tokio::test]
171    async fn protected_route_without_token_returns_401() {
172        let app = test_app(AuthConfig {
173            jwt_secret: Some(TEST_SECRET.to_string()),
174        });
175        let req = HttpRequest::builder()
176            .uri("/protected")
177            .body(Body::empty())
178            .unwrap();
179        let resp = app.oneshot(req).await.unwrap();
180        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
181    }
182
183    #[tokio::test]
184    async fn protected_route_with_valid_token_returns_200() {
185        let app = test_app(AuthConfig {
186            jwt_secret: Some(TEST_SECRET.to_string()),
187        });
188        let token = make_token("user1", "user1@broomva.tech", TEST_SECRET);
189        let req = HttpRequest::builder()
190            .uri("/protected")
191            .header("Authorization", format!("Bearer {token}"))
192            .body(Body::empty())
193            .unwrap();
194        let resp = app.oneshot(req).await.unwrap();
195        assert_eq!(resp.status(), StatusCode::OK);
196    }
197
198    #[tokio::test]
199    async fn protected_route_with_wrong_secret_returns_401() {
200        let app = test_app(AuthConfig {
201            jwt_secret: Some(TEST_SECRET.to_string()),
202        });
203        let token = make_token("user1", "user1@broomva.tech", "wrong-secret");
204        let req = HttpRequest::builder()
205            .uri("/protected")
206            .header("Authorization", format!("Bearer {token}"))
207            .body(Body::empty())
208            .unwrap();
209        let resp = app.oneshot(req).await.unwrap();
210        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
211    }
212
213    #[tokio::test]
214    async fn auth_disabled_passes_through() {
215        let app = test_app(AuthConfig { jwt_secret: None });
216        let req = HttpRequest::builder()
217            .uri("/protected")
218            .body(Body::empty())
219            .unwrap();
220        let resp = app.oneshot(req).await.unwrap();
221        assert_eq!(resp.status(), StatusCode::OK);
222    }
223
224    #[tokio::test]
225    async fn invalid_auth_header_returns_401() {
226        let app = test_app(AuthConfig {
227            jwt_secret: Some(TEST_SECRET.to_string()),
228        });
229        let req = HttpRequest::builder()
230            .uri("/protected")
231            .header("Authorization", "Basic abc123")
232            .body(Body::empty())
233            .unwrap();
234        let resp = app.oneshot(req).await.unwrap();
235        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
236    }
237}