1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::error::DbError;
13use kellnr_db::DbProvider;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::{debug, warn};
20
21#[derive(Debug)]
22pub struct Token {
23 pub value: String,
24 pub user: String,
25 pub is_admin: bool,
26 pub is_read_only: bool,
27}
28
29#[derive(Debug)]
31pub enum OptionToken {
32 None,
33 Some(Token),
34}
35
36pub fn generate_token() -> String {
37 let mut rng = rng();
38 iter::repeat(())
39 .map(|()| rng.sample(Alphanumeric))
40 .map(char::from)
41 .take(32)
42 .collect::<String>()
43}
44
45impl Token {
46 pub async fn from_header(
47 headers: &HeaderMap,
48 db: &Arc<dyn DbProvider>,
49 cache: &Arc<TokenCacheManager>,
50 settings: &Arc<Settings>,
51 ) -> Result<Self, StatusCode> {
52 Self::extract_token(headers, db, cache, settings).await
53 }
54
55 async fn extract_token(
56 headers: &HeaderMap,
57 db: &Arc<dyn DbProvider>,
58 cache: &Arc<TokenCacheManager>,
59 settings: &Arc<Settings>,
60 ) -> Result<Token, StatusCode> {
61 let mut token = headers
63 .get("Authorization")
64 .ok_or(StatusCode::UNAUTHORIZED)?
65 .to_str()
66 .map_err(|_| StatusCode::BAD_REQUEST)?;
67
68 if token.starts_with("Basic ") || token.starts_with("basic ") {
70 let decoded = STANDARD
71 .decode(&token[6..])
72 .map_err(|_| StatusCode::BAD_REQUEST)?;
73 let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
74 let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
75
76 let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
77 if db.authenticate_user(&user.name, token).await.is_err() {
78 return Err(StatusCode::FORBIDDEN);
79 }
80
81 return Ok(Token {
82 value: token.to_string(),
83 user: user.name,
84 is_admin: user.is_admin,
85 is_read_only: user.is_read_only,
86 });
87 }
88
89 if token.starts_with("Bearer ") || token.starts_with("bearer ") {
91 token = &token[7..];
92 }
93
94 if let Some(cached) = cache.get(token).await {
96 debug!("Token cache hit for user: {}", cached.user);
97 return Ok(Token {
98 value: token.to_string(),
99 user: cached.user,
100 is_admin: cached.is_admin,
101 is_read_only: cached.is_read_only,
102 });
103 }
104
105 let user = match get_user_with_retry(
107 db,
108 token,
109 settings.registry.token_db_retry_count,
110 settings.registry.token_db_retry_delay_ms,
111 )
112 .await
113 {
114 Ok(user) => user,
115 Err(e) => {
116 debug!("Token cache miss, DB lookup failed: {}", e);
117 return Err(StatusCode::FORBIDDEN);
118 }
119 };
120
121 debug!("Token cache miss, queried DB for user: {}", user.name);
122
123 cache
125 .insert(
126 token.to_string(),
127 CachedTokenData {
128 user: user.name.clone(),
129 is_admin: user.is_admin,
130 is_read_only: user.is_read_only,
131 },
132 )
133 .await;
134
135 Ok(Token {
136 value: token.to_string(),
137 user: user.name,
138 is_admin: user.is_admin,
139 is_read_only: user.is_read_only,
140 })
141 }
142}
143
144async fn get_user_with_retry(
145 db: &Arc<dyn DbProvider>,
146 token: &str,
147 max_retries: u32,
148 delay_ms: u64,
149) -> Result<kellnr_db::User, DbError> {
150 let mut attempts = 0;
151
152 loop {
153 match db.get_user_from_token(token).await {
154 Ok(user) => return Ok(user),
155 Err(e) => {
156 if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
158 return Err(e);
159 }
160
161 attempts += 1;
162 if attempts > max_retries {
163 warn!(
164 "Failed to get user from token after {} retries: {}",
165 max_retries, e
166 );
167 return Err(e);
168 }
169
170 warn!(
171 "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
172 attempts,
173 max_retries + 1,
174 delay_ms,
175 e
176 );
177 sleep(Duration::from_millis(delay_ms)).await;
178 }
179 }
180 }
181}
182
183impl FromRequestParts<AppStateData> for Token {
184 type Rejection = StatusCode;
185
186 async fn from_request_parts(
187 parts: &mut Parts,
188 state: &AppStateData,
189 ) -> Result<Self, Self::Rejection> {
190 Self::extract_token(&parts.headers, &state.db, &state.token_cache, &state.settings).await
191 }
192}
193
194impl FromRequestParts<AppStateData> for OptionToken {
195 type Rejection = StatusCode;
196
197 async fn from_request_parts(
198 parts: &mut Parts,
199 state: &AppStateData,
200 ) -> Result<Self, Self::Rejection> {
201 match Token::extract_token(&parts.headers, &state.db, &state.token_cache, &state.settings).await {
202 Ok(token) => Ok(OptionToken::Some(token)),
203 Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
204 Err(status_code) => Err(status_code),
205 }
206 }
207}
208
209#[derive(Deserialize)]
210pub struct NewTokenReqData {
211 pub name: String,
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use kellnr_db::User;
218 use kellnr_db::error::DbError;
219 use kellnr_db::mock::MockDb;
220 use mockall::predicate::*;
221 use std::sync::atomic::{AtomicU32, Ordering};
222
223 fn test_user() -> User {
224 User {
225 id: 1,
226 name: "test_user".to_string(),
227 pwd: String::new(),
228 salt: String::new(),
229 is_admin: false,
230 is_read_only: false,
231 }
232 }
233
234 fn test_settings() -> Arc<Settings> {
235 Arc::new(Settings {
236 registry: kellnr_settings::Registry {
237 token_cache_enabled: true,
238 token_cache_ttl_seconds: 60,
239 token_cache_max_capacity: 100,
240 token_db_retry_count: 3,
241 token_db_retry_delay_ms: 1,
242 ..kellnr_settings::Registry::default()
243 },
244 ..Settings::default()
245 })
246 }
247
248 #[tokio::test]
253 async fn test_retry_succeeds_on_first_attempt() {
254 let mut mock_db = MockDb::new();
256 mock_db
257 .expect_get_user_from_token()
258 .with(eq("valid_token"))
259 .times(1)
260 .returning(|_| Ok(test_user()));
261
262 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
263 let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
264
265 assert!(result.is_ok());
266 assert_eq!(result.unwrap().name, "test_user");
267 }
268
269 #[tokio::test]
270 async fn test_retry_succeeds_after_transient_error() {
271 let call_count = Arc::new(AtomicU32::new(0));
273 let call_count_clone = call_count.clone();
274
275 let mut mock_db = MockDb::new();
276 mock_db
277 .expect_get_user_from_token()
278 .with(eq("token"))
279 .times(2)
280 .returning(move |_| {
281 let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
282 if count == 0 {
283 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
284 sea_orm::error::ConnAcquireErr::Timeout,
285 )))
286 } else {
287 Ok(test_user())
288 }
289 });
290
291 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
292 let result = get_user_with_retry(&db, "token", 3, 1).await;
293
294 assert!(result.is_ok());
295 assert_eq!(call_count.load(Ordering::SeqCst), 2);
296 }
297
298 #[tokio::test]
299 async fn test_no_retry_on_token_not_found() {
300 let mut mock_db = MockDb::new();
302 mock_db
303 .expect_get_user_from_token()
304 .with(eq("invalid_token"))
305 .times(1) .returning(|_| Err(DbError::TokenNotFound));
307
308 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
309 let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
310
311 assert!(result.is_err());
312 assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
313 }
314
315 #[tokio::test]
316 async fn test_no_retry_on_user_not_found() {
317 let mut mock_db = MockDb::new();
319 mock_db
320 .expect_get_user_from_token()
321 .with(eq("orphan_token"))
322 .times(1)
323 .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
324
325 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
326 let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
327
328 assert!(result.is_err());
329 assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
330 }
331
332 #[tokio::test]
333 async fn test_exhausts_retries_on_persistent_error() {
334 let call_count = Arc::new(AtomicU32::new(0));
336 let call_count_clone = call_count.clone();
337
338 let mut mock_db = MockDb::new();
339 mock_db
340 .expect_get_user_from_token()
341 .with(eq("token"))
342 .times(4) .returning(move |_| {
344 call_count_clone.fetch_add(1, Ordering::SeqCst);
345 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
346 sea_orm::error::ConnAcquireErr::Timeout,
347 )))
348 });
349
350 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
351 let result = get_user_with_retry(&db, "token", 3, 1).await;
352
353 assert!(result.is_err());
354 assert_eq!(call_count.load(Ordering::SeqCst), 4);
355 }
356
357 #[tokio::test]
362 async fn test_cache_hit_returns_cached_token() {
363 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
365 cache
366 .insert(
367 "cached_token".to_string(),
368 CachedTokenData {
369 user: "cached_user".to_string(),
370 is_admin: true,
371 is_read_only: false,
372 },
373 )
374 .await;
375
376 let mock_db = MockDb::new(); let db: Arc<dyn DbProvider> = Arc::new(mock_db);
378 let settings = test_settings();
379
380 let mut headers = HeaderMap::new();
381 headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
382
383 let result = Token::from_header(&headers, &db, &cache, &settings).await;
384
385 assert!(result.is_ok());
386 let token = result.unwrap();
387 assert_eq!(token.user, "cached_user");
388 assert!(token.is_admin);
389 }
390
391 #[tokio::test]
392 async fn test_cache_miss_queries_db_and_caches() {
393 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
394
395 let mut mock_db = MockDb::new();
396 mock_db
397 .expect_get_user_from_token()
398 .with(eq("new_token"))
399 .times(1)
400 .returning(|_| {
401 Ok(User {
402 id: 1,
403 name: "db_user".to_string(),
404 pwd: String::new(),
405 salt: String::new(),
406 is_admin: false,
407 is_read_only: true,
408 })
409 });
410
411 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
412 let settings = test_settings();
413
414 let mut headers = HeaderMap::new();
415 headers.insert("Authorization", "Bearer new_token".parse().unwrap());
416
417 let result = Token::from_header(&headers, &db, &cache, &settings).await;
418
419 assert!(result.is_ok());
420 let token = result.unwrap();
421 assert_eq!(token.user, "db_user");
422 assert!(token.is_read_only);
423
424 let cached = cache.get("new_token").await;
426 assert!(cached.is_some());
427 assert_eq!(cached.unwrap().user, "db_user");
428 }
429
430 #[tokio::test]
431 async fn test_cache_miss_with_invalid_token_returns_forbidden() {
432 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
433
434 let mut mock_db = MockDb::new();
435 mock_db
436 .expect_get_user_from_token()
437 .with(eq("bad_token"))
438 .times(1) .returning(|_| Err(DbError::TokenNotFound));
440
441 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
442 let settings = test_settings();
443
444 let mut headers = HeaderMap::new();
445 headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
446
447 let result = Token::from_header(&headers, &db, &cache, &settings).await;
448
449 assert!(result.is_err());
450 assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
451
452 let cached = cache.get("bad_token").await;
454 assert!(cached.is_none());
455 }
456
457 #[tokio::test]
458 async fn test_no_authorization_header_returns_unauthorized() {
459 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
460 let mock_db = MockDb::new();
461 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
462 let settings = test_settings();
463
464 let headers = HeaderMap::new(); let result = Token::from_header(&headers, &db, &cache, &settings).await;
467
468 assert!(result.is_err());
469 assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
470 }
471
472 #[tokio::test]
473 async fn test_disabled_cache_always_queries_db() {
474 let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); let call_count = Arc::new(AtomicU32::new(0));
478 let call_count_clone = call_count.clone();
479
480 let mut mock_db = MockDb::new();
481 mock_db
482 .expect_get_user_from_token()
483 .with(eq("token"))
484 .times(2) .returning(move |_| {
486 call_count_clone.fetch_add(1, Ordering::SeqCst);
487 Ok(test_user())
488 });
489
490 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
491 let settings = test_settings();
492
493 let mut headers = HeaderMap::new();
494 headers.insert("Authorization", "Bearer token".parse().unwrap());
495
496 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
498 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
500
501 assert_eq!(call_count.load(Ordering::SeqCst), 2);
502 }
503
504 #[tokio::test]
505 async fn test_lowercase_bearer_prefix_works() {
506 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
508
509 let mut mock_db = MockDb::new();
510 mock_db
511 .expect_get_user_from_token()
512 .with(eq("lowercase_token"))
513 .times(1)
514 .returning(|_| Ok(test_user()));
515
516 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
517 let settings = test_settings();
518
519 let mut headers = HeaderMap::new();
520 headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
521
522 let result = Token::from_header(&headers, &db, &cache, &settings).await;
523
524 assert!(result.is_ok());
525 assert_eq!(result.unwrap().user, "test_user");
526 }
527
528 #[tokio::test]
529 async fn test_zero_retries_only_attempts_once() {
530 let call_count = Arc::new(AtomicU32::new(0));
532 let call_count_clone = call_count.clone();
533
534 let mut mock_db = MockDb::new();
535 mock_db
536 .expect_get_user_from_token()
537 .with(eq("token"))
538 .times(1)
539 .returning(move |_| {
540 call_count_clone.fetch_add(1, Ordering::SeqCst);
541 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
542 sea_orm::error::ConnAcquireErr::Timeout,
543 )))
544 });
545
546 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
547 let result = get_user_with_retry(&db, "token", 0, 1).await;
548
549 assert!(result.is_err());
550 assert_eq!(call_count.load(Ordering::SeqCst), 1);
551 }
552}