1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::DbProvider;
13use kellnr_db::error::DbError;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::warn;
20use utoipa::ToSchema;
21
22#[derive(Debug)]
23pub struct Token {
24 pub value: String,
25 pub user: String,
26 pub is_admin: bool,
27 pub is_read_only: bool,
28}
29
30#[derive(Debug)]
32pub enum OptionToken {
33 None,
34 Some(Token),
35}
36
37pub fn generate_token() -> String {
38 let mut rng = rng();
39 iter::repeat(())
40 .map(|()| rng.sample(Alphanumeric))
41 .map(char::from)
42 .take(32)
43 .collect::<String>()
44}
45
46impl Token {
47 pub async fn from_header(
48 headers: &HeaderMap,
49 db: &Arc<dyn DbProvider>,
50 cache: &Arc<TokenCacheManager>,
51 settings: &Arc<Settings>,
52 ) -> Result<Self, StatusCode> {
53 Self::extract_token(headers, db, cache, settings).await
54 }
55
56 async fn extract_token(
57 headers: &HeaderMap,
58 db: &Arc<dyn DbProvider>,
59 cache: &Arc<TokenCacheManager>,
60 settings: &Arc<Settings>,
61 ) -> Result<Token, StatusCode> {
62 let mut token = headers
64 .get("Authorization")
65 .ok_or(StatusCode::UNAUTHORIZED)?
66 .to_str()
67 .map_err(|_| StatusCode::BAD_REQUEST)?;
68
69 if token.starts_with("Basic ") || token.starts_with("basic ") {
71 let decoded = STANDARD
72 .decode(&token[6..])
73 .map_err(|_| StatusCode::BAD_REQUEST)?;
74 let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
75 let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
76
77 let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
78 if db.authenticate_user(&user.name, token).await.is_err() {
79 return Err(StatusCode::FORBIDDEN);
80 }
81
82 return Ok(Token {
83 value: token.to_string(),
84 user: user.name,
85 is_admin: user.is_admin,
86 is_read_only: user.is_read_only,
87 });
88 }
89
90 if token.starts_with("Bearer ") || token.starts_with("bearer ") {
92 token = &token[7..];
93 }
94
95 if let Some(cached) = cache.get(token).await {
97 return Ok(Token {
98 value: token.to_string(),
99 user: cached.user,
100 is_admin: cached.is_admin,
101 is_read_only: cached.is_read_only,
102 });
103 }
104
105 let Ok(user) = get_user_with_retry(
107 db,
108 token,
109 settings.registry.token_db_retry_count,
110 settings.registry.token_db_retry_delay_ms,
111 )
112 .await
113 else {
114 return Err(StatusCode::FORBIDDEN);
115 };
116
117 cache
119 .insert(
120 token.to_string(),
121 CachedTokenData {
122 user: user.name.clone(),
123 is_admin: user.is_admin,
124 is_read_only: user.is_read_only,
125 },
126 )
127 .await;
128
129 Ok(Token {
130 value: token.to_string(),
131 user: user.name,
132 is_admin: user.is_admin,
133 is_read_only: user.is_read_only,
134 })
135 }
136}
137
138async fn get_user_with_retry(
139 db: &Arc<dyn DbProvider>,
140 token: &str,
141 max_retries: u32,
142 delay_ms: u64,
143) -> Result<kellnr_db::User, DbError> {
144 let mut attempts = 0;
145
146 loop {
147 match db.get_user_from_token(token).await {
148 Ok(user) => return Ok(user),
149 Err(e) => {
150 if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
152 return Err(e);
153 }
154
155 attempts += 1;
156 if attempts > max_retries {
157 warn!(
158 "Failed to get user from token after {} retries: {}",
159 max_retries, e
160 );
161 return Err(e);
162 }
163
164 warn!(
165 "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
166 attempts,
167 max_retries + 1,
168 delay_ms,
169 e
170 );
171 sleep(Duration::from_millis(delay_ms)).await;
172 }
173 }
174 }
175}
176
177impl FromRequestParts<AppStateData> for Token {
178 type Rejection = StatusCode;
179
180 async fn from_request_parts(
181 parts: &mut Parts,
182 state: &AppStateData,
183 ) -> Result<Self, Self::Rejection> {
184 Self::extract_token(
185 &parts.headers,
186 &state.db,
187 &state.token_cache,
188 &state.settings,
189 )
190 .await
191 }
192}
193
194impl FromRequestParts<AppStateData> for OptionToken {
195 type Rejection = StatusCode;
196
197 async fn from_request_parts(
198 parts: &mut Parts,
199 state: &AppStateData,
200 ) -> Result<Self, Self::Rejection> {
201 match Token::extract_token(
202 &parts.headers,
203 &state.db,
204 &state.token_cache,
205 &state.settings,
206 )
207 .await
208 {
209 Ok(token) => Ok(OptionToken::Some(token)),
210 Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
211 Err(status_code) => Err(status_code),
212 }
213 }
214}
215
216#[derive(Deserialize, ToSchema)]
217pub struct NewTokenReqData {
218 pub name: String,
219}
220
221#[cfg(test)]
222mod tests {
223 use std::sync::atomic::{AtomicU32, Ordering};
224
225 use kellnr_db::User;
226 use kellnr_db::error::DbError;
227 use kellnr_db::mock::MockDb;
228 use mockall::predicate::*;
229
230 use super::*;
231
232 fn test_user() -> User {
233 User {
234 id: 1,
235 name: "test_user".to_string(),
236 pwd: String::new(),
237 salt: String::new(),
238 is_admin: false,
239 is_read_only: false,
240 created: String::new(),
241 }
242 }
243
244 fn test_settings() -> Arc<Settings> {
245 Arc::new(Settings {
246 registry: kellnr_settings::Registry {
247 token_cache_enabled: true,
248 token_cache_ttl_seconds: 60,
249 token_cache_max_capacity: 100,
250 token_db_retry_count: 3,
251 token_db_retry_delay_ms: 1,
252 ..kellnr_settings::Registry::default()
253 },
254 ..Settings::default()
255 })
256 }
257
258 #[tokio::test]
263 async fn test_retry_succeeds_on_first_attempt() {
264 let mut mock_db = MockDb::new();
266 mock_db
267 .expect_get_user_from_token()
268 .with(eq("valid_token"))
269 .times(1)
270 .returning(|_| Ok(test_user()));
271
272 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
273 let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
274
275 assert!(result.is_ok());
276 assert_eq!(result.unwrap().name, "test_user");
277 }
278
279 #[tokio::test]
280 async fn test_retry_succeeds_after_transient_error() {
281 let call_count = Arc::new(AtomicU32::new(0));
283 let call_count_clone = call_count.clone();
284
285 let mut mock_db = MockDb::new();
286 mock_db
287 .expect_get_user_from_token()
288 .with(eq("token"))
289 .times(2)
290 .returning(move |_| {
291 let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
292 if count == 0 {
293 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
294 sea_orm::error::ConnAcquireErr::Timeout,
295 )))
296 } else {
297 Ok(test_user())
298 }
299 });
300
301 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
302 let result = get_user_with_retry(&db, "token", 3, 1).await;
303
304 assert!(result.is_ok());
305 assert_eq!(call_count.load(Ordering::SeqCst), 2);
306 }
307
308 #[tokio::test]
309 async fn test_no_retry_on_token_not_found() {
310 let mut mock_db = MockDb::new();
312 mock_db
313 .expect_get_user_from_token()
314 .with(eq("invalid_token"))
315 .times(1) .returning(|_| Err(DbError::TokenNotFound));
317
318 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
319 let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
320
321 assert!(result.is_err());
322 assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
323 }
324
325 #[tokio::test]
326 async fn test_no_retry_on_user_not_found() {
327 let mut mock_db = MockDb::new();
329 mock_db
330 .expect_get_user_from_token()
331 .with(eq("orphan_token"))
332 .times(1)
333 .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
334
335 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
336 let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
337
338 assert!(result.is_err());
339 assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
340 }
341
342 #[tokio::test]
343 async fn test_exhausts_retries_on_persistent_error() {
344 let call_count = Arc::new(AtomicU32::new(0));
346 let call_count_clone = call_count.clone();
347
348 let mut mock_db = MockDb::new();
349 mock_db
350 .expect_get_user_from_token()
351 .with(eq("token"))
352 .times(4) .returning(move |_| {
354 call_count_clone.fetch_add(1, Ordering::SeqCst);
355 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
356 sea_orm::error::ConnAcquireErr::Timeout,
357 )))
358 });
359
360 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
361 let result = get_user_with_retry(&db, "token", 3, 1).await;
362
363 assert!(result.is_err());
364 assert_eq!(call_count.load(Ordering::SeqCst), 4);
365 }
366
367 #[tokio::test]
372 async fn test_cache_hit_returns_cached_token() {
373 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
375 cache
376 .insert(
377 "cached_token".to_string(),
378 CachedTokenData {
379 user: "cached_user".to_string(),
380 is_admin: true,
381 is_read_only: false,
382 },
383 )
384 .await;
385
386 let mock_db = MockDb::new(); let db: Arc<dyn DbProvider> = Arc::new(mock_db);
388 let settings = test_settings();
389
390 let mut headers = HeaderMap::new();
391 headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
392
393 let result = Token::from_header(&headers, &db, &cache, &settings).await;
394
395 assert!(result.is_ok());
396 let token = result.unwrap();
397 assert_eq!(token.user, "cached_user");
398 assert!(token.is_admin);
399 }
400
401 #[tokio::test]
402 async fn test_cache_miss_queries_db_and_caches() {
403 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
404
405 let mut mock_db = MockDb::new();
406 mock_db
407 .expect_get_user_from_token()
408 .with(eq("new_token"))
409 .times(1)
410 .returning(|_| {
411 Ok(User {
412 id: 1,
413 name: "db_user".to_string(),
414 pwd: String::new(),
415 salt: String::new(),
416 is_admin: false,
417 is_read_only: true,
418 created: String::new(),
419 })
420 });
421
422 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
423 let settings = test_settings();
424
425 let mut headers = HeaderMap::new();
426 headers.insert("Authorization", "Bearer new_token".parse().unwrap());
427
428 let result = Token::from_header(&headers, &db, &cache, &settings).await;
429
430 assert!(result.is_ok());
431 let token = result.unwrap();
432 assert_eq!(token.user, "db_user");
433 assert!(token.is_read_only);
434
435 let cached = cache.get("new_token").await;
437 assert!(cached.is_some());
438 assert_eq!(cached.unwrap().user, "db_user");
439 }
440
441 #[tokio::test]
442 async fn test_cache_miss_with_invalid_token_returns_forbidden() {
443 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
444
445 let mut mock_db = MockDb::new();
446 mock_db
447 .expect_get_user_from_token()
448 .with(eq("bad_token"))
449 .times(1) .returning(|_| Err(DbError::TokenNotFound));
451
452 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
453 let settings = test_settings();
454
455 let mut headers = HeaderMap::new();
456 headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
457
458 let result = Token::from_header(&headers, &db, &cache, &settings).await;
459
460 assert!(result.is_err());
461 assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
462
463 let cached = cache.get("bad_token").await;
465 assert!(cached.is_none());
466 }
467
468 #[tokio::test]
469 async fn test_no_authorization_header_returns_unauthorized() {
470 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
471 let mock_db = MockDb::new();
472 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
473 let settings = test_settings();
474
475 let headers = HeaderMap::new(); let result = Token::from_header(&headers, &db, &cache, &settings).await;
478
479 assert!(result.is_err());
480 assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
481 }
482
483 #[tokio::test]
484 async fn test_disabled_cache_always_queries_db() {
485 let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); let call_count = Arc::new(AtomicU32::new(0));
489 let call_count_clone = call_count.clone();
490
491 let mut mock_db = MockDb::new();
492 mock_db
493 .expect_get_user_from_token()
494 .with(eq("token"))
495 .times(2) .returning(move |_| {
497 call_count_clone.fetch_add(1, Ordering::SeqCst);
498 Ok(test_user())
499 });
500
501 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
502 let settings = test_settings();
503
504 let mut headers = HeaderMap::new();
505 headers.insert("Authorization", "Bearer token".parse().unwrap());
506
507 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
509 let _ = Token::from_header(&headers, &db, &cache, &settings).await;
511
512 assert_eq!(call_count.load(Ordering::SeqCst), 2);
513 }
514
515 #[tokio::test]
516 async fn test_lowercase_bearer_prefix_works() {
517 let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
519
520 let mut mock_db = MockDb::new();
521 mock_db
522 .expect_get_user_from_token()
523 .with(eq("lowercase_token"))
524 .times(1)
525 .returning(|_| Ok(test_user()));
526
527 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
528 let settings = test_settings();
529
530 let mut headers = HeaderMap::new();
531 headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
532
533 let result = Token::from_header(&headers, &db, &cache, &settings).await;
534
535 assert!(result.is_ok());
536 assert_eq!(result.unwrap().user, "test_user");
537 }
538
539 #[tokio::test]
540 async fn test_zero_retries_only_attempts_once() {
541 let call_count = Arc::new(AtomicU32::new(0));
543 let call_count_clone = call_count.clone();
544
545 let mut mock_db = MockDb::new();
546 mock_db
547 .expect_get_user_from_token()
548 .with(eq("token"))
549 .times(1)
550 .returning(move |_| {
551 call_count_clone.fetch_add(1, Ordering::SeqCst);
552 Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
553 sea_orm::error::ConnAcquireErr::Timeout,
554 )))
555 });
556
557 let db: Arc<dyn DbProvider> = Arc::new(mock_db);
558 let result = get_user_with_retry(&db, "token", 0, 1).await;
559
560 assert!(result.is_err());
561 assert_eq!(call_count.load(Ordering::SeqCst), 1);
562 }
563}