1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::{AuthError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TokenSet {
12 pub access_token: String,
14 pub refresh_token: Option<String>,
16 pub id_token: Option<String>,
18 pub expires_at: DateTime<Utc>,
20 pub token_type: String,
22 pub scopes: Vec<String>,
24}
25
26impl TokenSet {
27 pub fn is_expired(&self) -> bool {
29 Utc::now() > self.expires_at
30 }
31
32 pub fn expires_within(&self, duration: chrono::Duration) -> bool {
34 Utc::now() + duration > self.expires_at
35 }
36
37 pub fn remaining_lifetime(&self) -> chrono::Duration {
39 self.expires_at - Utc::now()
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct IdTokenClaims {
46 pub iss: String,
48 pub sub: String,
50 pub aud: StringOrArray,
52 pub exp: i64,
54 pub iat: i64,
56 #[serde(default)]
58 pub nonce: Option<String>,
59 #[serde(default)]
61 pub email: Option<String>,
62 #[serde(default)]
64 pub email_verified: Option<bool>,
65 #[serde(default)]
67 pub name: Option<String>,
68 #[serde(default)]
70 pub given_name: Option<String>,
71 #[serde(default)]
73 pub family_name: Option<String>,
74 #[serde(default)]
76 pub picture: Option<String>,
77 #[serde(default)]
79 pub groups: Vec<String>,
80 #[serde(flatten)]
82 pub additional: HashMap<String, serde_json::Value>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum StringOrArray {
89 String(String),
91 Array(Vec<String>),
93}
94
95impl StringOrArray {
96 pub fn contains(&self, value: &str) -> bool {
98 match self {
99 StringOrArray::String(s) => s == value,
100 StringOrArray::Array(arr) => arr.iter().any(|s| s == value),
101 }
102 }
103}
104
105pub struct TokenValidator {
107 issuer: String,
109 audience: String,
111 clock_skew: i64,
113}
114
115impl TokenValidator {
116 pub fn new(issuer: &str, audience: &str) -> Self {
118 Self {
119 issuer: issuer.to_string(),
120 audience: audience.to_string(),
121 clock_skew: 60, }
123 }
124
125 pub fn with_clock_skew(mut self, seconds: i64) -> Self {
127 self.clock_skew = seconds;
128 self
129 }
130
131 pub fn validate_claims(&self, claims: &IdTokenClaims, expected_nonce: Option<&str>) -> Result<()> {
136 if claims.iss != self.issuer {
138 return Err(AuthError::TokenValidationFailed(format!(
139 "invalid issuer: expected {}, got {}",
140 self.issuer, claims.iss
141 )));
142 }
143
144 if !claims.aud.contains(&self.audience) {
146 return Err(AuthError::TokenValidationFailed(
147 "token audience mismatch".into(),
148 ));
149 }
150
151 let now = Utc::now().timestamp();
153 if claims.exp < now - self.clock_skew {
154 return Err(AuthError::TokenExpired);
155 }
156
157 if claims.iat > now + self.clock_skew {
159 return Err(AuthError::TokenValidationFailed(
160 "token issued in the future".into(),
161 ));
162 }
163
164 if let Some(expected) = expected_nonce {
166 if claims.nonce.as_deref() != Some(expected) {
167 return Err(AuthError::InvalidNonce);
168 }
169 }
170
171 Ok(())
172 }
173
174 pub fn decode_jwt_claims(token: &str) -> Result<IdTokenClaims> {
179 use base64::Engine;
180
181 let parts: Vec<&str> = token.split('.').collect();
182 if parts.len() != 3 {
183 return Err(AuthError::TokenValidationFailed("invalid JWT format".into()));
184 }
185
186 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
187 .decode(parts[1])
188 .map_err(|e| AuthError::TokenValidationFailed(format!("base64 decode error: {}", e)))?;
189
190 let claims: IdTokenClaims = serde_json::from_slice(&payload)?;
191
192 Ok(claims)
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct UserInfo {
199 pub sub: String,
201 pub email: Option<String>,
203 pub email_verified: bool,
205 pub name: Option<String>,
207 pub given_name: Option<String>,
209 pub family_name: Option<String>,
211 pub picture: Option<String>,
213 pub groups: Vec<String>,
215 pub provider: String,
217}
218
219impl UserInfo {
220 pub fn from_claims(claims: &IdTokenClaims, provider: &str) -> Self {
222 Self {
223 sub: claims.sub.clone(),
224 email: claims.email.clone(),
225 email_verified: claims.email_verified.unwrap_or(false),
226 name: claims.name.clone(),
227 given_name: claims.given_name.clone(),
228 family_name: claims.family_name.clone(),
229 picture: claims.picture.clone(),
230 groups: claims.groups.clone(),
231 provider: provider.to_string(),
232 }
233 }
234
235 pub fn email_domain(&self) -> Option<&str> {
237 self.email.as_ref().and_then(|e| e.split('@').nth(1))
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_token_expiration() {
247 let token = TokenSet {
248 access_token: "test".to_string(),
249 refresh_token: None,
250 id_token: None,
251 expires_at: Utc::now() + chrono::Duration::hours(1),
252 token_type: "Bearer".to_string(),
253 scopes: vec![],
254 };
255
256 assert!(!token.is_expired());
257 assert!(!token.expires_within(chrono::Duration::minutes(30)));
258 assert!(token.expires_within(chrono::Duration::hours(2)));
259 }
260
261 #[test]
262 fn test_string_or_array() {
263 let single = StringOrArray::String("test".to_string());
264 assert!(single.contains("test"));
265 assert!(!single.contains("other"));
266
267 let array = StringOrArray::Array(vec!["one".to_string(), "two".to_string()]);
268 assert!(array.contains("one"));
269 assert!(array.contains("two"));
270 assert!(!array.contains("three"));
271 }
272
273 #[test]
274 fn test_claim_validation() {
275 let validator = TokenValidator::new("https://accounts.google.com", "client-id");
276
277 let claims = IdTokenClaims {
278 iss: "https://accounts.google.com".to_string(),
279 sub: "user123".to_string(),
280 aud: StringOrArray::String("client-id".to_string()),
281 exp: Utc::now().timestamp() + 3600,
282 iat: Utc::now().timestamp(),
283 nonce: Some("test-nonce".to_string()),
284 email: Some("user@example.com".to_string()),
285 email_verified: Some(true),
286 name: Some("Test User".to_string()),
287 given_name: None,
288 family_name: None,
289 picture: None,
290 groups: vec![],
291 additional: HashMap::new(),
292 };
293
294 assert!(validator.validate_claims(&claims, Some("test-nonce")).is_ok());
295 assert!(validator.validate_claims(&claims, Some("wrong-nonce")).is_err());
296 }
297}