1use std::sync::Arc;
14
15use axum::extract::Request;
16use axum::http::StatusCode;
17use axum::middleware::Next;
18use axum::response::{IntoResponse, Response};
19use serde::Serialize;
20
21use lago_auth::jwt::{extract_bearer_token, validate_jwt};
22
23#[derive(Clone)]
25pub struct AuthConfig {
26 pub jwt_secret: Option<String>,
28}
29
30impl AuthConfig {
31 pub fn from_env() -> Self {
36 let secret = std::env::var("HAIMA_JWT_SECRET")
37 .ok()
38 .or_else(|| std::env::var("AUTH_SECRET").ok())
39 .filter(|s| !s.is_empty());
40
41 if secret.is_none() {
42 tracing::warn!(
43 "no HAIMA_JWT_SECRET or AUTH_SECRET configured — auth DISABLED (local dev mode)"
44 );
45 } else {
46 tracing::info!("JWT auth enabled for protected routes");
47 }
48
49 Self { jwt_secret: secret }
50 }
51}
52
53#[derive(Serialize)]
55struct AuthErrorBody {
56 error: String,
57 message: String,
58}
59
60fn auth_error(status: StatusCode, message: impl Into<String>) -> Response {
61 let body = AuthErrorBody {
62 error: "unauthorized".to_string(),
63 message: message.into(),
64 };
65 (status, axum::Json(body)).into_response()
66}
67
68pub async fn require_auth(
73 axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
74 request: Request,
75 next: Next,
76) -> Response {
77 let Some(secret) = &config.jwt_secret else {
79 return next.run(request).await;
80 };
81
82 let auth_header = match request.headers().get("authorization") {
84 Some(h) => match h.to_str() {
85 Ok(s) => s.to_string(),
86 Err(_) => return auth_error(StatusCode::UNAUTHORIZED, "invalid authorization header"),
87 },
88 None => return auth_error(StatusCode::UNAUTHORIZED, "missing authorization header"),
89 };
90
91 let token = match extract_bearer_token(&auth_header) {
93 Ok(t) => t,
94 Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
95 };
96
97 match validate_jwt(token, secret) {
99 Ok(claims) => {
100 tracing::debug!(
101 user_id = %claims.sub,
102 email = %claims.email,
103 "authenticated request"
104 );
105 }
106 Err(e) => return auth_error(StatusCode::UNAUTHORIZED, e.to_string()),
107 }
108
109 next.run(request).await
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use axum::Router;
116 use axum::body::Body;
117 use axum::http::Request as HttpRequest;
118 use axum::routing::get;
119 use jsonwebtoken::{EncodingKey, Header, encode};
120 use lago_auth::BroomvaClaims;
121 use tower::ServiceExt;
122
123 const TEST_SECRET: &str = "test-haima-secret-key";
124
125 fn make_token(sub: &str, email: &str, secret: &str) -> String {
126 let now = std::time::SystemTime::now()
127 .duration_since(std::time::UNIX_EPOCH)
128 .unwrap()
129 .as_secs();
130 let claims = BroomvaClaims {
131 sub: sub.to_string(),
132 email: email.to_string(),
133 exp: now + 3600,
134 iat: now,
135 };
136 let key = EncodingKey::from_secret(secret.as_bytes());
137 encode(&Header::default(), &claims, &key).unwrap()
138 }
139
140 fn test_app(auth_config: AuthConfig) -> Router {
141 let config = Arc::new(auth_config);
142 Router::new()
143 .route(
144 "/protected",
145 get(|| async { axum::Json(serde_json::json!({"ok": true})) }),
146 )
147 .layer(axum::middleware::from_fn_with_state(
148 config.clone(),
149 require_auth,
150 ))
151 .route(
152 "/public",
153 get(|| async { axum::Json(serde_json::json!({"public": true})) }),
154 )
155 }
156
157 #[tokio::test]
158 async fn public_route_no_auth_needed() {
159 let app = test_app(AuthConfig {
160 jwt_secret: Some(TEST_SECRET.to_string()),
161 });
162 let req = HttpRequest::builder()
163 .uri("/public")
164 .body(Body::empty())
165 .unwrap();
166 let resp = app.oneshot(req).await.unwrap();
167 assert_eq!(resp.status(), StatusCode::OK);
168 }
169
170 #[tokio::test]
171 async fn protected_route_without_token_returns_401() {
172 let app = test_app(AuthConfig {
173 jwt_secret: Some(TEST_SECRET.to_string()),
174 });
175 let req = HttpRequest::builder()
176 .uri("/protected")
177 .body(Body::empty())
178 .unwrap();
179 let resp = app.oneshot(req).await.unwrap();
180 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
181 }
182
183 #[tokio::test]
184 async fn protected_route_with_valid_token_returns_200() {
185 let app = test_app(AuthConfig {
186 jwt_secret: Some(TEST_SECRET.to_string()),
187 });
188 let token = make_token("user1", "user1@broomva.tech", TEST_SECRET);
189 let req = HttpRequest::builder()
190 .uri("/protected")
191 .header("Authorization", format!("Bearer {token}"))
192 .body(Body::empty())
193 .unwrap();
194 let resp = app.oneshot(req).await.unwrap();
195 assert_eq!(resp.status(), StatusCode::OK);
196 }
197
198 #[tokio::test]
199 async fn protected_route_with_wrong_secret_returns_401() {
200 let app = test_app(AuthConfig {
201 jwt_secret: Some(TEST_SECRET.to_string()),
202 });
203 let token = make_token("user1", "user1@broomva.tech", "wrong-secret");
204 let req = HttpRequest::builder()
205 .uri("/protected")
206 .header("Authorization", format!("Bearer {token}"))
207 .body(Body::empty())
208 .unwrap();
209 let resp = app.oneshot(req).await.unwrap();
210 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
211 }
212
213 #[tokio::test]
214 async fn auth_disabled_passes_through() {
215 let app = test_app(AuthConfig { jwt_secret: None });
216 let req = HttpRequest::builder()
217 .uri("/protected")
218 .body(Body::empty())
219 .unwrap();
220 let resp = app.oneshot(req).await.unwrap();
221 assert_eq!(resp.status(), StatusCode::OK);
222 }
223
224 #[tokio::test]
225 async fn invalid_auth_header_returns_401() {
226 let app = test_app(AuthConfig {
227 jwt_secret: Some(TEST_SECRET.to_string()),
228 });
229 let req = HttpRequest::builder()
230 .uri("/protected")
231 .header("Authorization", "Basic abc123")
232 .body(Body::empty())
233 .unwrap();
234 let resp = app.oneshot(req).await.unwrap();
235 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
236 }
237}