1use std::sync::Arc;
14
15use axum::extract::Request;
16use axum::http::StatusCode;
17use axum::middleware::Next;
18use axum::response::{IntoResponse, Response};
19use serde::Serialize;
20
21use lago_auth::jwt::{extract_bearer_token, validate_jwt};
22
23#[derive(Clone)]
25pub struct AuthConfig {
26 pub jwt_secret: Option<String>,
28}
29
30impl AuthConfig {
31 pub fn from_env() -> Self {
36 let secret = std::env::var("HAIMA_JWT_SECRET")
37 .ok()
38 .or_else(|| std::env::var("AUTH_SECRET").ok())
39 .filter(|s| !s.is_empty());
40
41 if secret.is_none() {
42 tracing::warn!(
43 "no HAIMA_JWT_SECRET or AUTH_SECRET configured — auth DISABLED (local dev mode)"
44 );
45 } else {
46 tracing::info!("JWT auth enabled for protected routes");
47 }
48
49 Self { jwt_secret: secret }
50 }
51}
52
53#[derive(Serialize)]
55struct AuthErrorBody {
56 error: String,
57 message: String,
58}
59
60fn auth_error(status: StatusCode, message: impl Into<String>) -> Response {
61 let body = AuthErrorBody {
62 error: "unauthorized".to_string(),
63 message: message.into(),
64 };
65 (status, axum::Json(body)).into_response()
66}
67
68pub async fn require_auth(
73 axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
74 request: Request,
75 next: Next,
76) -> Response {
77 let Some(secret) = &config.jwt_secret else {
79 return next.run(request).await;
80 };
81
82 let auth_header = match request.headers().get("authorization") {
84 Some(h) => match h.to_str() {
85 Ok(s) => s.to_string(),
86 Err(_) => return auth_error(StatusCode::UNAUTHORIZED, "invalid authorization header"),
87 },
88 None => return auth_error(StatusCode::UNAUTHORIZED, "missing authorization header"),
89 };
90
91 let token = match extract_bearer_token(&auth_header) {
93 Ok(t) => t,
94 Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
95 };
96
97 match validate_jwt(token, secret) {
99 Ok(claims) => {
100 tracing::debug!(
101 user_id = %claims.sub,
102 email = %claims.email,
103 "authenticated request"
104 );
105 }
106 Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
107 }
108
109 next.run(request).await
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use axum::Router;
116 use axum::body::Body;
117 use axum::http::Request as HttpRequest;
118 use axum::routing::get;
119 use jsonwebtoken::{EncodingKey, Header, encode};
120 use lago_auth::BroomvaClaims;
121 use tower::ServiceExt;
122
123 const TEST_SECRET: &str = "test-haima-secret-key";
124
125 fn make_token(sub: &str, email: &str, secret: &str) -> String {
126 let now = std::time::SystemTime::now()
127 .duration_since(std::time::UNIX_EPOCH)
128 .unwrap()
129 .as_secs();
130 let claims = BroomvaClaims {
131 sub: sub.to_string(),
132 email: email.to_string(),
133 exp: now + 3600,
134 iat: now,
135 };
136 let key = EncodingKey::from_secret(secret.as_bytes());
137 encode(&Header::default(), &claims, &key).unwrap()
138 }
139
140 fn test_app(auth_config: AuthConfig) -> Router {
141 let config = Arc::new(auth_config);
142 Router::new()
143 .route(
144 "/protected",
145 get(|| async { axum::Json(serde_json::json!({"ok": true})) }),
146 )
147 .layer(axum::middleware::from_fn_with_state(config, require_auth))
148 .route(
149 "/public",
150 get(|| async { axum::Json(serde_json::json!({"public": true})) }),
151 )
152 }
153
154 #[tokio::test]
155 async fn public_route_no_auth_needed() {
156 let app = test_app(AuthConfig {
157 jwt_secret: Some(TEST_SECRET.to_string()),
158 });
159 let req = HttpRequest::builder()
160 .uri("/public")
161 .body(Body::empty())
162 .unwrap();
163 let resp = app.oneshot(req).await.unwrap();
164 assert_eq!(resp.status(), StatusCode::OK);
165 }
166
167 #[tokio::test]
168 async fn protected_route_without_token_returns_401() {
169 let app = test_app(AuthConfig {
170 jwt_secret: Some(TEST_SECRET.to_string()),
171 });
172 let req = HttpRequest::builder()
173 .uri("/protected")
174 .body(Body::empty())
175 .unwrap();
176 let resp = app.oneshot(req).await.unwrap();
177 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
178 }
179
180 #[tokio::test]
181 async fn protected_route_with_valid_token_returns_200() {
182 let app = test_app(AuthConfig {
183 jwt_secret: Some(TEST_SECRET.to_string()),
184 });
185 let token = make_token("user1", "user1@broomva.tech", TEST_SECRET);
186 let req = HttpRequest::builder()
187 .uri("/protected")
188 .header("Authorization", format!("Bearer {token}"))
189 .body(Body::empty())
190 .unwrap();
191 let resp = app.oneshot(req).await.unwrap();
192 assert_eq!(resp.status(), StatusCode::OK);
193 }
194
195 #[tokio::test]
196 async fn protected_route_with_wrong_secret_returns_401() {
197 let app = test_app(AuthConfig {
198 jwt_secret: Some(TEST_SECRET.to_string()),
199 });
200 let token = make_token("user1", "user1@broomva.tech", "wrong-secret");
201 let req = HttpRequest::builder()
202 .uri("/protected")
203 .header("Authorization", format!("Bearer {token}"))
204 .body(Body::empty())
205 .unwrap();
206 let resp = app.oneshot(req).await.unwrap();
207 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
208 }
209
210 #[tokio::test]
211 async fn auth_disabled_passes_through() {
212 let app = test_app(AuthConfig { jwt_secret: None });
213 let req = HttpRequest::builder()
214 .uri("/protected")
215 .body(Body::empty())
216 .unwrap();
217 let resp = app.oneshot(req).await.unwrap();
218 assert_eq!(resp.status(), StatusCode::OK);
219 }
220
221 #[tokio::test]
222 async fn invalid_auth_header_returns_401() {
223 let app = test_app(AuthConfig {
224 jwt_secret: Some(TEST_SECRET.to_string()),
225 });
226 let req = HttpRequest::builder()
227 .uri("/protected")
228 .header("Authorization", "Basic abc123")
229 .body(Body::empty())
230 .unwrap();
231 let resp = app.oneshot(req).await.unwrap();
232 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
233 }
234}