kellnr_auth/
token.rs

1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::DbProvider;
13use kellnr_db::error::DbError;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::{debug, warn};
20
21#[derive(Debug)]
22pub struct Token {
23    pub value: String,
24    pub user: String,
25    pub is_admin: bool,
26    pub is_read_only: bool,
27}
28
29// See https://github.com/tokio-rs/axum/discussions/2281
30#[derive(Debug)]
31pub enum OptionToken {
32    None,
33    Some(Token),
34}
35
36pub fn generate_token() -> String {
37    let mut rng = rng();
38    iter::repeat(())
39        .map(|()| rng.sample(Alphanumeric))
40        .map(char::from)
41        .take(32)
42        .collect::<String>()
43}
44
45impl Token {
46    pub async fn from_header(
47        headers: &HeaderMap,
48        db: &Arc<dyn DbProvider>,
49        cache: &Arc<TokenCacheManager>,
50        settings: &Arc<Settings>,
51    ) -> Result<Self, StatusCode> {
52        Self::extract_token(headers, db, cache, settings).await
53    }
54
55    async fn extract_token(
56        headers: &HeaderMap,
57        db: &Arc<dyn DbProvider>,
58        cache: &Arc<TokenCacheManager>,
59        settings: &Arc<Settings>,
60    ) -> Result<Token, StatusCode> {
61        // OptionToken code expects UNAUTHORIZED when no token is found
62        let mut token = headers
63            .get("Authorization")
64            .ok_or(StatusCode::UNAUTHORIZED)?
65            .to_str()
66            .map_err(|_| StatusCode::BAD_REQUEST)?;
67
68        // Handle basic authentication (does NOT use token cache - queries DB directly)
69        if token.starts_with("Basic ") || token.starts_with("basic ") {
70            let decoded = STANDARD
71                .decode(&token[6..])
72                .map_err(|_| StatusCode::BAD_REQUEST)?;
73            let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
74            let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
75
76            let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
77            if db.authenticate_user(&user.name, token).await.is_err() {
78                return Err(StatusCode::FORBIDDEN);
79            }
80
81            return Ok(Token {
82                value: token.to_string(),
83                user: user.name,
84                is_admin: user.is_admin,
85                is_read_only: user.is_read_only,
86            });
87        }
88
89        // Handle bearer authentication (uses token cache)
90        if token.starts_with("Bearer ") || token.starts_with("bearer ") {
91            token = &token[7..];
92        }
93
94        // Check cache first
95        if let Some(cached) = cache.get(token).await {
96            debug!("Token cache hit for user: {}", cached.user);
97            return Ok(Token {
98                value: token.to_string(),
99                user: cached.user,
100                is_admin: cached.is_admin,
101                is_read_only: cached.is_read_only,
102            });
103        }
104
105        // Cache miss - query DB with retry logic
106        let user = match get_user_with_retry(
107            db,
108            token,
109            settings.registry.token_db_retry_count,
110            settings.registry.token_db_retry_delay_ms,
111        )
112        .await
113        {
114            Ok(user) => user,
115            Err(e) => {
116                debug!("Token cache miss, DB lookup failed: {}", e);
117                return Err(StatusCode::FORBIDDEN);
118            }
119        };
120
121        debug!("Token cache miss, queried DB for user: {}", user.name);
122
123        // Insert into cache on successful DB lookup
124        cache
125            .insert(
126                token.to_string(),
127                CachedTokenData {
128                    user: user.name.clone(),
129                    is_admin: user.is_admin,
130                    is_read_only: user.is_read_only,
131                },
132            )
133            .await;
134
135        Ok(Token {
136            value: token.to_string(),
137            user: user.name,
138            is_admin: user.is_admin,
139            is_read_only: user.is_read_only,
140        })
141    }
142}
143
144async fn get_user_with_retry(
145    db: &Arc<dyn DbProvider>,
146    token: &str,
147    max_retries: u32,
148    delay_ms: u64,
149) -> Result<kellnr_db::User, DbError> {
150    let mut attempts = 0;
151
152    loop {
153        match db.get_user_from_token(token).await {
154            Ok(user) => return Ok(user),
155            Err(e) => {
156                // Do not retry on "not found" errors - these are definitive
157                if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
158                    return Err(e);
159                }
160
161                attempts += 1;
162                if attempts > max_retries {
163                    warn!(
164                        "Failed to get user from token after {} retries: {}",
165                        max_retries, e
166                    );
167                    return Err(e);
168                }
169
170                warn!(
171                    "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
172                    attempts,
173                    max_retries + 1,
174                    delay_ms,
175                    e
176                );
177                sleep(Duration::from_millis(delay_ms)).await;
178            }
179        }
180    }
181}
182
183impl FromRequestParts<AppStateData> for Token {
184    type Rejection = StatusCode;
185
186    async fn from_request_parts(
187        parts: &mut Parts,
188        state: &AppStateData,
189    ) -> Result<Self, Self::Rejection> {
190        Self::extract_token(
191            &parts.headers,
192            &state.db,
193            &state.token_cache,
194            &state.settings,
195        )
196        .await
197    }
198}
199
200impl FromRequestParts<AppStateData> for OptionToken {
201    type Rejection = StatusCode;
202
203    async fn from_request_parts(
204        parts: &mut Parts,
205        state: &AppStateData,
206    ) -> Result<Self, Self::Rejection> {
207        match Token::extract_token(
208            &parts.headers,
209            &state.db,
210            &state.token_cache,
211            &state.settings,
212        )
213        .await
214        {
215            Ok(token) => Ok(OptionToken::Some(token)),
216            Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
217            Err(status_code) => Err(status_code),
218        }
219    }
220}
221
222#[derive(Deserialize)]
223pub struct NewTokenReqData {
224    pub name: String,
225}
226
227#[cfg(test)]
228mod tests {
229    use std::sync::atomic::{AtomicU32, Ordering};
230
231    use kellnr_db::User;
232    use kellnr_db::error::DbError;
233    use kellnr_db::mock::MockDb;
234    use mockall::predicate::*;
235
236    use super::*;
237
238    fn test_user() -> User {
239        User {
240            id: 1,
241            name: "test_user".to_string(),
242            pwd: String::new(),
243            salt: String::new(),
244            is_admin: false,
245            is_read_only: false,
246        }
247    }
248
249    fn test_settings() -> Arc<Settings> {
250        Arc::new(Settings {
251            registry: kellnr_settings::Registry {
252                token_cache_enabled: true,
253                token_cache_ttl_seconds: 60,
254                token_cache_max_capacity: 100,
255                token_db_retry_count: 3,
256                token_db_retry_delay_ms: 1,
257                ..kellnr_settings::Registry::default()
258            },
259            ..Settings::default()
260        })
261    }
262
263    // ===================
264    // Retry Logic Tests
265    // ===================
266
267    #[tokio::test]
268    async fn test_retry_succeeds_on_first_attempt() {
269        // DB returns user on first try - no retries needed
270        let mut mock_db = MockDb::new();
271        mock_db
272            .expect_get_user_from_token()
273            .with(eq("valid_token"))
274            .times(1)
275            .returning(|_| Ok(test_user()));
276
277        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
278        let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
279
280        assert!(result.is_ok());
281        assert_eq!(result.unwrap().name, "test_user");
282    }
283
284    #[tokio::test]
285    async fn test_retry_succeeds_after_transient_error() {
286        // DB fails once with transient error, then succeeds
287        let call_count = Arc::new(AtomicU32::new(0));
288        let call_count_clone = call_count.clone();
289
290        let mut mock_db = MockDb::new();
291        mock_db
292            .expect_get_user_from_token()
293            .with(eq("token"))
294            .times(2)
295            .returning(move |_| {
296                let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
297                if count == 0 {
298                    Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
299                        sea_orm::error::ConnAcquireErr::Timeout,
300                    )))
301                } else {
302                    Ok(test_user())
303                }
304            });
305
306        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
307        let result = get_user_with_retry(&db, "token", 3, 1).await;
308
309        assert!(result.is_ok());
310        assert_eq!(call_count.load(Ordering::SeqCst), 2);
311    }
312
313    #[tokio::test]
314    async fn test_no_retry_on_token_not_found() {
315        // TokenNotFound should NOT trigger retries
316        let mut mock_db = MockDb::new();
317        mock_db
318            .expect_get_user_from_token()
319            .with(eq("invalid_token"))
320            .times(1) // Should only be called once
321            .returning(|_| Err(DbError::TokenNotFound));
322
323        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
324        let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
325
326        assert!(result.is_err());
327        assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
328    }
329
330    #[tokio::test]
331    async fn test_no_retry_on_user_not_found() {
332        // UserNotFound should NOT trigger retries
333        let mut mock_db = MockDb::new();
334        mock_db
335            .expect_get_user_from_token()
336            .with(eq("orphan_token"))
337            .times(1)
338            .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
339
340        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
341        let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
342
343        assert!(result.is_err());
344        assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
345    }
346
347    #[tokio::test]
348    async fn test_exhausts_retries_on_persistent_error() {
349        // DB fails all retries with transient error
350        let call_count = Arc::new(AtomicU32::new(0));
351        let call_count_clone = call_count.clone();
352
353        let mut mock_db = MockDb::new();
354        mock_db
355            .expect_get_user_from_token()
356            .with(eq("token"))
357            .times(4) // 1 initial + 3 retries
358            .returning(move |_| {
359                call_count_clone.fetch_add(1, Ordering::SeqCst);
360                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
361                    sea_orm::error::ConnAcquireErr::Timeout,
362                )))
363            });
364
365        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
366        let result = get_user_with_retry(&db, "token", 3, 1).await;
367
368        assert!(result.is_err());
369        assert_eq!(call_count.load(Ordering::SeqCst), 4);
370    }
371
372    // =====================================
373    // Token Extraction with Cache Tests
374    // =====================================
375
376    #[tokio::test]
377    async fn test_cache_hit_returns_cached_token() {
378        // Pre-populate cache, DB should NOT be called
379        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
380        cache
381            .insert(
382                "cached_token".to_string(),
383                CachedTokenData {
384                    user: "cached_user".to_string(),
385                    is_admin: true,
386                    is_read_only: false,
387                },
388            )
389            .await;
390
391        let mock_db = MockDb::new(); // No expectations - should not be called
392        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
393        let settings = test_settings();
394
395        let mut headers = HeaderMap::new();
396        headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
397
398        let result = Token::from_header(&headers, &db, &cache, &settings).await;
399
400        assert!(result.is_ok());
401        let token = result.unwrap();
402        assert_eq!(token.user, "cached_user");
403        assert!(token.is_admin);
404    }
405
406    #[tokio::test]
407    async fn test_cache_miss_queries_db_and_caches() {
408        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
409
410        let mut mock_db = MockDb::new();
411        mock_db
412            .expect_get_user_from_token()
413            .with(eq("new_token"))
414            .times(1)
415            .returning(|_| {
416                Ok(User {
417                    id: 1,
418                    name: "db_user".to_string(),
419                    pwd: String::new(),
420                    salt: String::new(),
421                    is_admin: false,
422                    is_read_only: true,
423                })
424            });
425
426        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
427        let settings = test_settings();
428
429        let mut headers = HeaderMap::new();
430        headers.insert("Authorization", "Bearer new_token".parse().unwrap());
431
432        let result = Token::from_header(&headers, &db, &cache, &settings).await;
433
434        assert!(result.is_ok());
435        let token = result.unwrap();
436        assert_eq!(token.user, "db_user");
437        assert!(token.is_read_only);
438
439        // Verify token was cached
440        let cached = cache.get("new_token").await;
441        assert!(cached.is_some());
442        assert_eq!(cached.unwrap().user, "db_user");
443    }
444
445    #[tokio::test]
446    async fn test_cache_miss_with_invalid_token_returns_forbidden() {
447        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
448
449        let mut mock_db = MockDb::new();
450        mock_db
451            .expect_get_user_from_token()
452            .with(eq("bad_token"))
453            .times(1) // Should not retry on TokenNotFound
454            .returning(|_| Err(DbError::TokenNotFound));
455
456        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
457        let settings = test_settings();
458
459        let mut headers = HeaderMap::new();
460        headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
461
462        let result = Token::from_header(&headers, &db, &cache, &settings).await;
463
464        assert!(result.is_err());
465        assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
466
467        // Verify invalid token was NOT cached
468        let cached = cache.get("bad_token").await;
469        assert!(cached.is_none());
470    }
471
472    #[tokio::test]
473    async fn test_no_authorization_header_returns_unauthorized() {
474        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
475        let mock_db = MockDb::new();
476        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
477        let settings = test_settings();
478
479        let headers = HeaderMap::new(); // No Authorization header
480
481        let result = Token::from_header(&headers, &db, &cache, &settings).await;
482
483        assert!(result.is_err());
484        assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
485    }
486
487    #[tokio::test]
488    async fn test_disabled_cache_always_queries_db() {
489        // Cache is disabled - should always query DB
490        let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); // Disabled
491
492        let call_count = Arc::new(AtomicU32::new(0));
493        let call_count_clone = call_count.clone();
494
495        let mut mock_db = MockDb::new();
496        mock_db
497            .expect_get_user_from_token()
498            .with(eq("token"))
499            .times(2) // Called twice
500            .returning(move |_| {
501                call_count_clone.fetch_add(1, Ordering::SeqCst);
502                Ok(test_user())
503            });
504
505        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
506        let settings = test_settings();
507
508        let mut headers = HeaderMap::new();
509        headers.insert("Authorization", "Bearer token".parse().unwrap());
510
511        // First call
512        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
513        // Second call - should hit DB again since cache is disabled
514        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
515
516        assert_eq!(call_count.load(Ordering::SeqCst), 2);
517    }
518
519    #[tokio::test]
520    async fn test_lowercase_bearer_prefix_works() {
521        // Verifies case-insensitive handling of "bearer" prefix
522        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
523
524        let mut mock_db = MockDb::new();
525        mock_db
526            .expect_get_user_from_token()
527            .with(eq("lowercase_token"))
528            .times(1)
529            .returning(|_| Ok(test_user()));
530
531        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
532        let settings = test_settings();
533
534        let mut headers = HeaderMap::new();
535        headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
536
537        let result = Token::from_header(&headers, &db, &cache, &settings).await;
538
539        assert!(result.is_ok());
540        assert_eq!(result.unwrap().user, "test_user");
541    }
542
543    #[tokio::test]
544    async fn test_zero_retries_only_attempts_once() {
545        // With max_retries = 0, should only attempt once
546        let call_count = Arc::new(AtomicU32::new(0));
547        let call_count_clone = call_count.clone();
548
549        let mut mock_db = MockDb::new();
550        mock_db
551            .expect_get_user_from_token()
552            .with(eq("token"))
553            .times(1)
554            .returning(move |_| {
555                call_count_clone.fetch_add(1, Ordering::SeqCst);
556                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
557                    sea_orm::error::ConnAcquireErr::Timeout,
558                )))
559            });
560
561        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
562        let result = get_user_with_retry(&db, "token", 0, 1).await;
563
564        assert!(result.is_err());
565        assert_eq!(call_count.load(Ordering::SeqCst), 1);
566    }
567}