Skip to main content

kellnr_auth/
token.rs

1use std::iter;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::extract::FromRequestParts;
6use axum::http::request::Parts;
7use axum::http::{HeaderMap, StatusCode};
8use base64::Engine;
9use base64::engine::general_purpose::STANDARD;
10use kellnr_appstate::AppStateData;
11use kellnr_common::token_cache::{CachedTokenData, TokenCacheManager};
12use kellnr_db::DbProvider;
13use kellnr_db::error::DbError;
14use kellnr_settings::Settings;
15use rand::distr::Alphanumeric;
16use rand::{Rng, rng};
17use serde::Deserialize;
18use tokio::time::sleep;
19use tracing::warn;
20use utoipa::ToSchema;
21
22#[derive(Debug)]
23pub struct Token {
24    pub value: String,
25    pub user: String,
26    pub is_admin: bool,
27    pub is_read_only: bool,
28}
29
30// See https://github.com/tokio-rs/axum/discussions/2281
31#[derive(Debug)]
32pub enum OptionToken {
33    None,
34    Some(Token),
35}
36
37pub fn generate_token() -> String {
38    let mut rng = rng();
39    iter::repeat(())
40        .map(|()| rng.sample(Alphanumeric))
41        .map(char::from)
42        .take(32)
43        .collect::<String>()
44}
45
46impl Token {
47    pub async fn from_header(
48        headers: &HeaderMap,
49        db: &Arc<dyn DbProvider>,
50        cache: &Arc<TokenCacheManager>,
51        settings: &Arc<Settings>,
52    ) -> Result<Self, StatusCode> {
53        Self::extract_token(headers, db, cache, settings).await
54    }
55
56    async fn extract_token(
57        headers: &HeaderMap,
58        db: &Arc<dyn DbProvider>,
59        cache: &Arc<TokenCacheManager>,
60        settings: &Arc<Settings>,
61    ) -> Result<Token, StatusCode> {
62        // OptionToken code expects UNAUTHORIZED when no token is found
63        let mut token = headers
64            .get("Authorization")
65            .ok_or(StatusCode::UNAUTHORIZED)?
66            .to_str()
67            .map_err(|_| StatusCode::BAD_REQUEST)?;
68
69        // Handle basic authentication (does NOT use token cache - queries DB directly)
70        if token.starts_with("Basic ") || token.starts_with("basic ") {
71            let decoded = STANDARD
72                .decode(&token[6..])
73                .map_err(|_| StatusCode::BAD_REQUEST)?;
74            let decoded_str = String::from_utf8(decoded).map_err(|_| StatusCode::BAD_REQUEST)?;
75            let (user, token) = decoded_str.split_once(':').ok_or(StatusCode::BAD_REQUEST)?;
76
77            let user = db.get_user(user).await.map_err(|_| StatusCode::FORBIDDEN)?;
78            if db.authenticate_user(&user.name, token).await.is_err() {
79                return Err(StatusCode::FORBIDDEN);
80            }
81
82            return Ok(Token {
83                value: token.to_string(),
84                user: user.name,
85                is_admin: user.is_admin,
86                is_read_only: user.is_read_only,
87            });
88        }
89
90        // Handle bearer authentication (uses token cache)
91        if token.starts_with("Bearer ") || token.starts_with("bearer ") {
92            token = &token[7..];
93        }
94
95        // Check cache first
96        if let Some(cached) = cache.get(token).await {
97            return Ok(Token {
98                value: token.to_string(),
99                user: cached.user,
100                is_admin: cached.is_admin,
101                is_read_only: cached.is_read_only,
102            });
103        }
104
105        // Cache miss - query DB with retry logic
106        let Ok(user) = get_user_with_retry(
107            db,
108            token,
109            settings.registry.token_db_retry_count,
110            settings.registry.token_db_retry_delay_ms,
111        )
112        .await
113        else {
114            return Err(StatusCode::FORBIDDEN);
115        };
116
117        // Insert into cache on successful DB lookup
118        cache
119            .insert(
120                token.to_string(),
121                CachedTokenData {
122                    user: user.name.clone(),
123                    is_admin: user.is_admin,
124                    is_read_only: user.is_read_only,
125                },
126            )
127            .await;
128
129        Ok(Token {
130            value: token.to_string(),
131            user: user.name,
132            is_admin: user.is_admin,
133            is_read_only: user.is_read_only,
134        })
135    }
136}
137
138async fn get_user_with_retry(
139    db: &Arc<dyn DbProvider>,
140    token: &str,
141    max_retries: u32,
142    delay_ms: u64,
143) -> Result<kellnr_db::User, DbError> {
144    let mut attempts = 0;
145
146    loop {
147        match db.get_user_from_token(token).await {
148            Ok(user) => return Ok(user),
149            Err(e) => {
150                // Do not retry on "not found" errors - these are definitive
151                if matches!(e, DbError::TokenNotFound | DbError::UserNotFound(_)) {
152                    return Err(e);
153                }
154
155                attempts += 1;
156                if attempts > max_retries {
157                    warn!(
158                        "Failed to get user from token after {} retries: {}",
159                        max_retries, e
160                    );
161                    return Err(e);
162                }
163
164                warn!(
165                    "Transient DB error on attempt {}/{}, retrying in {}ms: {}",
166                    attempts,
167                    max_retries + 1,
168                    delay_ms,
169                    e
170                );
171                sleep(Duration::from_millis(delay_ms)).await;
172            }
173        }
174    }
175}
176
177impl FromRequestParts<AppStateData> for Token {
178    type Rejection = StatusCode;
179
180    async fn from_request_parts(
181        parts: &mut Parts,
182        state: &AppStateData,
183    ) -> Result<Self, Self::Rejection> {
184        Self::extract_token(
185            &parts.headers,
186            &state.db,
187            &state.token_cache,
188            &state.settings,
189        )
190        .await
191    }
192}
193
194impl FromRequestParts<AppStateData> for OptionToken {
195    type Rejection = StatusCode;
196
197    async fn from_request_parts(
198        parts: &mut Parts,
199        state: &AppStateData,
200    ) -> Result<Self, Self::Rejection> {
201        match Token::extract_token(
202            &parts.headers,
203            &state.db,
204            &state.token_cache,
205            &state.settings,
206        )
207        .await
208        {
209            Ok(token) => Ok(OptionToken::Some(token)),
210            Err(StatusCode::UNAUTHORIZED) => Ok(OptionToken::None),
211            Err(status_code) => Err(status_code),
212        }
213    }
214}
215
216#[derive(Deserialize, ToSchema)]
217pub struct NewTokenReqData {
218    pub name: String,
219}
220
221#[cfg(test)]
222mod tests {
223    use std::sync::atomic::{AtomicU32, Ordering};
224
225    use kellnr_db::User;
226    use kellnr_db::error::DbError;
227    use kellnr_db::mock::MockDb;
228    use mockall::predicate::*;
229
230    use super::*;
231
232    fn test_user() -> User {
233        User {
234            id: 1,
235            name: "test_user".to_string(),
236            pwd: String::new(),
237            salt: String::new(),
238            is_admin: false,
239            is_read_only: false,
240            created: String::new(),
241        }
242    }
243
244    fn test_settings() -> Arc<Settings> {
245        Arc::new(Settings {
246            registry: kellnr_settings::Registry {
247                token_cache_enabled: true,
248                token_cache_ttl_seconds: 60,
249                token_cache_max_capacity: 100,
250                token_db_retry_count: 3,
251                token_db_retry_delay_ms: 1,
252                ..kellnr_settings::Registry::default()
253            },
254            ..Settings::default()
255        })
256    }
257
258    // ===================
259    // Retry Logic Tests
260    // ===================
261
262    #[tokio::test]
263    async fn test_retry_succeeds_on_first_attempt() {
264        // DB returns user on first try - no retries needed
265        let mut mock_db = MockDb::new();
266        mock_db
267            .expect_get_user_from_token()
268            .with(eq("valid_token"))
269            .times(1)
270            .returning(|_| Ok(test_user()));
271
272        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
273        let result = get_user_with_retry(&db, "valid_token", 3, 10).await;
274
275        assert!(result.is_ok());
276        assert_eq!(result.unwrap().name, "test_user");
277    }
278
279    #[tokio::test]
280    async fn test_retry_succeeds_after_transient_error() {
281        // DB fails once with transient error, then succeeds
282        let call_count = Arc::new(AtomicU32::new(0));
283        let call_count_clone = call_count.clone();
284
285        let mut mock_db = MockDb::new();
286        mock_db
287            .expect_get_user_from_token()
288            .with(eq("token"))
289            .times(2)
290            .returning(move |_| {
291                let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
292                if count == 0 {
293                    Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
294                        sea_orm::error::ConnAcquireErr::Timeout,
295                    )))
296                } else {
297                    Ok(test_user())
298                }
299            });
300
301        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
302        let result = get_user_with_retry(&db, "token", 3, 1).await;
303
304        assert!(result.is_ok());
305        assert_eq!(call_count.load(Ordering::SeqCst), 2);
306    }
307
308    #[tokio::test]
309    async fn test_no_retry_on_token_not_found() {
310        // TokenNotFound should NOT trigger retries
311        let mut mock_db = MockDb::new();
312        mock_db
313            .expect_get_user_from_token()
314            .with(eq("invalid_token"))
315            .times(1) // Should only be called once
316            .returning(|_| Err(DbError::TokenNotFound));
317
318        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
319        let result = get_user_with_retry(&db, "invalid_token", 3, 10).await;
320
321        assert!(result.is_err());
322        assert!(matches!(result.unwrap_err(), DbError::TokenNotFound));
323    }
324
325    #[tokio::test]
326    async fn test_no_retry_on_user_not_found() {
327        // UserNotFound should NOT trigger retries
328        let mut mock_db = MockDb::new();
329        mock_db
330            .expect_get_user_from_token()
331            .with(eq("orphan_token"))
332            .times(1)
333            .returning(|_| Err(DbError::UserNotFound("orphan".to_string())));
334
335        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
336        let result = get_user_with_retry(&db, "orphan_token", 3, 10).await;
337
338        assert!(result.is_err());
339        assert!(matches!(result.unwrap_err(), DbError::UserNotFound(_)));
340    }
341
342    #[tokio::test]
343    async fn test_exhausts_retries_on_persistent_error() {
344        // DB fails all retries with transient error
345        let call_count = Arc::new(AtomicU32::new(0));
346        let call_count_clone = call_count.clone();
347
348        let mut mock_db = MockDb::new();
349        mock_db
350            .expect_get_user_from_token()
351            .with(eq("token"))
352            .times(4) // 1 initial + 3 retries
353            .returning(move |_| {
354                call_count_clone.fetch_add(1, Ordering::SeqCst);
355                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
356                    sea_orm::error::ConnAcquireErr::Timeout,
357                )))
358            });
359
360        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
361        let result = get_user_with_retry(&db, "token", 3, 1).await;
362
363        assert!(result.is_err());
364        assert_eq!(call_count.load(Ordering::SeqCst), 4);
365    }
366
367    // =====================================
368    // Token Extraction with Cache Tests
369    // =====================================
370
371    #[tokio::test]
372    async fn test_cache_hit_returns_cached_token() {
373        // Pre-populate cache, DB should NOT be called
374        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
375        cache
376            .insert(
377                "cached_token".to_string(),
378                CachedTokenData {
379                    user: "cached_user".to_string(),
380                    is_admin: true,
381                    is_read_only: false,
382                },
383            )
384            .await;
385
386        let mock_db = MockDb::new(); // No expectations - should not be called
387        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
388        let settings = test_settings();
389
390        let mut headers = HeaderMap::new();
391        headers.insert("Authorization", "Bearer cached_token".parse().unwrap());
392
393        let result = Token::from_header(&headers, &db, &cache, &settings).await;
394
395        assert!(result.is_ok());
396        let token = result.unwrap();
397        assert_eq!(token.user, "cached_user");
398        assert!(token.is_admin);
399    }
400
401    #[tokio::test]
402    async fn test_cache_miss_queries_db_and_caches() {
403        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
404
405        let mut mock_db = MockDb::new();
406        mock_db
407            .expect_get_user_from_token()
408            .with(eq("new_token"))
409            .times(1)
410            .returning(|_| {
411                Ok(User {
412                    id: 1,
413                    name: "db_user".to_string(),
414                    pwd: String::new(),
415                    salt: String::new(),
416                    is_admin: false,
417                    is_read_only: true,
418                    created: String::new(),
419                })
420            });
421
422        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
423        let settings = test_settings();
424
425        let mut headers = HeaderMap::new();
426        headers.insert("Authorization", "Bearer new_token".parse().unwrap());
427
428        let result = Token::from_header(&headers, &db, &cache, &settings).await;
429
430        assert!(result.is_ok());
431        let token = result.unwrap();
432        assert_eq!(token.user, "db_user");
433        assert!(token.is_read_only);
434
435        // Verify token was cached
436        let cached = cache.get("new_token").await;
437        assert!(cached.is_some());
438        assert_eq!(cached.unwrap().user, "db_user");
439    }
440
441    #[tokio::test]
442    async fn test_cache_miss_with_invalid_token_returns_forbidden() {
443        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
444
445        let mut mock_db = MockDb::new();
446        mock_db
447            .expect_get_user_from_token()
448            .with(eq("bad_token"))
449            .times(1) // Should not retry on TokenNotFound
450            .returning(|_| Err(DbError::TokenNotFound));
451
452        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
453        let settings = test_settings();
454
455        let mut headers = HeaderMap::new();
456        headers.insert("Authorization", "Bearer bad_token".parse().unwrap());
457
458        let result = Token::from_header(&headers, &db, &cache, &settings).await;
459
460        assert!(result.is_err());
461        assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
462
463        // Verify invalid token was NOT cached
464        let cached = cache.get("bad_token").await;
465        assert!(cached.is_none());
466    }
467
468    #[tokio::test]
469    async fn test_no_authorization_header_returns_unauthorized() {
470        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
471        let mock_db = MockDb::new();
472        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
473        let settings = test_settings();
474
475        let headers = HeaderMap::new(); // No Authorization header
476
477        let result = Token::from_header(&headers, &db, &cache, &settings).await;
478
479        assert!(result.is_err());
480        assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
481    }
482
483    #[tokio::test]
484    async fn test_disabled_cache_always_queries_db() {
485        // Cache is disabled - should always query DB
486        let cache = Arc::new(TokenCacheManager::new(false, 60, 100)); // Disabled
487
488        let call_count = Arc::new(AtomicU32::new(0));
489        let call_count_clone = call_count.clone();
490
491        let mut mock_db = MockDb::new();
492        mock_db
493            .expect_get_user_from_token()
494            .with(eq("token"))
495            .times(2) // Called twice
496            .returning(move |_| {
497                call_count_clone.fetch_add(1, Ordering::SeqCst);
498                Ok(test_user())
499            });
500
501        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
502        let settings = test_settings();
503
504        let mut headers = HeaderMap::new();
505        headers.insert("Authorization", "Bearer token".parse().unwrap());
506
507        // First call
508        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
509        // Second call - should hit DB again since cache is disabled
510        let _ = Token::from_header(&headers, &db, &cache, &settings).await;
511
512        assert_eq!(call_count.load(Ordering::SeqCst), 2);
513    }
514
515    #[tokio::test]
516    async fn test_lowercase_bearer_prefix_works() {
517        // Verifies case-insensitive handling of "bearer" prefix
518        let cache = Arc::new(TokenCacheManager::new(true, 60, 100));
519
520        let mut mock_db = MockDb::new();
521        mock_db
522            .expect_get_user_from_token()
523            .with(eq("lowercase_token"))
524            .times(1)
525            .returning(|_| Ok(test_user()));
526
527        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
528        let settings = test_settings();
529
530        let mut headers = HeaderMap::new();
531        headers.insert("Authorization", "bearer lowercase_token".parse().unwrap());
532
533        let result = Token::from_header(&headers, &db, &cache, &settings).await;
534
535        assert!(result.is_ok());
536        assert_eq!(result.unwrap().user, "test_user");
537    }
538
539    #[tokio::test]
540    async fn test_zero_retries_only_attempts_once() {
541        // With max_retries = 0, should only attempt once
542        let call_count = Arc::new(AtomicU32::new(0));
543        let call_count_clone = call_count.clone();
544
545        let mut mock_db = MockDb::new();
546        mock_db
547            .expect_get_user_from_token()
548            .with(eq("token"))
549            .times(1)
550            .returning(move |_| {
551                call_count_clone.fetch_add(1, Ordering::SeqCst);
552                Err(DbError::PostgresError(sea_orm::DbErr::ConnectionAcquire(
553                    sea_orm::error::ConnAcquireErr::Timeout,
554                )))
555            });
556
557        let db: Arc<dyn DbProvider> = Arc::new(mock_db);
558        let result = get_user_with_retry(&db, "token", 0, 1).await;
559
560        assert!(result.is_err());
561        assert_eq!(call_count.load(Ordering::SeqCst), 1);
562    }
563}