1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::DbProvider;
13use kellnr_db::error::DbError;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::{debug, warn};
20
21#[derive(Debug)]
22pub struct Token {
23 pub value: String,
24 pub user: String,
25 pub is_admin: bool,
26 pub is_read_only: bool,
27}
28
29#[derive(Debug)]
31pub enum OptionToken {
32 None,
33 Some(Token),
34}
35
36pub fn generate_token() -> String {
37 let mut rng = rng();
38 iter::repeat(())
39 .map(|()| rng.sample(Alphanumeric))
40 .map(char::from)
41 .take(32)
42 .collect::<String>()
43}
44
45impl Token {
46 pub async fn from_header(
47 headers: &HeaderMap,
48 db: &Arc<dyn DbProvider>,
49 cache: &Arc<TokenCacheManager>,
50 settings: &Arc<Settings>,
51 ) -> Result<Self, StatusCode> {
52 Self::extract_token(headers, db, cache, settings).await
53 }
54
55 async fn extract_token(
56 headers: &HeaderMap,
57 db: &Arc<dyn DbProvider>,
58 cache: &Arc<TokenCacheManager>,
59 settings: &Arc<Settings>,
60 ) -> Result<Token, StatusCode> {
61 let mut token = headers
63 .get("Authorization")
64 .ok_or(StatusCode::UNAUTHORIZED)?
65 .to_str()
66 .map_err(|_| StatusCode::BAD_REQUEST)?;
67
68 if token.starts_with("Basic ") || token.starts_with("basic ") {
70 let decoded = STANDARD
71 .decode(&token[6..])
72 .map_err(|_| StatusCode::BAD_REQUEST)?;
73 let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
74 let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
75
76 let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
77 if db.authenticate_user(&user.name, token).await.is_err() {
78 return Err(StatusCode::FORBIDDEN);
79 }
80
81 return Ok(Token {
82 value: token.to_string(),
83 user: user.name,
84 is_admin: user.is_admin,
85 is_read_only: user.is_read_only,
86 });
87 }
88
89 if token.starts_with("Bearer ") || token.starts_with("bearer ") {
91 token = &token[7..];
92 }
93
94 if let Some(cached) = cache.get(token).await {
96 debug!("Token cache hit for user: {}", cached.user);
97 return Ok(Token {
98 value: token.to_string(),
99 user: cached.user,
100 is_admin: cached.is_admin,
101 is_read_only: cached.is_read_only,
102 });
103 }
104
105 let user = match get_user_with_retry(
107 db,
108 token,
109 settings.registry.token_db_retry_count,
110 settings.registry.token_db_retry_delay_ms,
111 )
112 .await
113 {
114 Ok(user) => user,
115 Err(e) => {
116 debug!("Token cache miss, DB lookup failed: {}", e);
117 return Err(StatusCode::FORBIDDEN);
118 }
119 };
120
121 debug!("Token cache miss, queried DB for user: {}", user.name);
122
123 cache
125 .insert(
126 token.to_string(),
127 CachedTokenData {
128 user: user.name.clone(),
129 is_admin: user.is_admin,
130 is_read_only: user.is_read_only,
131 },
132 )
133 .await;
134
135 Ok(Token {
136 value: token.to_string(),
137 user: user.name,
138 is_admin: user.is_admin,
139 is_read_only: user.is_read_only,
140 })
141 }
142}
143
144async fn get_user_with_retry(
145 db: &Arc<dyn DbProvider>,
146 token: &str,
147 max_retries: u32,
148 delay_ms: u64,
149) -> Result<kellnr_db::User, DbError> {
150 let mut attempts = 0;
151
152 loop {
153 match db.get_user_from_token(token).await {
154 Ok(user) => return Ok(user),
155 Err(e) => {
156 if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
158 return Err(e);
159 }
160
161 attempts += 1;
162 if attempts > max_retries {
163 warn!(
164 "Failed to get user from token after {} retries: {}",
165 max_retries, e
166 );
167 return Err(e);
168 }
169
170 warn!(
171 "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
172 attempts,
173 max_retries + 1,
174 delay_ms,
175 e
176 );
177 sleep(Duration::from_millis(delay_ms)).await;
178 }
179 }
180 }
181}
182
183impl FromRequestParts<AppStateData> for Token {
184 type Rejection = StatusCode;
185
186 async fn from_request_parts(
187 parts: &mut Parts,
188 state: &AppStateData,
189 ) -> Result<Self, Self::Rejection> {
190 Self::extract_token(
191 &parts.headers,
192 &state.db,
193 &state.token_cache,
194 &state.settings,
195 )
196 .await
197 }
198}
199
200impl FromRequestParts<AppStateData> for OptionToken {
201 type Rejection = StatusCode;
202
203 async fn from_request_parts(
204 parts: &mut Parts,
205 state: &AppStateData,
206 ) -> Result<Self, Self::Rejection> {
207 match Token::extract_token(
208 &parts.headers,
209 &state.db,
210 &state.token_cache,
211 &state.settings,
212 )
213 .await
214 {
215 Ok(token) => Ok(OptionToken::Some(token)),
216 Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
217 Err(status_code) => Err(status_code),
218 }
219 }
220}
221
222#[derive(Deserialize)]
223pub struct NewTokenReqData {
224 pub name: String,
225}
226
227#[cfg(test)]
228mod tests {
229 use std::sync::atomic::{AtomicU32, Ordering};
230
231 use kellnr_db::User;
232 use kellnr_db::error::DbError;
233 use kellnr_db::mock::MockDb;
234 use mockall::predicate::*;
235
236 use super::*;
237
238 fn test_user() -> User {
239 User {
240 id: 1,
241 name: "test_user".to_string(),
242 pwd: String::new(),
243 salt: String::new(),
244 is_admin: false,
245 is_read_only: false,
246 }
247 }
248
249 fn test_settings() -> Arc<Settings> {
250 Arc::new(Settings {
251 registry: kellnr_settings::Registry {
252 token_cache_enabled: true,
253 token_cache_ttl_seconds: 60,
254 token_cache_max_capacity: 100,
255 token_db_retry_count: 3,
256 token_db_retry_delay_ms: 1,
257 ..kellnr_settings::Registry::default()
258 },
259 ..Settings::default()
260 })
261 }
262
263 #[tokio::test]
268 async fn test_retry_succeeds_on_first_attempt() {
269 let mut mock_db = MockDb::new();
271 mock_db
272 .expect_get_user_from_token()
273 .with(eq("valid_token"))
274 .times(1)
275 .returning(|_| Ok(test_user()));
276
277 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
278 let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
279
280 assert!(result.is_ok());
281 assert_eq!(result.unwrap().name, "test_user");
282 }
283
284 #[tokio::test]
285 async fn test_retry_succeeds_after_transient_error() {
286 let call_count = Arc::new(AtomicU32::new(0));
288 let call_count_clone = call_count.clone();
289
290 let mut mock_db = MockDb::new();
291 mock_db
292 .expect_get_user_from_token()
293 .with(eq("token"))
294 .times(2)
295 .returning(move |_| {
296 let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
297 if count == 0 {
298 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
299 sea_orm::error::ConnAcquireErr::Timeout,
300 )))
301 } else {
302 Ok(test_user())
303 }
304 });
305
306 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
307 let result = get_user_with_retry(&db, "token", 3, 1).await;
308
309 assert!(result.is_ok());
310 assert_eq!(call_count.load(Ordering::SeqCst), 2);
311 }
312
313 #[tokio::test]
314 async fn test_no_retry_on_token_not_found() {
315 let mut mock_db = MockDb::new();
317 mock_db
318 .expect_get_user_from_token()
319 .with(eq("invalid_token"))
320 .times(1) .returning(|_| Err(DbError::TokenNotFound));
322
323 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
324 let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
325
326 assert!(result.is_err());
327 assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
328 }
329
330 #[tokio::test]
331 async fn test_no_retry_on_user_not_found() {
332 let mut mock_db = MockDb::new();
334 mock_db
335 .expect_get_user_from_token()
336 .with(eq("orphan_token"))
337 .times(1)
338 .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
339
340 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
341 let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
342
343 assert!(result.is_err());
344 assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
345 }
346
347 #[tokio::test]
348 async fn test_exhausts_retries_on_persistent_error() {
349 let call_count = Arc::new(AtomicU32::new(0));
351 let call_count_clone = call_count.clone();
352
353 let mut mock_db = MockDb::new();
354 mock_db
355 .expect_get_user_from_token()
356 .with(eq("token"))
357 .times(4) .returning(move |_| {
359 call_count_clone.fetch_add(1, Ordering::SeqCst);
360 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
361 sea_orm::error::ConnAcquireErr::Timeout,
362 )))
363 });
364
365 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
366 let result = get_user_with_retry(&db, "token", 3, 1).await;
367
368 assert!(result.is_err());
369 assert_eq!(call_count.load(Ordering::SeqCst), 4);
370 }
371
372 #[tokio::test]
377 async fn test_cache_hit_returns_cached_token() {
378 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
380 cache
381 .insert(
382 "cached_token".to_string(),
383 CachedTokenData {
384 user: "cached_user".to_string(),
385 is_admin: true,
386 is_read_only: false,
387 },
388 )
389 .await;
390
391 let mock_db = MockDb::new(); let db: Arc<dyn DbProvider> = Arc::new(mock_db);
393 let settings = test_settings();
394
395 let mut headers = HeaderMap::new();
396 headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
397
398 let result = Token::from_header(&headers, &db, &cache, &settings).await;
399
400 assert!(result.is_ok());
401 let token = result.unwrap();
402 assert_eq!(token.user, "cached_user");
403 assert!(token.is_admin);
404 }
405
406 #[tokio::test]
407 async fn test_cache_miss_queries_db_and_caches() {
408 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
409
410 let mut mock_db = MockDb::new();
411 mock_db
412 .expect_get_user_from_token()
413 .with(eq("new_token"))
414 .times(1)
415 .returning(|_| {
416 Ok(User {
417 id: 1,
418 name: "db_user".to_string(),
419 pwd: String::new(),
420 salt: String::new(),
421 is_admin: false,
422 is_read_only: true,
423 })
424 });
425
426 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
427 let settings = test_settings();
428
429 let mut headers = HeaderMap::new();
430 headers.insert("Authorization", "Bearer new_token".parse().unwrap());
431
432 let result = Token::from_header(&headers, &db, &cache, &settings).await;
433
434 assert!(result.is_ok());
435 let token = result.unwrap();
436 assert_eq!(token.user, "db_user");
437 assert!(token.is_read_only);
438
439 let cached = cache.get("new_token").await;
441 assert!(cached.is_some());
442 assert_eq!(cached.unwrap().user, "db_user");
443 }
444
445 #[tokio::test]
446 async fn test_cache_miss_with_invalid_token_returns_forbidden() {
447 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
448
449 let mut mock_db = MockDb::new();
450 mock_db
451 .expect_get_user_from_token()
452 .with(eq("bad_token"))
453 .times(1) .returning(|_| Err(DbError::TokenNotFound));
455
456 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
457 let settings = test_settings();
458
459 let mut headers = HeaderMap::new();
460 headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
461
462 let result = Token::from_header(&headers, &db, &cache, &settings).await;
463
464 assert!(result.is_err());
465 assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
466
467 let cached = cache.get("bad_token").await;
469 assert!(cached.is_none());
470 }
471
472 #[tokio::test]
473 async fn test_no_authorization_header_returns_unauthorized() {
474 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
475 let mock_db = MockDb::new();
476 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
477 let settings = test_settings();
478
479 let headers = HeaderMap::new(); let result = Token::from_header(&headers, &db, &cache, &settings).await;
482
483 assert!(result.is_err());
484 assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
485 }
486
487 #[tokio::test]
488 async fn test_disabled_cache_always_queries_db() {
489 let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); let call_count = Arc::new(AtomicU32::new(0));
493 let call_count_clone = call_count.clone();
494
495 let mut mock_db = MockDb::new();
496 mock_db
497 .expect_get_user_from_token()
498 .with(eq("token"))
499 .times(2) .returning(move |_| {
501 call_count_clone.fetch_add(1, Ordering::SeqCst);
502 Ok(test_user())
503 });
504
505 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
506 let settings = test_settings();
507
508 let mut headers = HeaderMap::new();
509 headers.insert("Authorization", "Bearer token".parse().unwrap());
510
511 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
513 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
515
516 assert_eq!(call_count.load(Ordering::SeqCst), 2);
517 }
518
519 #[tokio::test]
520 async fn test_lowercase_bearer_prefix_works() {
521 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
523
524 let mut mock_db = MockDb::new();
525 mock_db
526 .expect_get_user_from_token()
527 .with(eq("lowercase_token"))
528 .times(1)
529 .returning(|_| Ok(test_user()));
530
531 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
532 let settings = test_settings();
533
534 let mut headers = HeaderMap::new();
535 headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
536
537 let result = Token::from_header(&headers, &db, &cache, &settings).await;
538
539 assert!(result.is_ok());
540 assert_eq!(result.unwrap().user, "test_user");
541 }
542
543 #[tokio::test]
544 async fn test_zero_retries_only_attempts_once() {
545 let call_count = Arc::new(AtomicU32::new(0));
547 let call_count_clone = call_count.clone();
548
549 let mut mock_db = MockDb::new();
550 mock_db
551 .expect_get_user_from_token()
552 .with(eq("token"))
553 .times(1)
554 .returning(move |_| {
555 call_count_clone.fetch_add(1, Ordering::SeqCst);
556 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
557 sea_orm::error::ConnAcquireErr::Timeout,
558 )))
559 });
560
561 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
562 let result = get_user_with_retry(&db, "token", 0, 1).await;
563
564 assert!(result.is_err());
565 assert_eq!(call_count.load(Ordering::SeqCst), 1);
566 }
567}