1use chrono::{Duration, Utc};
2use std::{
3 collections::HashMap,
4 net::SocketAddr,
5 sync::{Arc, RwLock},
6};
7use tokio::{net::TcpListener, task::JoinHandle};
8use uuid::Uuid;
9
10use crate::{
11 config::IssuerConfig,
12 crypto::{build_jwks_json, generate_token_string, issue_jwt, Keys},
13 models::{AuthorizationCode, Client, DeviceAuthorization, Token},
14};
15
16#[async_trait::async_trait]
17pub trait OauthStore: Send + Sync {
18 async fn get_client(&self, client_id: &str) -> Option<Client>;
19 async fn insert_client(&self, client: Client);
20
21 async fn get_code(&self, code: &str) -> Option<AuthorizationCode>;
22 async fn remove_code(&self, code: &str) -> Option<AuthorizationCode>;
23 async fn insert_code(&self, code: String, auth_code: AuthorizationCode);
24 async fn cleanup_expired_codes(&self) -> usize;
25
26 async fn get_token(&self, token: &str) -> Option<Token>;
27 async fn insert_token(&self, token: String, value: Token);
28 async fn update_token(&self, token: &str, value: Token);
29 async fn cleanup_expired_tokens(&self) -> usize;
30
31 async fn get_refresh_token(&self, token: &str) -> Option<Token>;
32 async fn insert_refresh_token(&self, token: String, value: Token);
33 async fn update_refresh_token(&self, token: &str, value: Token);
34 async fn cleanup_expired_refresh_tokens(&self) -> usize;
35
36 async fn get_device_code(&self, device_code: &str) -> Option<DeviceAuthorization>;
37 async fn insert_device_code(&self, device_code: String, auth: DeviceAuthorization);
38 async fn update_device_code(&self, device_code: &str, auth: DeviceAuthorization);
39 async fn cleanup_expired_device_codes(&self) -> usize;
40
41 async fn get_all_clients(&self) -> Vec<Client>;
42 async fn get_all_codes(&self) -> Vec<AuthorizationCode>;
43 async fn get_all_tokens(&self) -> Vec<Token>;
44 async fn get_all_refresh_tokens(&self) -> Vec<Token>;
45
46 async fn clear_clients(&self);
47 async fn clear_codes(&self);
48 async fn clear_tokens(&self);
49 async fn clear_refresh_tokens(&self);
50 async fn clear_device_codes(&self);
51 async fn clear_all(&self);
52}
53
54#[derive(Clone)]
55pub struct InMemoryStore {
56 clients: Arc<RwLock<HashMap<String, Client>>>,
57 codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
58 tokens: Arc<RwLock<HashMap<String, Token>>>,
59 refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
60 device_codes: Arc<RwLock<HashMap<String, DeviceAuthorization>>>,
61}
62
63impl Default for InMemoryStore {
64 fn default() -> Self {
65 Self {
66 clients: Arc::new(RwLock::new(HashMap::new())),
67 codes: Arc::new(RwLock::new(HashMap::new())),
68 tokens: Arc::new(RwLock::new(HashMap::new())),
69 refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
70 device_codes: Arc::new(RwLock::new(HashMap::new())),
71 }
72 }
73}
74
75#[async_trait::async_trait]
76impl OauthStore for InMemoryStore {
77 async fn get_client(&self, client_id: &str) -> Option<Client> {
78 self.clients.read().unwrap().get(client_id).cloned()
79 }
80
81 async fn insert_client(&self, client: Client) {
82 self.clients
83 .write()
84 .unwrap()
85 .insert(client.client_id.clone(), client);
86 }
87
88 async fn get_code(&self, code: &str) -> Option<AuthorizationCode> {
89 self.codes.read().unwrap().get(code).cloned()
90 }
91
92 async fn remove_code(&self, code: &str) -> Option<AuthorizationCode> {
93 self.codes.write().unwrap().remove(code)
94 }
95
96 async fn insert_code(&self, code: String, auth_code: AuthorizationCode) {
97 self.codes.write().unwrap().insert(code, auth_code);
98 }
99
100 async fn get_token(&self, token: &str) -> Option<Token> {
101 self.tokens.read().unwrap().get(token).cloned()
102 }
103
104 async fn insert_token(&self, token: String, value: Token) {
105 self.tokens.write().unwrap().insert(token, value);
106 }
107
108 async fn update_token(&self, token: &str, value: Token) {
109 if let Some(t) = self.tokens.write().unwrap().get_mut(token) {
110 *t = value;
111 }
112 }
113
114 async fn get_refresh_token(&self, token: &str) -> Option<Token> {
115 self.refresh_tokens.read().unwrap().get(token).cloned()
116 }
117
118 async fn insert_refresh_token(&self, token: String, value: Token) {
119 self.refresh_tokens.write().unwrap().insert(token, value);
120 }
121
122 async fn update_refresh_token(&self, token: &str, value: Token) {
123 if let Some(t) = self.refresh_tokens.write().unwrap().get_mut(token) {
124 *t = value;
125 }
126 }
127
128 async fn get_device_code(&self, device_code: &str) -> Option<DeviceAuthorization> {
129 self.device_codes.read().unwrap().get(device_code).cloned()
130 }
131
132 async fn insert_device_code(&self, device_code: String, auth: DeviceAuthorization) {
133 self.device_codes.write().unwrap().insert(device_code, auth);
134 }
135
136 async fn update_device_code(&self, device_code: &str, auth: DeviceAuthorization) {
137 if let Some(a) = self.device_codes.write().unwrap().get_mut(device_code) {
138 *a = auth;
139 }
140 }
141
142 async fn cleanup_expired_codes(&self) -> usize {
143 let now = Utc::now();
144 let mut count = 0;
145 let mut codes = self.codes.write().unwrap();
146 codes.retain(|_, code| {
147 if code.expires_at < now {
148 count += 1;
149 false
150 } else {
151 true
152 }
153 });
154 count
155 }
156
157 async fn cleanup_expired_tokens(&self) -> usize {
158 let now = Utc::now();
159 let mut count = 0;
160 let mut tokens = self.tokens.write().unwrap();
161 tokens.retain(|_, token| {
162 if token.expires_at < now {
163 count += 1;
164 false
165 } else {
166 true
167 }
168 });
169 count
170 }
171
172 async fn cleanup_expired_refresh_tokens(&self) -> usize {
173 let now = Utc::now();
174 let mut count = 0;
175 let mut tokens = self.refresh_tokens.write().unwrap();
176 tokens.retain(|_, token| {
177 if token.expires_at < now {
178 count += 1;
179 false
180 } else {
181 true
182 }
183 });
184 count
185 }
186
187 async fn cleanup_expired_device_codes(&self) -> usize {
188 let now = Utc::now();
189 let mut count = 0;
190 let mut codes = self.device_codes.write().unwrap();
191 codes.retain(|_, code| {
192 if code.expires_at < now {
193 count += 1;
194 false
195 } else {
196 true
197 }
198 });
199 count
200 }
201
202 async fn get_all_clients(&self) -> Vec<Client> {
203 self.clients.read().unwrap().values().cloned().collect()
204 }
205
206 async fn get_all_codes(&self) -> Vec<AuthorizationCode> {
207 self.codes.read().unwrap().values().cloned().collect()
208 }
209
210 async fn get_all_tokens(&self) -> Vec<Token> {
211 self.tokens.read().unwrap().values().cloned().collect()
212 }
213
214 async fn get_all_refresh_tokens(&self) -> Vec<Token> {
215 self.refresh_tokens
216 .read()
217 .unwrap()
218 .values()
219 .cloned()
220 .collect()
221 }
222
223 async fn clear_clients(&self) {
224 self.clients.write().unwrap().clear();
225 }
226
227 async fn clear_codes(&self) {
228 self.codes.write().unwrap().clear();
229 }
230
231 async fn clear_tokens(&self) {
232 self.tokens.write().unwrap().clear();
233 }
234
235 async fn clear_refresh_tokens(&self) {
236 self.refresh_tokens.write().unwrap().clear();
237 }
238
239 async fn clear_device_codes(&self) {
240 self.device_codes.write().unwrap().clear();
241 }
242
243 async fn clear_all(&self) {
244 self.clear_clients().await;
245 self.clear_codes().await;
246 self.clear_tokens().await;
247 self.clear_refresh_tokens().await;
248 self.clear_device_codes().await;
249 }
250}
251
252#[derive(Clone)]
256pub struct AppState {
257 pub config: Arc<IssuerConfig>,
258 pub base_url: String,
259 pub store: Arc<dyn OauthStore>,
260 pub keys: Arc<Keys>,
261 pub jwks_json: Arc<serde_json::Value>,
262}
263
264impl AppState {
265 pub fn new(config: IssuerConfig) -> Self {
267 Self::with_store(config, Arc::new(InMemoryStore::default()))
268 }
269
270 pub fn with_store(config: IssuerConfig, store: Arc<dyn OauthStore>) -> Self {
272 let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
273 let keys = Arc::new(Keys::generate());
274 let jwks_json = Arc::new(build_jwks_json(&keys));
275 Self {
276 config: Arc::new(config),
277 store,
278 base_url,
279 keys,
280 jwks_json,
281 }
282 }
283
284 pub fn issuer(&self) -> &str {
286 self.base_url.as_str()
287 }
288
289 pub async fn register_client(
291 &self,
292 metadata: serde_json::Value,
293 ) -> Result<Client, crate::error::OauthError> {
294 let requested_scope = metadata
295 .get("scope")
296 .and_then(|v| v.as_str())
297 .unwrap_or("openid");
298
299 self.config
300 .validate_scope(requested_scope)
301 .map_err(crate::error::OauthError::InvalidScope)?;
302
303 let client_id = Uuid::new_v4().to_string();
304
305 let client_secret = if self.config.generate_client_secret_for_dcr
306 || metadata
307 .get("token_endpoint_auth_method")
308 .and_then(|v| v.as_str())
309 != Some("none")
310 {
311 Some(generate_token_string())
312 } else {
313 None
314 };
315
316 let redirect_uris = metadata
317 .get("redirect_uris")
318 .and_then(|v| v.as_array())
319 .map(|arr| {
320 arr.iter()
321 .filter_map(|u| u.as_str().map(|s| s.to_string()))
322 .collect::<Vec<String>>()
323 })
324 .unwrap_or_default();
325
326 let grant_types = metadata
327 .get("grant_types")
328 .and_then(|v| v.as_array())
329 .map(|arr| {
330 arr.iter()
331 .filter_map(|v| v.as_str().map(|s| s.to_string()))
332 .collect::<Vec<String>>()
333 })
334 .unwrap_or_else(|| vec!["authorization_code".to_string()]);
335
336 let requires_redirect_uri = grant_types.iter().all(|g| {
337 !matches!(
338 g.as_str(),
339 "client_credentials" | "urn:ietf:params:oauth:grant-type:device_code"
340 )
341 });
342
343 if redirect_uris.is_empty() && requires_redirect_uri {
344 return Err(crate::error::OauthError::InvalidRequest(Some(
345 "redirect_uris required".to_string(),
346 )));
347 }
348
349 let client = Client {
350 client_id: client_id.clone(),
351 client_secret,
352 redirect_uris,
353 grant_types: metadata
354 .get("grant_types")
355 .and_then(|v| v.as_array())
356 .map(|arr| {
357 arr.iter()
358 .filter_map(|v| v.as_str().map(|s| s.to_string()))
359 .collect()
360 })
361 .unwrap_or_else(|| vec!["authorization_code".to_string()]),
362 response_types: metadata
363 .get("response_types")
364 .and_then(|v| v.as_array())
365 .map(|arr| {
366 arr.iter()
367 .filter_map(|v| v.as_str().map(|s| s.to_string()))
368 .collect()
369 })
370 .unwrap_or_else(|| vec!["code".to_string()]),
371 scope: metadata
372 .get("scope")
373 .and_then(|v| v.as_str())
374 .unwrap_or("")
375 .to_string(),
376 token_endpoint_auth_method: metadata
377 .get("token_endpoint_auth_method")
378 .and_then(|v| v.as_str())
379 .unwrap_or("client_secret_basic")
380 .to_string(),
381 client_name: metadata
382 .get("client_name")
383 .and_then(|v| v.as_str())
384 .map(|s| s.to_string()),
385 client_uri: metadata
386 .get("client_uri")
387 .and_then(|v| v.as_str())
388 .map(|s| s.to_string()),
389 logo_uri: metadata
390 .get("logo_uri")
391 .and_then(|v| v.as_str())
392 .map(|s| s.to_string()),
393 contacts: metadata
394 .get("contacts")
395 .and_then(|v| v.as_array())
396 .map(|arr| {
397 arr.iter()
398 .filter_map(|v| v.as_str().map(|s| s.to_string()))
399 .collect()
400 })
401 .unwrap_or_default(),
402 policy_uri: metadata
403 .get("policy_uri")
404 .and_then(|v| v.as_str())
405 .map(|s| s.to_string()),
406 tos_uri: metadata
407 .get("tos_uri")
408 .and_then(|v| v.as_str())
409 .map(|s| s.to_string()),
410 jwks: metadata.get("jwks").cloned(),
411 jwks_uri: metadata
412 .get("jwks_uri")
413 .and_then(|v| v.as_str())
414 .map(|s| s.to_string()),
415 software_id: metadata
416 .get("software_id")
417 .and_then(|v| v.as_str())
418 .map(|s| s.to_string()),
419 software_version: metadata
420 .get("software_version")
421 .and_then(|v| v.as_str())
422 .map(|s| s.to_string()),
423 registration_access_token: None,
424 registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
425 };
426
427 self.store.insert_client(client.clone()).await;
428
429 Ok(client)
430 }
431
432 #[cfg(feature = "testing")]
434 pub async fn generate_token(
435 &self,
436 client: &Client,
437 options: crate::testkit::JwtOptions,
438 ) -> Result<Token, jsonwebtoken::errors::Error> {
439 let user_id = options.user_id.clone();
440 let jwt = self.generate_jwt(client, options)?;
441 let refresh_token = generate_token_string();
442 let token = Token {
443 access_token: jwt.clone(),
444 refresh_token: Some(refresh_token.clone()),
445 client_id: client.client_id.clone(),
446 scope: client.scope.clone(),
447 expires_at: Utc::now() + Duration::hours(1),
448 user_id,
449 revoked: false,
450 };
451 self.store.insert_token(jwt.clone(), token.clone()).await;
452 self.store
453 .insert_refresh_token(refresh_token, token.clone())
454 .await;
455 Ok(token)
456 }
457
458 #[cfg(feature = "testing")]
460 pub fn generate_jwt(
461 &self,
462 client: &Client,
463 options: crate::testkit::JwtOptions,
464 ) -> Result<String, jsonwebtoken::errors::Error> {
465 let scope = options.scope.unwrap_or_else(|| client.scope.clone());
466 issue_jwt(
467 self.issuer(),
468 &client.client_id,
469 &options.user_id,
470 &scope,
471 options.expires_in,
472 &self.keys,
473 )
474 }
475
476 #[cfg(feature = "testing")]
477 pub async fn approve_device_code(&self, device_code: &str, user_id: &str) -> Option<()> {
478 let mut device_auth = self.store.get_device_code(device_code).await?;
479 device_auth.approved = true;
480 device_auth.user_id = Some(user_id.to_string());
481 self.store
482 .update_device_code(device_code, device_auth)
483 .await;
484 Some(())
485 }
486
487 pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
489 let port = self.config.port;
490 let addr = SocketAddr::from(([127, 0, 0, 1], port));
491 let listener = TcpListener::bind(addr)
492 .await
493 .expect("Failed to bind TCP listener - port may be in use");
494 let local_addr = listener
495 .local_addr()
496 .expect("Failed to get local address from listener");
497 let base_url = format!(
498 "{}://{}:{}",
499 self.config.scheme,
500 self.config.host,
501 local_addr.port()
502 );
503 self.base_url = base_url;
504
505 let cleanup_interval = self.config.cleanup_interval_secs;
506 let store = self.store.clone();
507
508 let cleanup_handle = if cleanup_interval > 0 {
509 Some(tokio::spawn(async move {
510 let mut interval =
511 tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval));
512 loop {
513 interval.tick().await;
514 let codes_cleaned = store.cleanup_expired_codes().await;
515 let tokens_cleaned = store.cleanup_expired_tokens().await;
516 let refresh_cleaned = store.cleanup_expired_refresh_tokens().await;
517 let device_codes_cleaned = store.cleanup_expired_device_codes().await;
518 if codes_cleaned > 0
519 || tokens_cleaned > 0
520 || refresh_cleaned > 0
521 || device_codes_cleaned > 0
522 {
523 tracing::debug!(
524 "Cleaned up expired entries: codes={}, tokens={}, refresh={}, device_codes={}",
525 codes_cleaned, tokens_cleaned, refresh_cleaned, device_codes_cleaned
526 );
527 }
528 }
529 }))
530 } else {
531 None
532 };
533
534 let router = crate::router::build_router(self);
535 let server_handle = tokio::spawn(async move {
536 axum::serve(listener, router).await.unwrap();
537 });
538
539 let combined_handle = tokio::spawn(async move {
540 if let Some(cleanup) = cleanup_handle {
541 tokio::select! {
542 _ = server_handle => {}
543 _ = cleanup => {}
544 }
545 } else {
546 server_handle.await.unwrap();
547 }
548 });
549 (local_addr, combined_handle)
550 }
551}