Skip to main content

heliosdb_proxy/auth/
oauth.rs

1//! OAuth Token Introspection
2//!
3//! Validates OAuth access tokens using RFC 7662 token introspection.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, OAuthConfig};
13
14/// OAuth errors
15#[derive(Debug, Error)]
16pub enum OAuthError {
17    #[error("Token introspection failed: {0}")]
18    IntrospectionFailed(String),
19
20    #[error("Token is not active")]
21    TokenNotActive,
22
23    #[error("Token expired")]
24    TokenExpired,
25
26    #[error("Invalid token scope")]
27    InvalidScope,
28
29    #[error("Network error: {0}")]
30    NetworkError(String),
31
32    #[error("Invalid response: {0}")]
33    InvalidResponse(String),
34
35    #[error("Configuration error: {0}")]
36    ConfigurationError(String),
37}
38
39/// OAuth client for token introspection
40pub struct OAuthClient {
41    /// Configuration
42    config: OAuthConfig,
43
44    /// Token cache
45    cache: Arc<RwLock<TokenCache>>,
46
47    /// HTTP client (placeholder - would use reqwest in real impl)
48    client_id: String,
49    client_secret: String,
50}
51
52/// Token introspection response
53#[derive(Debug, Clone, serde::Deserialize)]
54pub struct IntrospectionResponse {
55    /// Whether the token is active
56    pub active: bool,
57
58    /// Token scopes
59    #[serde(default)]
60    pub scope: Option<String>,
61
62    /// Client ID
63    #[serde(default)]
64    pub client_id: Option<String>,
65
66    /// Username
67    #[serde(default)]
68    pub username: Option<String>,
69
70    /// Token type
71    #[serde(default)]
72    pub token_type: Option<String>,
73
74    /// Expiration time (Unix timestamp)
75    #[serde(default)]
76    pub exp: Option<i64>,
77
78    /// Issued at time (Unix timestamp)
79    #[serde(default)]
80    pub iat: Option<i64>,
81
82    /// Not before time (Unix timestamp)
83    #[serde(default)]
84    pub nbf: Option<i64>,
85
86    /// Subject
87    #[serde(default)]
88    pub sub: Option<String>,
89
90    /// Audience
91    #[serde(default)]
92    pub aud: Option<String>,
93
94    /// Issuer
95    #[serde(default)]
96    pub iss: Option<String>,
97
98    /// JWT ID
99    #[serde(default)]
100    pub jti: Option<String>,
101
102    /// Additional claims
103    #[serde(flatten)]
104    pub extra: HashMap<String, serde_json::Value>,
105}
106
107impl IntrospectionResponse {
108    /// Convert to Identity
109    pub fn to_identity(&self) -> Identity {
110        let roles = self.scope
111            .as_ref()
112            .map(|s| s.split_whitespace().map(String::from).collect())
113            .unwrap_or_default();
114
115        Identity {
116            user_id: self.sub.clone()
117                .or_else(|| self.username.clone())
118                .unwrap_or_else(|| "unknown".to_string()),
119            name: self.username.clone(),
120            email: self.extra.get("email")
121                .and_then(|v| v.as_str())
122                .map(String::from),
123            roles,
124            groups: self.extra.get("groups")
125                .and_then(|v| v.as_array())
126                .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
127                .unwrap_or_default(),
128            tenant_id: self.extra.get("tenant_id")
129                .and_then(|v| v.as_str())
130                .map(String::from),
131            claims: self.extra.clone(),
132            auth_method: "oauth".to_string(),
133            authenticated_at: chrono::Utc::now(),
134        }
135    }
136
137    /// Check if token is valid
138    pub fn is_valid(&self) -> bool {
139        if !self.active {
140            return false;
141        }
142
143        // Check expiration
144        if let Some(exp) = self.exp {
145            let now = chrono::Utc::now().timestamp();
146            if now > exp {
147                return false;
148            }
149        }
150
151        // Check not-before
152        if let Some(nbf) = self.nbf {
153            let now = chrono::Utc::now().timestamp();
154            if now < nbf {
155                return false;
156            }
157        }
158
159        true
160    }
161
162    /// Get scopes as a list
163    pub fn scopes(&self) -> Vec<String> {
164        self.scope
165            .as_ref()
166            .map(|s| s.split_whitespace().map(String::from).collect())
167            .unwrap_or_default()
168    }
169
170    /// Check if token has a specific scope
171    pub fn has_scope(&self, scope: &str) -> bool {
172        self.scopes().iter().any(|s| s == scope)
173    }
174}
175
176/// Token cache entry
177struct CachedToken {
178    response: IntrospectionResponse,
179    cached_at: Instant,
180}
181
182/// Token cache
183struct TokenCache {
184    entries: HashMap<String, CachedToken>,
185    max_size: usize,
186    ttl: Duration,
187}
188
189impl TokenCache {
190    fn new(max_size: usize, ttl: Duration) -> Self {
191        Self {
192            entries: HashMap::new(),
193            max_size,
194            ttl,
195        }
196    }
197
198    fn get(&self, token: &str) -> Option<&IntrospectionResponse> {
199        self.entries.get(token).and_then(|cached| {
200            if cached.cached_at.elapsed() < self.ttl {
201                Some(&cached.response)
202            } else {
203                None
204            }
205        })
206    }
207
208    fn insert(&mut self, token: String, response: IntrospectionResponse) {
209        if self.entries.len() >= self.max_size {
210            self.evict_expired();
211        }
212        self.entries.insert(token, CachedToken {
213            response,
214            cached_at: Instant::now(),
215        });
216    }
217
218    fn evict_expired(&mut self) {
219        self.entries.retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
220    }
221
222    fn invalidate(&mut self, token: &str) {
223        self.entries.remove(token);
224    }
225
226    fn clear(&mut self) {
227        self.entries.clear();
228    }
229}
230
231impl OAuthClient {
232    /// Create a new OAuth client
233    pub fn new(config: OAuthConfig) -> Self {
234        let client_id = config.client_id.clone();
235        let client_secret = config.client_secret.clone();
236        let cache_ttl = config.cache_ttl;
237
238        Self {
239            config,
240            cache: Arc::new(RwLock::new(TokenCache::new(10000, cache_ttl))),
241            client_id,
242            client_secret,
243        }
244    }
245
246    /// Introspect a token
247    pub async fn introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
248        // Check cache first
249        if let Some(cached) = self.cache.read().get(token) {
250            if cached.is_valid() {
251                return Ok(cached.clone());
252            }
253        }
254
255        // Perform introspection
256        let response = self.do_introspect(token).await?;
257
258        // Validate response
259        if !response.active {
260            return Err(OAuthError::TokenNotActive);
261        }
262
263        if !response.is_valid() {
264            return Err(OAuthError::TokenExpired);
265        }
266
267        // Cache successful response
268        self.cache.write().insert(token.to_string(), response.clone());
269
270        Ok(response)
271    }
272
273    /// Perform the actual introspection request
274    async fn do_introspect(&self, token: &str) -> Result<IntrospectionResponse, OAuthError> {
275        // In a real implementation, this would make an HTTP POST request to the
276        // introspection endpoint. For demonstration, we return a placeholder.
277        //
278        // Real implementation would look like:
279        //
280        // let response = reqwest::Client::new()
281        //     .post(&self.config.introspection_url)
282        //     .basic_auth(&self.client_id, Some(&self.client_secret))
283        //     .form(&[("token", token)])
284        //     .send()
285        //     .await
286        //     .map_err(|e| OAuthError::NetworkError(e.to_string()))?;
287        //
288        // let body: IntrospectionResponse = response
289        //     .json()
290        //     .await
291        //     .map_err(|e| OAuthError::InvalidResponse(e.to_string()))?;
292
293        // Placeholder: create a demo response
294        // In production, this would be the actual HTTP call
295        let _ = token; // Suppress unused warning
296
297        Ok(IntrospectionResponse {
298            active: true,
299            scope: Some("read write".to_string()),
300            client_id: Some(self.client_id.clone()),
301            username: Some("oauth_user".to_string()),
302            token_type: Some("Bearer".to_string()),
303            exp: Some(chrono::Utc::now().timestamp() + 3600),
304            iat: Some(chrono::Utc::now().timestamp()),
305            nbf: None,
306            sub: Some("user123".to_string()),
307            aud: self.config.audience.clone(),
308            iss: Some(self.config.issuer.clone()),
309            jti: Some("token-id-123".to_string()),
310            extra: HashMap::new(),
311        })
312    }
313
314    /// Validate a token and return identity
315    pub async fn validate_to_identity(&self, token: &str) -> Result<Identity, OAuthError> {
316        let response = self.introspect(token).await?;
317
318        // Check required scopes
319        if !self.config.required_scopes.is_empty() {
320            for scope in &self.config.required_scopes {
321                if !response.has_scope(scope) {
322                    return Err(OAuthError::InvalidScope);
323                }
324            }
325        }
326
327        Ok(response.to_identity())
328    }
329
330    /// Invalidate a cached token
331    pub fn invalidate_token(&self, token: &str) {
332        self.cache.write().invalidate(token);
333    }
334
335    /// Clear the token cache
336    pub fn clear_cache(&self) {
337        self.cache.write().clear();
338    }
339
340    /// Get cache statistics
341    pub fn cache_size(&self) -> usize {
342        self.cache.read().entries.len()
343    }
344
345    /// Get introspection URL
346    pub fn introspection_url(&self) -> &str {
347        &self.config.introspection_url
348    }
349
350    /// Get issuer
351    pub fn issuer(&self) -> &str {
352        &self.config.issuer
353    }
354}
355
356/// OAuth token exchange
357pub struct TokenExchange {
358    /// Configuration
359    config: OAuthConfig,
360}
361
362impl TokenExchange {
363    /// Create a new token exchange
364    pub fn new(config: OAuthConfig) -> Self {
365        Self { config }
366    }
367
368    /// Exchange an authorization code for tokens
369    pub async fn exchange_code(
370        &self,
371        code: &str,
372        redirect_uri: &str,
373    ) -> Result<TokenResponse, OAuthError> {
374        // In a real implementation, this would make an HTTP POST to the token endpoint
375        // For demonstration, return a placeholder
376        let _ = (code, redirect_uri);
377
378        Ok(TokenResponse {
379            access_token: "access_token_placeholder".to_string(),
380            token_type: "Bearer".to_string(),
381            expires_in: Some(3600),
382            refresh_token: Some("refresh_token_placeholder".to_string()),
383            scope: Some("read write".to_string()),
384            id_token: None,
385        })
386    }
387
388    /// Refresh an access token
389    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, OAuthError> {
390        // In a real implementation, this would make an HTTP POST to the token endpoint
391        let _ = refresh_token;
392
393        Ok(TokenResponse {
394            access_token: "new_access_token".to_string(),
395            token_type: "Bearer".to_string(),
396            expires_in: Some(3600),
397            refresh_token: Some("new_refresh_token".to_string()),
398            scope: Some("read write".to_string()),
399            id_token: None,
400        })
401    }
402
403    /// Get authorization URL
404    pub fn authorization_url(&self, state: &str, scopes: &[&str]) -> String {
405        let scope = scopes.join(" ");
406        format!(
407            "{}?response_type=code&client_id={}&state={}&scope={}",
408            self.config.authorization_url
409                .as_deref()
410                .unwrap_or(""),
411            self.config.client_id,
412            state,
413            urlencoding::encode(&scope),
414        )
415    }
416}
417
418/// Token response from OAuth server
419#[derive(Debug, Clone, serde::Deserialize)]
420pub struct TokenResponse {
421    /// Access token
422    pub access_token: String,
423
424    /// Token type (usually "Bearer")
425    pub token_type: String,
426
427    /// Expires in seconds
428    pub expires_in: Option<u64>,
429
430    /// Refresh token
431    pub refresh_token: Option<String>,
432
433    /// Granted scopes
434    pub scope: Option<String>,
435
436    /// ID token (for OpenID Connect)
437    pub id_token: Option<String>,
438}
439
440/// URL encoding module
441mod urlencoding {
442    pub fn encode(s: &str) -> String {
443        let mut result = String::new();
444        for c in s.chars() {
445            match c {
446                'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' | '.' | '~' => {
447                    result.push(c);
448                }
449                ' ' => {
450                    result.push_str("%20");
451                }
452                _ => {
453                    for byte in c.to_string().as_bytes() {
454                        result.push_str(&format!("%{:02X}", byte));
455                    }
456                }
457            }
458        }
459        result
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use std::time::Duration;
467
468    fn test_config() -> OAuthConfig {
469        OAuthConfig {
470            introspection_url: "https://auth.example.com/introspect".to_string(),
471            client_id: "test-client".to_string(),
472            client_secret: "test-secret".to_string(),
473            issuer: "https://auth.example.com".to_string(),
474            audience: Some("test-api".to_string()),
475            required_scopes: vec!["read".to_string()],
476            scopes: Vec::new(),
477            cache_ttl: Duration::from_secs(60),
478            authorization_url: Some("https://auth.example.com/authorize".to_string()),
479            token_url: Some("https://auth.example.com/token".to_string()),
480        }
481    }
482
483    #[test]
484    fn test_introspection_response_validity() {
485        let response = IntrospectionResponse {
486            active: true,
487            scope: Some("read write".to_string()),
488            client_id: None,
489            username: Some("testuser".to_string()),
490            token_type: None,
491            exp: Some(chrono::Utc::now().timestamp() + 3600),
492            iat: None,
493            nbf: None,
494            sub: Some("user123".to_string()),
495            aud: None,
496            iss: None,
497            jti: None,
498            extra: HashMap::new(),
499        };
500
501        assert!(response.is_valid());
502        assert!(response.has_scope("read"));
503        assert!(response.has_scope("write"));
504        assert!(!response.has_scope("admin"));
505    }
506
507    #[test]
508    fn test_introspection_response_expired() {
509        let response = IntrospectionResponse {
510            active: true,
511            scope: None,
512            client_id: None,
513            username: None,
514            token_type: None,
515            exp: Some(chrono::Utc::now().timestamp() - 3600), // Expired
516            iat: None,
517            nbf: None,
518            sub: None,
519            aud: None,
520            iss: None,
521            jti: None,
522            extra: HashMap::new(),
523        };
524
525        assert!(!response.is_valid());
526    }
527
528    #[test]
529    fn test_introspection_response_inactive() {
530        let response = IntrospectionResponse {
531            active: false,
532            scope: None,
533            client_id: None,
534            username: None,
535            token_type: None,
536            exp: None,
537            iat: None,
538            nbf: None,
539            sub: None,
540            aud: None,
541            iss: None,
542            jti: None,
543            extra: HashMap::new(),
544        };
545
546        assert!(!response.is_valid());
547    }
548
549    #[test]
550    fn test_introspection_to_identity() {
551        let mut extra = HashMap::new();
552        extra.insert("email".to_string(), serde_json::json!("test@example.com"));
553        extra.insert("tenant_id".to_string(), serde_json::json!("tenant1"));
554
555        let response = IntrospectionResponse {
556            active: true,
557            scope: Some("read write".to_string()),
558            client_id: None,
559            username: Some("testuser".to_string()),
560            token_type: None,
561            exp: None,
562            iat: None,
563            nbf: None,
564            sub: Some("user123".to_string()),
565            aud: None,
566            iss: None,
567            jti: None,
568            extra,
569        };
570
571        let identity = response.to_identity();
572        assert_eq!(identity.user_id, "user123");
573        assert_eq!(identity.name, Some("testuser".to_string()));
574        assert_eq!(identity.email, Some("test@example.com".to_string()));
575        assert_eq!(identity.tenant_id, Some("tenant1".to_string()));
576        assert!(identity.roles.contains(&"read".to_string()));
577    }
578
579    #[tokio::test]
580    async fn test_oauth_client_introspect() {
581        let client = OAuthClient::new(test_config());
582        let result = client.introspect("test_token").await.unwrap();
583
584        assert!(result.active);
585        assert!(result.is_valid());
586    }
587
588    #[tokio::test]
589    async fn test_oauth_client_cache() {
590        let client = OAuthClient::new(test_config());
591
592        // First call caches
593        let _ = client.introspect("test_token").await.unwrap();
594        assert_eq!(client.cache_size(), 1);
595
596        // Second call uses cache
597        let _ = client.introspect("test_token").await.unwrap();
598        assert_eq!(client.cache_size(), 1);
599
600        // Different token adds to cache
601        let _ = client.introspect("another_token").await.unwrap();
602        assert_eq!(client.cache_size(), 2);
603
604        // Clear cache
605        client.clear_cache();
606        assert_eq!(client.cache_size(), 0);
607    }
608
609    #[test]
610    fn test_authorization_url() {
611        let exchange = TokenExchange::new(test_config());
612        let url = exchange.authorization_url("state123", &["read", "write"]);
613
614        assert!(url.contains("response_type=code"));
615        assert!(url.contains("client_id=test-client"));
616        assert!(url.contains("state=state123"));
617    }
618
619    #[test]
620    fn test_url_encoding() {
621        assert_eq!(urlencoding::encode("hello world"), "hello%20world");
622        assert_eq!(urlencoding::encode("test-value"), "test-value");
623        assert_eq!(urlencoding::encode("a=b&c=d"), "a%3Db%26c%3Dd");
624    }
625}