Skip to main content

oauth2_test_server/
store.rs

1use chrono::{Duration, Utc};
2use std::{
3    collections::HashMap,
4    net::SocketAddr,
5    sync::{Arc, RwLock},
6};
7use tokio::{net::TcpListener, task::JoinHandle};
8use uuid::Uuid;
9
10use crate::{
11    config::IssuerConfig,
12    crypto::{build_jwks_json, generate_token_string, issue_jwt, Keys},
13    models::{AuthorizationCode, Client, DeviceAuthorization, Token},
14};
15
16#[async_trait::async_trait]
17pub trait OauthStore: Send + Sync {
18    async fn get_client(&self, client_id: &str) -> Option<Client>;
19    async fn insert_client(&self, client: Client);
20
21    async fn get_code(&self, code: &str) -> Option<AuthorizationCode>;
22    async fn remove_code(&self, code: &str) -> Option<AuthorizationCode>;
23    async fn insert_code(&self, code: String, auth_code: AuthorizationCode);
24    async fn cleanup_expired_codes(&self) -> usize;
25
26    async fn get_token(&self, token: &str) -> Option<Token>;
27    async fn insert_token(&self, token: String, value: Token);
28    async fn update_token(&self, token: &str, value: Token);
29    async fn cleanup_expired_tokens(&self) -> usize;
30
31    async fn get_refresh_token(&self, token: &str) -> Option<Token>;
32    async fn insert_refresh_token(&self, token: String, value: Token);
33    async fn update_refresh_token(&self, token: &str, value: Token);
34    async fn cleanup_expired_refresh_tokens(&self) -> usize;
35
36    async fn get_device_code(&self, device_code: &str) -> Option<DeviceAuthorization>;
37    async fn insert_device_code(&self, device_code: String, auth: DeviceAuthorization);
38    async fn update_device_code(&self, device_code: &str, auth: DeviceAuthorization);
39    async fn cleanup_expired_device_codes(&self) -> usize;
40
41    async fn get_all_clients(&self) -> Vec<Client>;
42    async fn get_all_codes(&self) -> Vec<AuthorizationCode>;
43    async fn get_all_tokens(&self) -> Vec<Token>;
44    async fn get_all_refresh_tokens(&self) -> Vec<Token>;
45
46    async fn clear_clients(&self);
47    async fn clear_codes(&self);
48    async fn clear_tokens(&self);
49    async fn clear_refresh_tokens(&self);
50    async fn clear_device_codes(&self);
51    async fn clear_all(&self);
52}
53
54#[derive(Clone)]
55pub struct InMemoryStore {
56    clients: Arc<RwLock<HashMap<String, Client>>>,
57    codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
58    tokens: Arc<RwLock<HashMap<String, Token>>>,
59    refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
60    device_codes: Arc<RwLock<HashMap<String, DeviceAuthorization>>>,
61}
62
63impl Default for InMemoryStore {
64    fn default() -> Self {
65        Self {
66            clients: Arc::new(RwLock::new(HashMap::new())),
67            codes: Arc::new(RwLock::new(HashMap::new())),
68            tokens: Arc::new(RwLock::new(HashMap::new())),
69            refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
70            device_codes: Arc::new(RwLock::new(HashMap::new())),
71        }
72    }
73}
74
75#[async_trait::async_trait]
76impl OauthStore for InMemoryStore {
77    async fn get_client(&self, client_id: &str) -> Option<Client> {
78        self.clients.read().unwrap().get(client_id).cloned()
79    }
80
81    async fn insert_client(&self, client: Client) {
82        self.clients
83            .write()
84            .unwrap()
85            .insert(client.client_id.clone(), client);
86    }
87
88    async fn get_code(&self, code: &str) -> Option<AuthorizationCode> {
89        self.codes.read().unwrap().get(code).cloned()
90    }
91
92    async fn remove_code(&self, code: &str) -> Option<AuthorizationCode> {
93        self.codes.write().unwrap().remove(code)
94    }
95
96    async fn insert_code(&self, code: String, auth_code: AuthorizationCode) {
97        self.codes.write().unwrap().insert(code, auth_code);
98    }
99
100    async fn get_token(&self, token: &str) -> Option<Token> {
101        self.tokens.read().unwrap().get(token).cloned()
102    }
103
104    async fn insert_token(&self, token: String, value: Token) {
105        self.tokens.write().unwrap().insert(token, value);
106    }
107
108    async fn update_token(&self, token: &str, value: Token) {
109        if let Some(t) = self.tokens.write().unwrap().get_mut(token) {
110            *t = value;
111        }
112    }
113
114    async fn get_refresh_token(&self, token: &str) -> Option<Token> {
115        self.refresh_tokens.read().unwrap().get(token).cloned()
116    }
117
118    async fn insert_refresh_token(&self, token: String, value: Token) {
119        self.refresh_tokens.write().unwrap().insert(token, value);
120    }
121
122    async fn update_refresh_token(&self, token: &str, value: Token) {
123        if let Some(t) = self.refresh_tokens.write().unwrap().get_mut(token) {
124            *t = value;
125        }
126    }
127
128    async fn get_device_code(&self, device_code: &str) -> Option<DeviceAuthorization> {
129        self.device_codes.read().unwrap().get(device_code).cloned()
130    }
131
132    async fn insert_device_code(&self, device_code: String, auth: DeviceAuthorization) {
133        self.device_codes.write().unwrap().insert(device_code, auth);
134    }
135
136    async fn update_device_code(&self, device_code: &str, auth: DeviceAuthorization) {
137        if let Some(a) = self.device_codes.write().unwrap().get_mut(device_code) {
138            *a = auth;
139        }
140    }
141
142    async fn cleanup_expired_codes(&self) -> usize {
143        let now = Utc::now();
144        let mut count = 0;
145        let mut codes = self.codes.write().unwrap();
146        codes.retain(|_, code| {
147            if code.expires_at < now {
148                count += 1;
149                false
150            } else {
151                true
152            }
153        });
154        count
155    }
156
157    async fn cleanup_expired_tokens(&self) -> usize {
158        let now = Utc::now();
159        let mut count = 0;
160        let mut tokens = self.tokens.write().unwrap();
161        tokens.retain(|_, token| {
162            if token.expires_at < now {
163                count += 1;
164                false
165            } else {
166                true
167            }
168        });
169        count
170    }
171
172    async fn cleanup_expired_refresh_tokens(&self) -> usize {
173        let now = Utc::now();
174        let mut count = 0;
175        let mut tokens = self.refresh_tokens.write().unwrap();
176        tokens.retain(|_, token| {
177            if token.expires_at < now {
178                count += 1;
179                false
180            } else {
181                true
182            }
183        });
184        count
185    }
186
187    async fn cleanup_expired_device_codes(&self) -> usize {
188        let now = Utc::now();
189        let mut count = 0;
190        let mut codes = self.device_codes.write().unwrap();
191        codes.retain(|_, code| {
192            if code.expires_at < now {
193                count += 1;
194                false
195            } else {
196                true
197            }
198        });
199        count
200    }
201
202    async fn get_all_clients(&self) -> Vec<Client> {
203        self.clients.read().unwrap().values().cloned().collect()
204    }
205
206    async fn get_all_codes(&self) -> Vec<AuthorizationCode> {
207        self.codes.read().unwrap().values().cloned().collect()
208    }
209
210    async fn get_all_tokens(&self) -> Vec<Token> {
211        self.tokens.read().unwrap().values().cloned().collect()
212    }
213
214    async fn get_all_refresh_tokens(&self) -> Vec<Token> {
215        self.refresh_tokens
216            .read()
217            .unwrap()
218            .values()
219            .cloned()
220            .collect()
221    }
222
223    async fn clear_clients(&self) {
224        self.clients.write().unwrap().clear();
225    }
226
227    async fn clear_codes(&self) {
228        self.codes.write().unwrap().clear();
229    }
230
231    async fn clear_tokens(&self) {
232        self.tokens.write().unwrap().clear();
233    }
234
235    async fn clear_refresh_tokens(&self) {
236        self.refresh_tokens.write().unwrap().clear();
237    }
238
239    async fn clear_device_codes(&self) {
240        self.device_codes.write().unwrap().clear();
241    }
242
243    async fn clear_all(&self) {
244        self.clear_clients().await;
245        self.clear_codes().await;
246        self.clear_tokens().await;
247        self.clear_refresh_tokens().await;
248        self.clear_device_codes().await;
249    }
250}
251
252/// Shared in-memory state for the OAuth2 server.
253///
254/// Holds the issuer configuration, cryptographic keys, and abstract stores.
255#[derive(Clone)]
256pub struct AppState {
257    pub config: Arc<IssuerConfig>,
258    pub base_url: String,
259    pub store: Arc<dyn OauthStore>,
260    pub keys: Arc<Keys>,
261    pub jwks_json: Arc<serde_json::Value>,
262}
263
264impl AppState {
265    /// Create a new state with a freshly generated RSA key pair, and an in-memory store.
266    pub fn new(config: IssuerConfig) -> Self {
267        Self::with_store(config, Arc::new(InMemoryStore::default()))
268    }
269
270    /// Create a new state providing a custom store implementation.
271    pub fn with_store(config: IssuerConfig, store: Arc<dyn OauthStore>) -> Self {
272        let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
273        let keys = Arc::new(Keys::generate());
274        let jwks_json = Arc::new(build_jwks_json(&keys));
275        Self {
276            config: Arc::new(config),
277            store,
278            base_url,
279            keys,
280            jwks_json,
281        }
282    }
283
284    /// Returns the OAuth2 issuer URL (e.g. `http://localhost:8090`).
285    pub fn issuer(&self) -> &str {
286        self.base_url.as_str()
287    }
288
289    /// Register a new client from RFC 7591 metadata JSON.
290    pub async fn register_client(
291        &self,
292        metadata: serde_json::Value,
293    ) -> Result<Client, crate::error::OauthError> {
294        let requested_scope = metadata
295            .get("scope")
296            .and_then(|v| v.as_str())
297            .unwrap_or("openid");
298
299        self.config
300            .validate_scope(requested_scope)
301            .map_err(crate::error::OauthError::InvalidScope)?;
302
303        let client_id = Uuid::new_v4().to_string();
304
305        let client_secret = if self.config.generate_client_secret_for_dcr
306            || metadata
307                .get("token_endpoint_auth_method")
308                .and_then(|v| v.as_str())
309                != Some("none")
310        {
311            Some(generate_token_string())
312        } else {
313            None
314        };
315
316        let redirect_uris = metadata
317            .get("redirect_uris")
318            .and_then(|v| v.as_array())
319            .map(|arr| {
320                arr.iter()
321                    .filter_map(|u| u.as_str().map(|s| s.to_string()))
322                    .collect::<Vec<String>>()
323            })
324            .unwrap_or_default();
325
326        let grant_types = metadata
327            .get("grant_types")
328            .and_then(|v| v.as_array())
329            .map(|arr| {
330                arr.iter()
331                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
332                    .collect::<Vec<String>>()
333            })
334            .unwrap_or_else(|| vec!["authorization_code".to_string()]);
335
336        let requires_redirect_uri = grant_types.iter().all(|g| {
337            !matches!(
338                g.as_str(),
339                "client_credentials" | "urn:ietf:params:oauth:grant-type:device_code"
340            )
341        });
342
343        if redirect_uris.is_empty() && requires_redirect_uri {
344            return Err(crate::error::OauthError::InvalidRequest(Some(
345                "redirect_uris required".to_string(),
346            )));
347        }
348
349        let client = Client {
350            client_id: client_id.clone(),
351            client_secret,
352            redirect_uris,
353            grant_types: metadata
354                .get("grant_types")
355                .and_then(|v| v.as_array())
356                .map(|arr| {
357                    arr.iter()
358                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
359                        .collect()
360                })
361                .unwrap_or_else(|| vec!["authorization_code".to_string()]),
362            response_types: metadata
363                .get("response_types")
364                .and_then(|v| v.as_array())
365                .map(|arr| {
366                    arr.iter()
367                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
368                        .collect()
369                })
370                .unwrap_or_else(|| vec!["code".to_string()]),
371            scope: metadata
372                .get("scope")
373                .and_then(|v| v.as_str())
374                .unwrap_or("")
375                .to_string(),
376            token_endpoint_auth_method: metadata
377                .get("token_endpoint_auth_method")
378                .and_then(|v| v.as_str())
379                .unwrap_or("client_secret_basic")
380                .to_string(),
381            client_name: metadata
382                .get("client_name")
383                .and_then(|v| v.as_str())
384                .map(|s| s.to_string()),
385            client_uri: metadata
386                .get("client_uri")
387                .and_then(|v| v.as_str())
388                .map(|s| s.to_string()),
389            logo_uri: metadata
390                .get("logo_uri")
391                .and_then(|v| v.as_str())
392                .map(|s| s.to_string()),
393            contacts: metadata
394                .get("contacts")
395                .and_then(|v| v.as_array())
396                .map(|arr| {
397                    arr.iter()
398                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
399                        .collect()
400                })
401                .unwrap_or_default(),
402            policy_uri: metadata
403                .get("policy_uri")
404                .and_then(|v| v.as_str())
405                .map(|s| s.to_string()),
406            tos_uri: metadata
407                .get("tos_uri")
408                .and_then(|v| v.as_str())
409                .map(|s| s.to_string()),
410            jwks: metadata.get("jwks").cloned(),
411            jwks_uri: metadata
412                .get("jwks_uri")
413                .and_then(|v| v.as_str())
414                .map(|s| s.to_string()),
415            software_id: metadata
416                .get("software_id")
417                .and_then(|v| v.as_str())
418                .map(|s| s.to_string()),
419            software_version: metadata
420                .get("software_version")
421                .and_then(|v| v.as_str())
422                .map(|s| s.to_string()),
423            registration_access_token: None,
424            registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
425        };
426
427        self.store.insert_client(client.clone()).await;
428
429        Ok(client)
430    }
431
432    /// Issue a JWT and store it; used by `testkit` helpers.
433    #[cfg(feature = "testing")]
434    pub async fn generate_token(
435        &self,
436        client: &Client,
437        options: crate::testkit::JwtOptions,
438    ) -> Result<Token, jsonwebtoken::errors::Error> {
439        let user_id = options.user_id.clone();
440        let jwt = self.generate_jwt(client, options)?;
441        let refresh_token = generate_token_string();
442        let token = Token {
443            access_token: jwt.clone(),
444            refresh_token: Some(refresh_token.clone()),
445            client_id: client.client_id.clone(),
446            scope: client.scope.clone(),
447            expires_at: Utc::now() + Duration::hours(1),
448            user_id,
449            revoked: false,
450        };
451        self.store.insert_token(jwt.clone(), token.clone()).await;
452        self.store
453            .insert_refresh_token(refresh_token, token.clone())
454            .await;
455        Ok(token)
456    }
457
458    /// Sign a JWT for the given client; used by `testkit` helpers.
459    #[cfg(feature = "testing")]
460    pub fn generate_jwt(
461        &self,
462        client: &Client,
463        options: crate::testkit::JwtOptions,
464    ) -> Result<String, jsonwebtoken::errors::Error> {
465        let scope = options.scope.unwrap_or_else(|| client.scope.clone());
466        issue_jwt(
467            self.issuer(),
468            &client.client_id,
469            &options.user_id,
470            &scope,
471            options.expires_in,
472            &self.keys,
473        )
474    }
475
476    #[cfg(feature = "testing")]
477    pub async fn approve_device_code(&self, device_code: &str, user_id: &str) -> Option<()> {
478        let mut device_auth = self.store.get_device_code(device_code).await?;
479        device_auth.approved = true;
480        device_auth.user_id = Some(user_id.to_string());
481        self.store
482            .update_device_code(device_code, device_auth)
483            .await;
484        Some(())
485    }
486
487    /// Build the Axum router and bind to a TCP listener.
488    pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
489        let port = self.config.port;
490        let addr = SocketAddr::from(([127, 0, 0, 1], port));
491        let listener = TcpListener::bind(addr)
492            .await
493            .expect("Failed to bind TCP listener - port may be in use");
494        let local_addr = listener
495            .local_addr()
496            .expect("Failed to get local address from listener");
497        let base_url = format!(
498            "{}://{}:{}",
499            self.config.scheme,
500            self.config.host,
501            local_addr.port()
502        );
503        self.base_url = base_url;
504
505        let cleanup_interval = self.config.cleanup_interval_secs;
506        let store = self.store.clone();
507
508        let cleanup_handle = if cleanup_interval > 0 {
509            Some(tokio::spawn(async move {
510                let mut interval =
511                    tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval));
512                loop {
513                    interval.tick().await;
514                    let codes_cleaned = store.cleanup_expired_codes().await;
515                    let tokens_cleaned = store.cleanup_expired_tokens().await;
516                    let refresh_cleaned = store.cleanup_expired_refresh_tokens().await;
517                    let device_codes_cleaned = store.cleanup_expired_device_codes().await;
518                    if codes_cleaned > 0
519                        || tokens_cleaned > 0
520                        || refresh_cleaned > 0
521                        || device_codes_cleaned > 0
522                    {
523                        tracing::debug!(
524                            "Cleaned up expired entries: codes={}, tokens={}, refresh={}, device_codes={}",
525                            codes_cleaned, tokens_cleaned, refresh_cleaned, device_codes_cleaned
526                        );
527                    }
528                }
529            }))
530        } else {
531            None
532        };
533
534        let router = crate::router::build_router(self);
535        let server_handle = tokio::spawn(async move {
536            axum::serve(listener, router).await.unwrap();
537        });
538
539        let combined_handle = tokio::spawn(async move {
540            if let Some(cleanup) = cleanup_handle {
541                tokio::select! {
542                    _ = server_handle => {}
543                    _ = cleanup => {}
544                }
545            } else {
546                server_handle.await.unwrap();
547            }
548        });
549        (local_addr, combined_handle)
550    }
551}