Skip to main content

heliosdb_proxy/auth/
oauth.rs

1//! OAuth Token Introspection
2//!
3//! Validates OAuth access tokens using RFC 7662 token introspection.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, OAuthConfig};
13
14/// OAuth errors
15#[derive(Debug, Error)]
16pub enum OAuthError {
17    #[error("Token introspection failed: {0}")]
18    IntrospectionFailed(String),
19
20    #[error("Token is not active")]
21    TokenNotActive,
22
23    #[error("Token expired")]
24    TokenExpired,
25
26    #[error("Invalid token scope")]
27    InvalidScope,
28
29    #[error("Network error: {0}")]
30    NetworkError(String),
31
32    #[error("Invalid response: {0}")]
33    InvalidResponse(String),
34
35    #[error("Configuration error: {0}")]
36    ConfigurationError(String),
37}
38
39/// OAuth client for token introspection
40pub struct OAuthClient {
41    /// Configuration
42    config: OAuthConfig,
43
44    /// Token cache
45    cache: Arc<RwLock<TokenCache>>,
46
47    /// HTTP client (placeholder - would use reqwest in real impl)
48    client_id: String,
49    #[allow(dead_code)]
50    client_secret: String,
51}
52
53/// Token introspection response
54#[derive(Debug, Clone, serde::Deserialize)]
55pub struct IntrospectionResponse {
56    /// Whether the token is active
57    pub active: bool,
58
59    /// Token scopes
60    #[serde(default)]
61    pub scope: Option<String>,
62
63    /// Client ID
64    #[serde(default)]
65    pub client_id: Option<String>,
66
67    /// Username
68    #[serde(default)]
69    pub username: Option<String>,
70
71    /// Token type
72    #[serde(default)]
73    pub token_type: Option<String>,
74
75    /// Expiration time (Unix timestamp)
76    #[serde(default)]
77    pub exp: Option<i64>,
78
79    /// Issued at time (Unix timestamp)
80    #[serde(default)]
81    pub iat: Option<i64>,
82
83    /// Not before time (Unix timestamp)
84    #[serde(default)]
85    pub nbf: Option<i64>,
86
87    /// Subject
88    #[serde(default)]
89    pub sub: Option<String>,
90
91    /// Audience
92    #[serde(default)]
93    pub aud: Option<String>,
94
95    /// Issuer
96    #[serde(default)]
97    pub iss: Option<String>,
98
99    /// JWT ID
100    #[serde(default)]
101    pub jti: Option<String>,
102
103    /// Additional claims
104    #[serde(flatten)]
105    pub extra: HashMap<String, serde_json::Value>,
106}
107
108impl IntrospectionResponse {
109    /// Convert to Identity
110    pub fn to_identity(&self) -> Identity {
111        let roles = self
112            .scope
113            .as_ref()
114            .map(|s| s.split_whitespace().map(String::from).collect())
115            .unwrap_or_default();
116
117        Identity {
118            user_id: self
119                .sub
120                .clone()
121                .or_else(|| self.username.clone())
122                .unwrap_or_else(|| "unknown".to_string()),
123            name: self.username.clone(),
124            email: self
125                .extra
126                .get("email")
127                .and_then(|v| v.as_str())
128                .map(String::from),
129            roles,
130            groups: self
131                .extra
132                .get("groups")
133                .and_then(|v| v.as_array())
134                .map(|arr| {
135                    arr.iter()
136                        .filter_map(|v| v.as_str().map(String::from))
137                        .collect()
138                })
139                .unwrap_or_default(),
140            tenant_id: self
141                .extra
142                .get("tenant_id")
143                .and_then(|v| v.as_str())
144                .map(String::from),
145            claims: self.extra.clone(),
146            auth_method: "oauth".to_string(),
147            authenticated_at: chrono::Utc::now(),
148        }
149    }
150
151    /// Check if token is valid
152    pub fn is_valid(&self) -> bool {
153        if !self.active {
154            return false;
155        }
156
157        // Check expiration
158        if let Some(exp) = self.exp {
159            let now = chrono::Utc::now().timestamp();
160            if now > exp {
161                return false;
162            }
163        }
164
165        // Check not-before
166        if let Some(nbf) = self.nbf {
167            let now = chrono::Utc::now().timestamp();
168            if now < nbf {
169                return false;
170            }
171        }
172
173        true
174    }
175
176    /// Get scopes as a list
177    pub fn scopes(&self) -> Vec<String> {
178        self.scope
179            .as_ref()
180            .map(|s| s.split_whitespace().map(String::from).collect())
181            .unwrap_or_default()
182    }
183
184    /// Check if token has a specific scope
185    pub fn has_scope(&self, scope: &str) -> bool {
186        self.scopes().iter().any(|s| s == scope)
187    }
188}
189
190/// Token cache entry
191struct CachedToken {
192    response: IntrospectionResponse,
193    cached_at: Instant,
194}
195
196/// Token cache
197struct TokenCache {
198    entries: HashMap<String, CachedToken>,
199    max_size: usize,
200    ttl: Duration,
201}
202
203impl TokenCache {
204    fn new(max_size: usize, ttl: Duration) -> Self {
205        Self {
206            entries: HashMap::new(),
207            max_size,
208            ttl,
209        }
210    }
211
212    fn get(&self, token: &str) -> Option<&IntrospectionResponse> {
213        self.entries.get(token).and_then(|cached| {
214            if cached.cached_at.elapsed() < self.ttl {
215                Some(&cached.response)
216            } else {
217                None
218            }
219        })
220    }
221
222    fn insert(&mut self, token: String, response: IntrospectionResponse) {
223        if self.entries.len() >= self.max_size {
224            self.evict_expired();
225        }
226        self.entries.insert(
227            token,
228            CachedToken {
229                response,
230                cached_at: Instant::now(),
231            },
232        );
233    }
234
235    fn evict_expired(&mut self) {
236        self.entries
237            .retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
238    }
239
240    fn invalidate(&mut self, token: &str) {
241        self.entries.remove(token);
242    }
243
244    fn clear(&mut self) {
245        self.entries.clear();
246    }
247}
248
249impl OAuthClient {
250    /// Create a new OAuth client
251    pub fn new(config: OAuthConfig) -> Self {
252        let client_id = config.client_id.clone();
253        let client_secret = config.client_secret.clone();
254        let cache_ttl = config.cache_ttl;
255
256        Self {
257            config,
258            cache: Arc::new(RwLock::new(TokenCache::new(10000, cache_ttl))),
259            client_id,
260            client_secret,
261        }
262    }
263
264    /// Introspect a token
265    pub async fn introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
266        // Check cache first
267        if let Some(cached) = self.cache.read().get(token) {
268            if cached.is_valid() {
269                return Ok(cached.clone());
270            }
271        }
272
273        // Perform introspection
274        let response = self.do_introspect(token).await?;
275
276        // Validate response
277        if !response.active {
278            return Err(OAuthError::TokenNotActive);
279        }
280
281        if !response.is_valid() {
282            return Err(OAuthError::TokenExpired);
283        }
284
285        // Cache successful response
286        self.cache
287            .write()
288            .insert(token.to_string(), response.clone());
289
290        Ok(response)
291    }
292
293    /// Perform the actual introspection request
294    async fn do_introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
295        // In a real implementation, this would make an HTTP POST request to the
296        // introspection endpoint. For demonstration, we return a placeholder.
297        //
298        // Real implementation would look like:
299        //
300        // let response = reqwest::Client::new()
301        //     .post(&self.config.introspection_url)
302        //     .basic_auth(&self.client_id, Some(&self.client_secret))
303        //     .form(&[("token", token)])
304        //     .send()
305        //     .await
306        //     .map_err(|e| OAuthError::NetworkError(e.to_string()))?;
307        //
308        // let body: IntrospectionResponse = response
309        //     .json()
310        //     .await
311        //     .map_err(|e| OAuthError::InvalidResponse(e.to_string()))?;
312
313        // Placeholder: create a demo response
314        // In production, this would be the actual HTTP call
315        let _ = token; // Suppress unused warning
316
317        Ok(IntrospectionResponse {
318            active: true,
319            scope: Some("read write".to_string()),
320            client_id: Some(self.client_id.clone()),
321            username: Some("oauth_user".to_string()),
322            token_type: Some("Bearer".to_string()),
323            exp: Some(chrono::Utc::now().timestamp() + 3600),
324            iat: Some(chrono::Utc::now().timestamp()),
325            nbf: None,
326            sub: Some("user123".to_string()),
327            aud: self.config.audience.clone(),
328            iss: Some(self.config.issuer.clone()),
329            jti: Some("token-id-123".to_string()),
330            extra: HashMap::new(),
331        })
332    }
333
334    /// Validate a token and return identity
335    pub async fn validate_to_identity(&self, token: &str) -> Result<Identity, OAuthError> {
336        let response = self.introspect(token).await?;
337
338        // Check required scopes
339        if !self.config.required_scopes.is_empty() {
340            for scope in &self.config.required_scopes {
341                if !response.has_scope(scope) {
342                    return Err(OAuthError::InvalidScope);
343                }
344            }
345        }
346
347        Ok(response.to_identity())
348    }
349
350    /// Invalidate a cached token
351    pub fn invalidate_token(&self, token: &str) {
352        self.cache.write().invalidate(token);
353    }
354
355    /// Clear the token cache
356    pub fn clear_cache(&self) {
357        self.cache.write().clear();
358    }
359
360    /// Get cache statistics
361    pub fn cache_size(&self) -> usize {
362        self.cache.read().entries.len()
363    }
364
365    /// Get introspection URL
366    pub fn introspection_url(&self) -> &str {
367        &self.config.introspection_url
368    }
369
370    /// Get issuer
371    pub fn issuer(&self) -> &str {
372        &self.config.issuer
373    }
374}
375
376/// OAuth token exchange
377pub struct TokenExchange {
378    /// Configuration
379    config: OAuthConfig,
380}
381
382impl TokenExchange {
383    /// Create a new token exchange
384    pub fn new(config: OAuthConfig) -> Self {
385        Self { config }
386    }
387
388    /// Exchange an authorization code for tokens
389    pub async fn exchange_code(
390        &self,
391        code: &str,
392        redirect_uri: &str,
393    ) -> Result<TokenResponse, OAuthError> {
394        // In a real implementation, this would make an HTTP POST to the token endpoint
395        // For demonstration, return a placeholder
396        let _ = (code, redirect_uri);
397
398        Ok(TokenResponse {
399            access_token: "access_token_placeholder".to_string(),
400            token_type: "Bearer".to_string(),
401            expires_in: Some(3600),
402            refresh_token: Some("refresh_token_placeholder".to_string()),
403            scope: Some("read write".to_string()),
404            id_token: None,
405        })
406    }
407
408    /// Refresh an access token
409    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, OAuthError> {
410        // In a real implementation, this would make an HTTP POST to the token endpoint
411        let _ = refresh_token;
412
413        Ok(TokenResponse {
414            access_token: "new_access_token".to_string(),
415            token_type: "Bearer".to_string(),
416            expires_in: Some(3600),
417            refresh_token: Some("new_refresh_token".to_string()),
418            scope: Some("read write".to_string()),
419            id_token: None,
420        })
421    }
422
423    /// Get authorization URL
424    pub fn authorization_url(&self, state: &str, scopes: &[&str]) -> String {
425        let scope = scopes.join(" ");
426        format!(
427            "{}?response_type=code&client_id={}&state={}&scope={}",
428            self.config.authorization_url.as_deref().unwrap_or(""),
429            self.config.client_id,
430            state,
431            urlencoding::encode(&scope),
432        )
433    }
434}
435
436/// Token response from OAuth server
437#[derive(Debug, Clone, serde::Deserialize)]
438pub struct TokenResponse {
439    /// Access token
440    pub access_token: String,
441
442    /// Token type (usually "Bearer")
443    pub token_type: String,
444
445    /// Expires in seconds
446    pub expires_in: Option<u64>,
447
448    /// Refresh token
449    pub refresh_token: Option<String>,
450
451    /// Granted scopes
452    pub scope: Option<String>,
453
454    /// ID token (for OpenID Connect)
455    pub id_token: Option<String>,
456}
457
458/// URL encoding module
459mod urlencoding {
460    pub fn encode(s: &str) -> String {
461        let mut result = String::new();
462        for c in s.chars() {
463            match c {
464                'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => {
465                    result.push(c);
466                }
467                ' ' => {
468                    result.push_str("%20");
469                }
470                _ => {
471                    for byte in c.to_string().as_bytes() {
472                        result.push_str(&format!("%{:02X}", byte));
473                    }
474                }
475            }
476        }
477        result
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use std::time::Duration;
485
486    fn test_config() -> OAuthConfig {
487        OAuthConfig {
488            introspection_url: "https://auth.example.com/introspect".to_string(),
489            client_id: "test-client".to_string(),
490            client_secret: "test-secret".to_string(),
491            issuer: "https://auth.example.com".to_string(),
492            audience: Some("test-api".to_string()),
493            required_scopes: vec!["read".to_string()],
494            scopes: Vec::new(),
495            cache_ttl: Duration::from_secs(60),
496            authorization_url: Some("https://auth.example.com/authorize".to_string()),
497            token_url: Some("https://auth.example.com/token".to_string()),
498        }
499    }
500
501    #[test]
502    fn test_introspection_response_validity() {
503        let response = IntrospectionResponse {
504            active: true,
505            scope: Some("read write".to_string()),
506            client_id: None,
507            username: Some("testuser".to_string()),
508            token_type: None,
509            exp: Some(chrono::Utc::now().timestamp() + 3600),
510            iat: None,
511            nbf: None,
512            sub: Some("user123".to_string()),
513            aud: None,
514            iss: None,
515            jti: None,
516            extra: HashMap::new(),
517        };
518
519        assert!(response.is_valid());
520        assert!(response.has_scope("read"));
521        assert!(response.has_scope("write"));
522        assert!(!response.has_scope("admin"));
523    }
524
525    #[test]
526    fn test_introspection_response_expired() {
527        let response = IntrospectionResponse {
528            active: true,
529            scope: None,
530            client_id: None,
531            username: None,
532            token_type: None,
533            exp: Some(chrono::Utc::now().timestamp() - 3600), // Expired
534            iat: None,
535            nbf: None,
536            sub: None,
537            aud: None,
538            iss: None,
539            jti: None,
540            extra: HashMap::new(),
541        };
542
543        assert!(!response.is_valid());
544    }
545
546    #[test]
547    fn test_introspection_response_inactive() {
548        let response = IntrospectionResponse {
549            active: false,
550            scope: None,
551            client_id: None,
552            username: None,
553            token_type: None,
554            exp: None,
555            iat: None,
556            nbf: None,
557            sub: None,
558            aud: None,
559            iss: None,
560            jti: None,
561            extra: HashMap::new(),
562        };
563
564        assert!(!response.is_valid());
565    }
566
567    #[test]
568    fn test_introspection_to_identity() {
569        let mut extra = HashMap::new();
570        extra.insert("email".to_string(), serde_json::json!("test@example.com"));
571        extra.insert("tenant_id".to_string(), serde_json::json!("tenant1"));
572
573        let response = IntrospectionResponse {
574            active: true,
575            scope: Some("read write".to_string()),
576            client_id: None,
577            username: Some("testuser".to_string()),
578            token_type: None,
579            exp: None,
580            iat: None,
581            nbf: None,
582            sub: Some("user123".to_string()),
583            aud: None,
584            iss: None,
585            jti: None,
586            extra,
587        };
588
589        let identity = response.to_identity();
590        assert_eq!(identity.user_id, "user123");
591        assert_eq!(identity.name, Some("testuser".to_string()));
592        assert_eq!(identity.email, Some("test@example.com".to_string()));
593        assert_eq!(identity.tenant_id, Some("tenant1".to_string()));
594        assert!(identity.roles.contains(&"read".to_string()));
595    }
596
597    #[tokio::test]
598    async fn test_oauth_client_introspect() {
599        let client = OAuthClient::new(test_config());
600        let result = client.introspect("test_token").await.unwrap();
601
602        assert!(result.active);
603        assert!(result.is_valid());
604    }
605
606    #[tokio::test]
607    async fn test_oauth_client_cache() {
608        let client = OAuthClient::new(test_config());
609
610        // First call caches
611        let _ = client.introspect("test_token").await.unwrap();
612        assert_eq!(client.cache_size(), 1);
613
614        // Second call uses cache
615        let _ = client.introspect("test_token").await.unwrap();
616        assert_eq!(client.cache_size(), 1);
617
618        // Different token adds to cache
619        let _ = client.introspect("another_token").await.unwrap();
620        assert_eq!(client.cache_size(), 2);
621
622        // Clear cache
623        client.clear_cache();
624        assert_eq!(client.cache_size(), 0);
625    }
626
627    #[test]
628    fn test_authorization_url() {
629        let exchange = TokenExchange::new(test_config());
630        let url = exchange.authorization_url("state123", &["read", "write"]);
631
632        assert!(url.contains("response_type=code"));
633        assert!(url.contains("client_id=test-client"));
634        assert!(url.contains("state=state123"));
635    }
636
637    #[test]
638    fn test_url_encoding() {
639        assert_eq!(urlencoding::encode("hello world"), "hello%20world");
640        assert_eq!(urlencoding::encode("test-value"), "test-value");
641        assert_eq!(urlencoding::encode("a=b&c=d"), "a%3Db%26c%3Dd");
642    }
643}