kellnr_auth/
token.rs

1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::error::DbError;
13use kellnr_db::DbProvider;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::{debug, warn};
20
21#[derive(Debug)]
22pub struct Token {
23    pub value: String,
24    pub user: String,
25    pub is_admin: bool,
26    pub is_read_only: bool,
27}
28
29// See https://github.com/tokio-rs/axum/discussions/2281
30#[derive(Debug)]
31pub enum OptionToken {
32    None,
33    Some(Token),
34}
35
36pub fn generate_token() -> String {
37    let mut rng = rng();
38    iter::repeat(())
39        .map(|()| rng.sample(Alphanumeric))
40        .map(char::from)
41        .take(32)
42        .collect::<String>()
43}
44
45impl Token {
46    pub async fn from_header(
47        headers: &HeaderMap,
48        db: &Arc<dyn DbProvider>,
49        cache: &Arc<TokenCacheManager>,
50        settings: &Arc<Settings>,
51    ) -> Result<Self, StatusCode> {
52        Self::extract_token(headers, db, cache, settings).await
53    }
54
55    async fn extract_token(
56        headers: &HeaderMap,
57        db: &Arc<dyn DbProvider>,
58        cache: &Arc<TokenCacheManager>,
59        settings: &Arc<Settings>,
60    ) -> Result<Token, StatusCode> {
61        // OptionToken code expects UNAUTHORIZED when no token is found
62        let mut token = headers
63            .get("Authorization")
64            .ok_or(StatusCode::UNAUTHORIZED)?
65            .to_str()
66            .map_err(|_| StatusCode::BAD_REQUEST)?;
67
68        // Handle basic authentication (does NOT use token cache - queries DB directly)
69        if token.starts_with("Basic ") || token.starts_with("basic ") {
70            let decoded = STANDARD
71                .decode(&token[6..])
72                .map_err(|_| StatusCode::BAD_REQUEST)?;
73            let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
74            let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
75
76            let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
77            if db.authenticate_user(&user.name, token).await.is_err() {
78                return Err(StatusCode::FORBIDDEN);
79            }
80
81            return Ok(Token {
82                value: token.to_string(),
83                user: user.name,
84                is_admin: user.is_admin,
85                is_read_only: user.is_read_only,
86            });
87        }
88
89        // Handle bearer authentication (uses token cache)
90        if token.starts_with("Bearer ") || token.starts_with("bearer ") {
91            token = &token[7..];
92        }
93
94        // Check cache first
95        if let Some(cached) = cache.get(token).await {
96            debug!("Token cache hit for user: {}", cached.user);
97            return Ok(Token {
98                value: token.to_string(),
99                user: cached.user,
100                is_admin: cached.is_admin,
101                is_read_only: cached.is_read_only,
102            });
103        }
104
105        // Cache miss - query DB with retry logic
106        let user = match get_user_with_retry(
107            db,
108            token,
109            settings.registry.token_db_retry_count,
110            settings.registry.token_db_retry_delay_ms,
111        )
112        .await
113        {
114            Ok(user) => user,
115            Err(e) => {
116                debug!("Token cache miss, DB lookup failed: {}", e);
117                return Err(StatusCode::FORBIDDEN);
118            }
119        };
120
121        debug!("Token cache miss, queried DB for user: {}", user.name);
122
123        // Insert into cache on successful DB lookup
124        cache
125            .insert(
126                token.to_string(),
127                CachedTokenData {
128                    user: user.name.clone(),
129                    is_admin: user.is_admin,
130                    is_read_only: user.is_read_only,
131                },
132            )
133            .await;
134
135        Ok(Token {
136            value: token.to_string(),
137            user: user.name,
138            is_admin: user.is_admin,
139            is_read_only: user.is_read_only,
140        })
141    }
142}
143
144async fn get_user_with_retry(
145    db: &Arc<dyn DbProvider>,
146    token: &str,
147    max_retries: u32,
148    delay_ms: u64,
149) -> Result<kellnr_db::User, DbError> {
150    let mut attempts = 0;
151
152    loop {
153        match db.get_user_from_token(token).await {
154            Ok(user) => return Ok(user),
155            Err(e) => {
156                // Do not retry on "not found" errors - these are definitive
157                if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
158                    return Err(e);
159                }
160
161                attempts += 1;
162                if attempts > max_retries {
163                    warn!(
164                        "Failed to get user from token after {} retries: {}",
165                        max_retries, e
166                    );
167                    return Err(e);
168                }
169
170                warn!(
171                    "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
172                    attempts,
173                    max_retries + 1,
174                    delay_ms,
175                    e
176                );
177                sleep(Duration::from_millis(delay_ms)).await;
178            }
179        }
180    }
181}
182
183impl FromRequestParts<AppStateData> for Token {
184    type Rejection = StatusCode;
185
186    async fn from_request_parts(
187        parts: &mut Parts,
188        state: &AppStateData,
189    ) -> Result<Self, Self::Rejection> {
190        Self::extract_token(&parts.headers, &state.db, &state.token_cache, &state.settings).await
191    }
192}
193
194impl FromRequestParts<AppStateData> for OptionToken {
195    type Rejection = StatusCode;
196
197    async fn from_request_parts(
198        parts: &mut Parts,
199        state: &AppStateData,
200    ) -> Result<Self, Self::Rejection> {
201        match Token::extract_token(&parts.headers, &state.db, &state.token_cache, &state.settings).await {
202            Ok(token) => Ok(OptionToken::Some(token)),
203            Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
204            Err(status_code) => Err(status_code),
205        }
206    }
207}
208
209#[derive(Deserialize)]
210pub struct NewTokenReqData {
211    pub name: String,
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use kellnr_db::User;
218    use kellnr_db::error::DbError;
219    use kellnr_db::mock::MockDb;
220    use mockall::predicate::*;
221    use std::sync::atomic::{AtomicU32, Ordering};
222
223    fn test_user() -> User {
224        User {
225            id: 1,
226            name: "test_user".to_string(),
227            pwd: String::new(),
228            salt: String::new(),
229            is_admin: false,
230            is_read_only: false,
231        }
232    }
233
234    fn test_settings() -> Arc<Settings> {
235        Arc::new(Settings {
236            registry: kellnr_settings::Registry {
237                token_cache_enabled: true,
238                token_cache_ttl_seconds: 60,
239                token_cache_max_capacity: 100,
240                token_db_retry_count: 3,
241                token_db_retry_delay_ms: 1,
242                ..kellnr_settings::Registry::default()
243            },
244            ..Settings::default()
245        })
246    }
247
248    // ===================
249    // Retry Logic Tests
250    // ===================
251
252    #[tokio::test]
253    async fn test_retry_succeeds_on_first_attempt() {
254        // DB returns user on first try - no retries needed
255        let mut mock_db = MockDb::new();
256        mock_db
257            .expect_get_user_from_token()
258            .with(eq("valid_token"))
259            .times(1)
260            .returning(|_| Ok(test_user()));
261
262        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
263        let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
264
265        assert!(result.is_ok());
266        assert_eq!(result.unwrap().name, "test_user");
267    }
268
269    #[tokio::test]
270    async fn test_retry_succeeds_after_transient_error() {
271        // DB fails once with transient error, then succeeds
272        let call_count = Arc::new(AtomicU32::new(0));
273        let call_count_clone = call_count.clone();
274
275        let mut mock_db = MockDb::new();
276        mock_db
277            .expect_get_user_from_token()
278            .with(eq("token"))
279            .times(2)
280            .returning(move |_| {
281                let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
282                if count == 0 {
283                    Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
284                        sea_orm::error::ConnAcquireErr::Timeout,
285                    )))
286                } else {
287                    Ok(test_user())
288                }
289            });
290
291        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
292        let result = get_user_with_retry(&db, "token", 3, 1).await;
293
294        assert!(result.is_ok());
295        assert_eq!(call_count.load(Ordering::SeqCst), 2);
296    }
297
298    #[tokio::test]
299    async fn test_no_retry_on_token_not_found() {
300        // TokenNotFound should NOT trigger retries
301        let mut mock_db = MockDb::new();
302        mock_db
303            .expect_get_user_from_token()
304            .with(eq("invalid_token"))
305            .times(1) // Should only be called once
306            .returning(|_| Err(DbError::TokenNotFound));
307
308        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
309        let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
310
311        assert!(result.is_err());
312        assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
313    }
314
315    #[tokio::test]
316    async fn test_no_retry_on_user_not_found() {
317        // UserNotFound should NOT trigger retries
318        let mut mock_db = MockDb::new();
319        mock_db
320            .expect_get_user_from_token()
321            .with(eq("orphan_token"))
322            .times(1)
323            .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
324
325        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
326        let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
327
328        assert!(result.is_err());
329        assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
330    }
331
332    #[tokio::test]
333    async fn test_exhausts_retries_on_persistent_error() {
334        // DB fails all retries with transient error
335        let call_count = Arc::new(AtomicU32::new(0));
336        let call_count_clone = call_count.clone();
337
338        let mut mock_db = MockDb::new();
339        mock_db
340            .expect_get_user_from_token()
341            .with(eq("token"))
342            .times(4) // 1 initial + 3 retries
343            .returning(move |_| {
344                call_count_clone.fetch_add(1, Ordering::SeqCst);
345                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
346                    sea_orm::error::ConnAcquireErr::Timeout,
347                )))
348            });
349
350        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
351        let result = get_user_with_retry(&db, "token", 3, 1).await;
352
353        assert!(result.is_err());
354        assert_eq!(call_count.load(Ordering::SeqCst), 4);
355    }
356
357    // =====================================
358    // Token Extraction with Cache Tests
359    // =====================================
360
361    #[tokio::test]
362    async fn test_cache_hit_returns_cached_token() {
363        // Pre-populate cache, DB should NOT be called
364        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
365        cache
366            .insert(
367                "cached_token".to_string(),
368                CachedTokenData {
369                    user: "cached_user".to_string(),
370                    is_admin: true,
371                    is_read_only: false,
372                },
373            )
374            .await;
375
376        let mock_db = MockDb::new(); // No expectations - should not be called
377        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
378        let settings = test_settings();
379
380        let mut headers = HeaderMap::new();
381        headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
382
383        let result = Token::from_header(&headers, &db, &cache, &settings).await;
384
385        assert!(result.is_ok());
386        let token = result.unwrap();
387        assert_eq!(token.user, "cached_user");
388        assert!(token.is_admin);
389    }
390
391    #[tokio::test]
392    async fn test_cache_miss_queries_db_and_caches() {
393        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
394
395        let mut mock_db = MockDb::new();
396        mock_db
397            .expect_get_user_from_token()
398            .with(eq("new_token"))
399            .times(1)
400            .returning(|_| {
401                Ok(User {
402                    id: 1,
403                    name: "db_user".to_string(),
404                    pwd: String::new(),
405                    salt: String::new(),
406                    is_admin: false,
407                    is_read_only: true,
408                })
409            });
410
411        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
412        let settings = test_settings();
413
414        let mut headers = HeaderMap::new();
415        headers.insert("Authorization", "Bearer new_token".parse().unwrap());
416
417        let result = Token::from_header(&headers, &db, &cache, &settings).await;
418
419        assert!(result.is_ok());
420        let token = result.unwrap();
421        assert_eq!(token.user, "db_user");
422        assert!(token.is_read_only);
423
424        // Verify token was cached
425        let cached = cache.get("new_token").await;
426        assert!(cached.is_some());
427        assert_eq!(cached.unwrap().user, "db_user");
428    }
429
430    #[tokio::test]
431    async fn test_cache_miss_with_invalid_token_returns_forbidden() {
432        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
433
434        let mut mock_db = MockDb::new();
435        mock_db
436            .expect_get_user_from_token()
437            .with(eq("bad_token"))
438            .times(1) // Should not retry on TokenNotFound
439            .returning(|_| Err(DbError::TokenNotFound));
440
441        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
442        let settings = test_settings();
443
444        let mut headers = HeaderMap::new();
445        headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
446
447        let result = Token::from_header(&headers, &db, &cache, &settings).await;
448
449        assert!(result.is_err());
450        assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
451
452        // Verify invalid token was NOT cached
453        let cached = cache.get("bad_token").await;
454        assert!(cached.is_none());
455    }
456
457    #[tokio::test]
458    async fn test_no_authorization_header_returns_unauthorized() {
459        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
460        let mock_db = MockDb::new();
461        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
462        let settings = test_settings();
463
464        let headers = HeaderMap::new(); // No Authorization header
465
466        let result = Token::from_header(&headers, &db, &cache, &settings).await;
467
468        assert!(result.is_err());
469        assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
470    }
471
472    #[tokio::test]
473    async fn test_disabled_cache_always_queries_db() {
474        // Cache is disabled - should always query DB
475        let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); // Disabled
476
477        let call_count = Arc::new(AtomicU32::new(0));
478        let call_count_clone = call_count.clone();
479
480        let mut mock_db = MockDb::new();
481        mock_db
482            .expect_get_user_from_token()
483            .with(eq("token"))
484            .times(2) // Called twice
485            .returning(move |_| {
486                call_count_clone.fetch_add(1, Ordering::SeqCst);
487                Ok(test_user())
488            });
489
490        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
491        let settings = test_settings();
492
493        let mut headers = HeaderMap::new();
494        headers.insert("Authorization", "Bearer token".parse().unwrap());
495
496        // First call
497        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
498        // Second call - should hit DB again since cache is disabled
499        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
500
501        assert_eq!(call_count.load(Ordering::SeqCst), 2);
502    }
503
504    #[tokio::test]
505    async fn test_lowercase_bearer_prefix_works() {
506        // Verifies case-insensitive handling of "bearer" prefix
507        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
508
509        let mut mock_db = MockDb::new();
510        mock_db
511            .expect_get_user_from_token()
512            .with(eq("lowercase_token"))
513            .times(1)
514            .returning(|_| Ok(test_user()));
515
516        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
517        let settings = test_settings();
518
519        let mut headers = HeaderMap::new();
520        headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
521
522        let result = Token::from_header(&headers, &db, &cache, &settings).await;
523
524        assert!(result.is_ok());
525        assert_eq!(result.unwrap().user, "test_user");
526    }
527
528    #[tokio::test]
529    async fn test_zero_retries_only_attempts_once() {
530        // With max_retries = 0, should only attempt once
531        let call_count = Arc::new(AtomicU32::new(0));
532        let call_count_clone = call_count.clone();
533
534        let mut mock_db = MockDb::new();
535        mock_db
536            .expect_get_user_from_token()
537            .with(eq("token"))
538            .times(1)
539            .returning(move |_| {
540                call_count_clone.fetch_add(1, Ordering::SeqCst);
541                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
542                    sea_orm::error::ConnAcquireErr::Timeout,
543                )))
544            });
545
546        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
547        let result = get_user_with_retry(&db, "token", 0, 1).await;
548
549        assert!(result.is_err());
550        assert_eq!(call_count.load(Ordering::SeqCst), 1);
551    }
552}