corevpn_auth/
token.rs

1//! Token Management and Validation
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::{AuthError, Result};
8
9/// OAuth2/OIDC Token Set
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TokenSet {
12    /// Access token
13    pub access_token: String,
14    /// Refresh token (optional)
15    pub refresh_token: Option<String>,
16    /// ID token (for OIDC)
17    pub id_token: Option<String>,
18    /// Token expiration time
19    pub expires_at: DateTime<Utc>,
20    /// Token type (usually "Bearer")
21    pub token_type: String,
22    /// Granted scopes
23    pub scopes: Vec<String>,
24}
25
26impl TokenSet {
27    /// Check if access token is expired
28    pub fn is_expired(&self) -> bool {
29        Utc::now() > self.expires_at
30    }
31
32    /// Check if access token will expire within given duration
33    pub fn expires_within(&self, duration: chrono::Duration) -> bool {
34        Utc::now() + duration > self.expires_at
35    }
36
37    /// Get remaining lifetime
38    pub fn remaining_lifetime(&self) -> chrono::Duration {
39        self.expires_at - Utc::now()
40    }
41}
42
43/// Claims extracted from ID token
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct IdTokenClaims {
46    /// Issuer
47    pub iss: String,
48    /// Subject (user ID)
49    pub sub: String,
50    /// Audience
51    pub aud: StringOrArray,
52    /// Expiration time
53    pub exp: i64,
54    /// Issued at time
55    pub iat: i64,
56    /// Nonce
57    #[serde(default)]
58    pub nonce: Option<String>,
59    /// Email
60    #[serde(default)]
61    pub email: Option<String>,
62    /// Email verified
63    #[serde(default)]
64    pub email_verified: Option<bool>,
65    /// Name
66    #[serde(default)]
67    pub name: Option<String>,
68    /// Given name
69    #[serde(default)]
70    pub given_name: Option<String>,
71    /// Family name
72    #[serde(default)]
73    pub family_name: Option<String>,
74    /// Picture URL
75    #[serde(default)]
76    pub picture: Option<String>,
77    /// Groups
78    #[serde(default)]
79    pub groups: Vec<String>,
80    /// Additional claims
81    #[serde(flatten)]
82    pub additional: HashMap<String, serde_json::Value>,
83}
84
85/// String or array of strings (for audience claim)
86#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum StringOrArray {
89    /// Single string
90    String(String),
91    /// Array of strings
92    Array(Vec<String>),
93}
94
95impl StringOrArray {
96    /// Check if contains a value
97    pub fn contains(&self, value: &str) -> bool {
98        match self {
99            StringOrArray::String(s) => s == value,
100            StringOrArray::Array(arr) => arr.iter().any(|s| s == value),
101        }
102    }
103}
104
105/// Token validator
106pub struct TokenValidator {
107    /// Expected issuer
108    issuer: String,
109    /// Expected audience (client ID)
110    audience: String,
111    /// Clock skew tolerance (in seconds)
112    clock_skew: i64,
113}
114
115impl TokenValidator {
116    /// Create a new token validator
117    pub fn new(issuer: &str, audience: &str) -> Self {
118        Self {
119            issuer: issuer.to_string(),
120            audience: audience.to_string(),
121            clock_skew: 60, // 1 minute tolerance
122        }
123    }
124
125    /// Set clock skew tolerance
126    pub fn with_clock_skew(mut self, seconds: i64) -> Self {
127        self.clock_skew = seconds;
128        self
129    }
130
131    /// Validate ID token claims (without cryptographic verification)
132    ///
133    /// Note: For production, you should also verify the JWT signature
134    /// using the provider's JWKS.
135    pub fn validate_claims(&self, claims: &IdTokenClaims, expected_nonce: Option<&str>) -> Result<()> {
136        // Check issuer
137        if claims.iss != self.issuer {
138            return Err(AuthError::TokenValidationFailed(format!(
139                "invalid issuer: expected {}, got {}",
140                self.issuer, claims.iss
141            )));
142        }
143
144        // Check audience
145        if !claims.aud.contains(&self.audience) {
146            return Err(AuthError::TokenValidationFailed(
147                "token audience mismatch".into(),
148            ));
149        }
150
151        // Check expiration
152        let now = Utc::now().timestamp();
153        if claims.exp < now - self.clock_skew {
154            return Err(AuthError::TokenExpired);
155        }
156
157        // Check issued at (not in the future)
158        if claims.iat > now + self.clock_skew {
159            return Err(AuthError::TokenValidationFailed(
160                "token issued in the future".into(),
161            ));
162        }
163
164        // Check nonce if provided
165        if let Some(expected) = expected_nonce {
166            if claims.nonce.as_deref() != Some(expected) {
167                return Err(AuthError::InvalidNonce);
168            }
169        }
170
171        Ok(())
172    }
173
174    /// Decode JWT without verification (for extracting claims)
175    ///
176    /// WARNING: This does not verify the signature. For production use,
177    /// you should use a proper JWT library that verifies signatures.
178    pub fn decode_jwt_claims(token: &str) -> Result<IdTokenClaims> {
179        use base64::Engine;
180
181        let parts: Vec<&str> = token.split('.').collect();
182        if parts.len() != 3 {
183            return Err(AuthError::TokenValidationFailed("invalid JWT format".into()));
184        }
185
186        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
187            .decode(parts[1])
188            .map_err(|e| AuthError::TokenValidationFailed(format!("base64 decode error: {}", e)))?;
189
190        let claims: IdTokenClaims = serde_json::from_slice(&payload)?;
191
192        Ok(claims)
193    }
194}
195
196/// User information extracted from tokens
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct UserInfo {
199    /// Subject (user ID from provider)
200    pub sub: String,
201    /// Email
202    pub email: Option<String>,
203    /// Email verified
204    pub email_verified: bool,
205    /// Display name
206    pub name: Option<String>,
207    /// Given name
208    pub given_name: Option<String>,
209    /// Family name
210    pub family_name: Option<String>,
211    /// Picture URL
212    pub picture: Option<String>,
213    /// Groups
214    pub groups: Vec<String>,
215    /// Provider type
216    pub provider: String,
217}
218
219impl UserInfo {
220    /// Extract from ID token claims
221    pub fn from_claims(claims: &IdTokenClaims, provider: &str) -> Self {
222        Self {
223            sub: claims.sub.clone(),
224            email: claims.email.clone(),
225            email_verified: claims.email_verified.unwrap_or(false),
226            name: claims.name.clone(),
227            given_name: claims.given_name.clone(),
228            family_name: claims.family_name.clone(),
229            picture: claims.picture.clone(),
230            groups: claims.groups.clone(),
231            provider: provider.to_string(),
232        }
233    }
234
235    /// Get email domain
236    pub fn email_domain(&self) -> Option<&str> {
237        self.email.as_ref().and_then(|e| e.split('@').nth(1))
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_token_expiration() {
247        let token = TokenSet {
248            access_token: "test".to_string(),
249            refresh_token: None,
250            id_token: None,
251            expires_at: Utc::now() + chrono::Duration::hours(1),
252            token_type: "Bearer".to_string(),
253            scopes: vec![],
254        };
255
256        assert!(!token.is_expired());
257        assert!(!token.expires_within(chrono::Duration::minutes(30)));
258        assert!(token.expires_within(chrono::Duration::hours(2)));
259    }
260
261    #[test]
262    fn test_string_or_array() {
263        let single = StringOrArray::String("test".to_string());
264        assert!(single.contains("test"));
265        assert!(!single.contains("other"));
266
267        let array = StringOrArray::Array(vec!["one".to_string(), "two".to_string()]);
268        assert!(array.contains("one"));
269        assert!(array.contains("two"));
270        assert!(!array.contains("three"));
271    }
272
273    #[test]
274    fn test_claim_validation() {
275        let validator = TokenValidator::new("https://accounts.google.com", "client-id");
276
277        let claims = IdTokenClaims {
278            iss: "https://accounts.google.com".to_string(),
279            sub: "user123".to_string(),
280            aud: StringOrArray::String("client-id".to_string()),
281            exp: Utc::now().timestamp() + 3600,
282            iat: Utc::now().timestamp(),
283            nonce: Some("test-nonce".to_string()),
284            email: Some("user@example.com".to_string()),
285            email_verified: Some(true),
286            name: Some("Test User".to_string()),
287            given_name: None,
288            family_name: None,
289            picture: None,
290            groups: vec![],
291            additional: HashMap::new(),
292        };
293
294        assert!(validator.validate_claims(&claims, Some("test-nonce")).is_ok());
295        assert!(validator.validate_claims(&claims, Some("wrong-nonce")).is_err());
296    }
297}