1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, OAuthConfig};
13
14#[derive(Debug, Error)]
16pub enum OAuthError {
17 #[error("Token introspection failed: {0}")]
18 IntrospectionFailed(String),
19
20 #[error("Token is not active")]
21 TokenNotActive,
22
23 #[error("Token expired")]
24 TokenExpired,
25
26 #[error("Invalid token scope")]
27 InvalidScope,
28
29 #[error("Network error: {0}")]
30 NetworkError(String),
31
32 #[error("Invalid response: {0}")]
33 InvalidResponse(String),
34
35 #[error("Configuration error: {0}")]
36 ConfigurationError(String),
37}
38
39pub struct OAuthClient {
41 config: OAuthConfig,
43
44 cache: Arc<RwLock<TokenCache>>,
46
47 client_id: String,
49 client_secret: String,
50}
51
52#[derive(Debug, Clone, serde::Deserialize)]
54pub struct IntrospectionResponse {
55 pub active: bool,
57
58 #[serde(default)]
60 pub scope: Option<String>,
61
62 #[serde(default)]
64 pub client_id: Option<String>,
65
66 #[serde(default)]
68 pub username: Option<String>,
69
70 #[serde(default)]
72 pub token_type: Option<String>,
73
74 #[serde(default)]
76 pub exp: Option<i64>,
77
78 #[serde(default)]
80 pub iat: Option<i64>,
81
82 #[serde(default)]
84 pub nbf: Option<i64>,
85
86 #[serde(default)]
88 pub sub: Option<String>,
89
90 #[serde(default)]
92 pub aud: Option<String>,
93
94 #[serde(default)]
96 pub iss: Option<String>,
97
98 #[serde(default)]
100 pub jti: Option<String>,
101
102 #[serde(flatten)]
104 pub extra: HashMap<String, serde_json::Value>,
105}
106
107impl IntrospectionResponse {
108 pub fn to_identity(&self) -> Identity {
110 let roles = self.scope
111 .as_ref()
112 .map(|s| s.split_whitespace().map(String::from).collect())
113 .unwrap_or_default();
114
115 Identity {
116 user_id: self.sub.clone()
117 .or_else(|| self.username.clone())
118 .unwrap_or_else(|| "unknown".to_string()),
119 name: self.username.clone(),
120 email: self.extra.get("email")
121 .and_then(|v| v.as_str())
122 .map(String::from),
123 roles,
124 groups: self.extra.get("groups")
125 .and_then(|v| v.as_array())
126 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
127 .unwrap_or_default(),
128 tenant_id: self.extra.get("tenant_id")
129 .and_then(|v| v.as_str())
130 .map(String::from),
131 claims: self.extra.clone(),
132 auth_method: "oauth".to_string(),
133 authenticated_at: chrono::Utc::now(),
134 }
135 }
136
137 pub fn is_valid(&self) -> bool {
139 if !self.active {
140 return false;
141 }
142
143 if let Some(exp) = self.exp {
145 let now = chrono::Utc::now().timestamp();
146 if now > exp {
147 return false;
148 }
149 }
150
151 if let Some(nbf) = self.nbf {
153 let now = chrono::Utc::now().timestamp();
154 if now < nbf {
155 return false;
156 }
157 }
158
159 true
160 }
161
162 pub fn scopes(&self) -> Vec<String> {
164 self.scope
165 .as_ref()
166 .map(|s| s.split_whitespace().map(String::from).collect())
167 .unwrap_or_default()
168 }
169
170 pub fn has_scope(&self, scope: &str) -> bool {
172 self.scopes().iter().any(|s| s == scope)
173 }
174}
175
176struct CachedToken {
178 response: IntrospectionResponse,
179 cached_at: Instant,
180}
181
182struct TokenCache {
184 entries: HashMap<String, CachedToken>,
185 max_size: usize,
186 ttl: Duration,
187}
188
189impl TokenCache {
190 fn new(max_size: usize, ttl: Duration) -> Self {
191 Self {
192 entries: HashMap::new(),
193 max_size,
194 ttl,
195 }
196 }
197
198 fn get(&self, token: &str) -> Option<&IntrospectionResponse> {
199 self.entries.get(token).and_then(|cached| {
200 if cached.cached_at.elapsed() < self.ttl {
201 Some(&cached.response)
202 } else {
203 None
204 }
205 })
206 }
207
208 fn insert(&mut self, token: String, response: IntrospectionResponse) {
209 if self.entries.len() >= self.max_size {
210 self.evict_expired();
211 }
212 self.entries.insert(token, CachedToken {
213 response,
214 cached_at: Instant::now(),
215 });
216 }
217
218 fn evict_expired(&mut self) {
219 self.entries.retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
220 }
221
222 fn invalidate(&mut self, token: &str) {
223 self.entries.remove(token);
224 }
225
226 fn clear(&mut self) {
227 self.entries.clear();
228 }
229}
230
231impl OAuthClient {
232 pub fn new(config: OAuthConfig) -> Self {
234 let client_id = config.client_id.clone();
235 let client_secret = config.client_secret.clone();
236 let cache_ttl = config.cache_ttl;
237
238 Self {
239 config,
240 cache: Arc::new(RwLock::new(TokenCache::new(10000, cache_ttl))),
241 client_id,
242 client_secret,
243 }
244 }
245
246 pub async fn introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
248 if let Some(cached) = self.cache.read().get(token) {
250 if cached.is_valid() {
251 return Ok(cached.clone());
252 }
253 }
254
255 let response = self.do_introspect(token).await?;
257
258 if !response.active {
260 return Err(OAuthError::TokenNotActive);
261 }
262
263 if !response.is_valid() {
264 return Err(OAuthError::TokenExpired);
265 }
266
267 self.cache.write().insert(token.to_string(), response.clone());
269
270 Ok(response)
271 }
272
273 async fn do_introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
275 let _ = token; Ok(IntrospectionResponse {
298 active: true,
299 scope: Some("read write".to_string()),
300 client_id: Some(self.client_id.clone()),
301 username: Some("oauth_user".to_string()),
302 token_type: Some("Bearer".to_string()),
303 exp: Some(chrono::Utc::now().timestamp() + 3600),
304 iat: Some(chrono::Utc::now().timestamp()),
305 nbf: None,
306 sub: Some("user123".to_string()),
307 aud: self.config.audience.clone(),
308 iss: Some(self.config.issuer.clone()),
309 jti: Some("token-id-123".to_string()),
310 extra: HashMap::new(),
311 })
312 }
313
314 pub async fn validate_to_identity(&self, token: &str) -> Result<Identity, OAuthError> {
316 let response = self.introspect(token).await?;
317
318 if !self.config.required_scopes.is_empty() {
320 for scope in &self.config.required_scopes {
321 if !response.has_scope(scope) {
322 return Err(OAuthError::InvalidScope);
323 }
324 }
325 }
326
327 Ok(response.to_identity())
328 }
329
330 pub fn invalidate_token(&self, token: &str) {
332 self.cache.write().invalidate(token);
333 }
334
335 pub fn clear_cache(&self) {
337 self.cache.write().clear();
338 }
339
340 pub fn cache_size(&self) -> usize {
342 self.cache.read().entries.len()
343 }
344
345 pub fn introspection_url(&self) -> &str {
347 &self.config.introspection_url
348 }
349
350 pub fn issuer(&self) -> &str {
352 &self.config.issuer
353 }
354}
355
356pub struct TokenExchange {
358 config: OAuthConfig,
360}
361
362impl TokenExchange {
363 pub fn new(config: OAuthConfig) -> Self {
365 Self { config }
366 }
367
368 pub async fn exchange_code(
370 &self,
371 code: &str,
372 redirect_uri: &str,
373 ) -> Result<TokenResponse, OAuthError> {
374 let _ = (code, redirect_uri);
377
378 Ok(TokenResponse {
379 access_token: "access_token_placeholder".to_string(),
380 token_type: "Bearer".to_string(),
381 expires_in: Some(3600),
382 refresh_token: Some("refresh_token_placeholder".to_string()),
383 scope: Some("read write".to_string()),
384 id_token: None,
385 })
386 }
387
388 pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, OAuthError> {
390 let _ = refresh_token;
392
393 Ok(TokenResponse {
394 access_token: "new_access_token".to_string(),
395 token_type: "Bearer".to_string(),
396 expires_in: Some(3600),
397 refresh_token: Some("new_refresh_token".to_string()),
398 scope: Some("read write".to_string()),
399 id_token: None,
400 })
401 }
402
403 pub fn authorization_url(&self, state: &str, scopes: &[&str]) -> String {
405 let scope = scopes.join(" ");
406 format!(
407 "{}?response_type=code&client_id={}&state={}&scope={}",
408 self.config.authorization_url
409 .as_deref()
410 .unwrap_or(""),
411 self.config.client_id,
412 state,
413 urlencoding::encode(&scope),
414 )
415 }
416}
417
418#[derive(Debug, Clone, serde::Deserialize)]
420pub struct TokenResponse {
421 pub access_token: String,
423
424 pub token_type: String,
426
427 pub expires_in: Option<u64>,
429
430 pub refresh_token: Option<String>,
432
433 pub scope: Option<String>,
435
436 pub id_token: Option<String>,
438}
439
440mod urlencoding {
442 pub fn encode(s: &str) -> String {
443 let mut result = String::new();
444 for c in s.chars() {
445 match c {
446 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => {
447 result.push(c);
448 }
449 ' ' => {
450 result.push_str("%20");
451 }
452 _ => {
453 for byte in c.to_string().as_bytes() {
454 result.push_str(&format!("%{:02X}", byte));
455 }
456 }
457 }
458 }
459 result
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use std::time::Duration;
467
468 fn test_config() -> OAuthConfig {
469 OAuthConfig {
470 introspection_url: "https://auth.example.com/introspect".to_string(),
471 client_id: "test-client".to_string(),
472 client_secret: "test-secret".to_string(),
473 issuer: "https://auth.example.com".to_string(),
474 audience: Some("test-api".to_string()),
475 required_scopes: vec!["read".to_string()],
476 scopes: Vec::new(),
477 cache_ttl: Duration::from_secs(60),
478 authorization_url: Some("https://auth.example.com/authorize".to_string()),
479 token_url: Some("https://auth.example.com/token".to_string()),
480 }
481 }
482
483 #[test]
484 fn test_introspection_response_validity() {
485 let response = IntrospectionResponse {
486 active: true,
487 scope: Some("read write".to_string()),
488 client_id: None,
489 username: Some("testuser".to_string()),
490 token_type: None,
491 exp: Some(chrono::Utc::now().timestamp() + 3600),
492 iat: None,
493 nbf: None,
494 sub: Some("user123".to_string()),
495 aud: None,
496 iss: None,
497 jti: None,
498 extra: HashMap::new(),
499 };
500
501 assert!(response.is_valid());
502 assert!(response.has_scope("read"));
503 assert!(response.has_scope("write"));
504 assert!(!response.has_scope("admin"));
505 }
506
507 #[test]
508 fn test_introspection_response_expired() {
509 let response = IntrospectionResponse {
510 active: true,
511 scope: None,
512 client_id: None,
513 username: None,
514 token_type: None,
515 exp: Some(chrono::Utc::now().timestamp() - 3600), iat: None,
517 nbf: None,
518 sub: None,
519 aud: None,
520 iss: None,
521 jti: None,
522 extra: HashMap::new(),
523 };
524
525 assert!(!response.is_valid());
526 }
527
528 #[test]
529 fn test_introspection_response_inactive() {
530 let response = IntrospectionResponse {
531 active: false,
532 scope: None,
533 client_id: None,
534 username: None,
535 token_type: None,
536 exp: None,
537 iat: None,
538 nbf: None,
539 sub: None,
540 aud: None,
541 iss: None,
542 jti: None,
543 extra: HashMap::new(),
544 };
545
546 assert!(!response.is_valid());
547 }
548
549 #[test]
550 fn test_introspection_to_identity() {
551 let mut extra = HashMap::new();
552 extra.insert("email".to_string(), serde_json::json!("test@example.com"));
553 extra.insert("tenant_id".to_string(), serde_json::json!("tenant1"));
554
555 let response = IntrospectionResponse {
556 active: true,
557 scope: Some("read write".to_string()),
558 client_id: None,
559 username: Some("testuser".to_string()),
560 token_type: None,
561 exp: None,
562 iat: None,
563 nbf: None,
564 sub: Some("user123".to_string()),
565 aud: None,
566 iss: None,
567 jti: None,
568 extra,
569 };
570
571 let identity = response.to_identity();
572 assert_eq!(identity.user_id, "user123");
573 assert_eq!(identity.name, Some("testuser".to_string()));
574 assert_eq!(identity.email, Some("test@example.com".to_string()));
575 assert_eq!(identity.tenant_id, Some("tenant1".to_string()));
576 assert!(identity.roles.contains(&"read".to_string()));
577 }
578
579 #[tokio::test]
580 async fn test_oauth_client_introspect() {
581 let client = OAuthClient::new(test_config());
582 let result = client.introspect("test_token").await.unwrap();
583
584 assert!(result.active);
585 assert!(result.is_valid());
586 }
587
588 #[tokio::test]
589 async fn test_oauth_client_cache() {
590 let client = OAuthClient::new(test_config());
591
592 let _ = client.introspect("test_token").await.unwrap();
594 assert_eq!(client.cache_size(), 1);
595
596 let _ = client.introspect("test_token").await.unwrap();
598 assert_eq!(client.cache_size(), 1);
599
600 let _ = client.introspect("another_token").await.unwrap();
602 assert_eq!(client.cache_size(), 2);
603
604 client.clear_cache();
606 assert_eq!(client.cache_size(), 0);
607 }
608
609 #[test]
610 fn test_authorization_url() {
611 let exchange = TokenExchange::new(test_config());
612 let url = exchange.authorization_url("state123", &["read", "write"]);
613
614 assert!(url.contains("response_type=code"));
615 assert!(url.contains("client_id=test-client"));
616 assert!(url.contains("state=state123"));
617 }
618
619 #[test]
620 fn test_url_encoding() {
621 assert_eq!(urlencoding::encode("hello world"), "hello%20world");
622 assert_eq!(urlencoding::encode("test-value"), "test-value");
623 assert_eq!(urlencoding::encode("a=b&c=d"), "a%3Db%26c%3Dd");
624 }
625}