1use std::{
35 collections::HashMap,
36 sync::{Arc, Mutex},
37 time::{Duration, SystemTime},
38};
39
40use async_trait::async_trait;
41use reqwest::Client;
42use serde::{Deserialize, Serialize};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct OAuthToken {
49 pub access_token: String,
51 pub refresh_token: Option<String>,
53 pub expires_at: Option<u64>,
55 pub scope: Option<String>,
57 pub token_type: String,
59}
60
61impl OAuthToken {
62 pub fn is_expired(&self) -> bool {
64 let Some(exp) = self.expires_at else {
65 return false;
66 };
67 let now = SystemTime::now()
68 .duration_since(SystemTime::UNIX_EPOCH)
69 .unwrap_or_default()
70 .as_secs();
71 now + 30 >= exp
72 }
73}
74
75#[async_trait]
82pub trait OAuthTokenStore: Send + Sync + 'static {
83 async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken>;
85 async fn set(&self, user_id: &str, provider: &str, token: OAuthToken);
87 async fn delete(&self, user_id: &str, provider: &str);
89}
90
91#[derive(Clone, Default)]
93pub struct InMemoryTokenStore {
94 tokens: Arc<Mutex<HashMap<(String, String), OAuthToken>>>,
95}
96
97impl InMemoryTokenStore {
98 pub fn new() -> Self {
100 Self::default()
101 }
102}
103
104#[async_trait]
105impl OAuthTokenStore for InMemoryTokenStore {
106 async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken> {
107 self.tokens
108 .lock()
109 .unwrap()
110 .get(&(user_id.to_string(), provider.to_string()))
111 .cloned()
112 }
113
114 async fn set(&self, user_id: &str, provider: &str, token: OAuthToken) {
115 self.tokens
116 .lock()
117 .unwrap()
118 .insert((user_id.to_string(), provider.to_string()), token);
119 }
120
121 async fn delete(&self, user_id: &str, provider: &str) {
122 self.tokens
123 .lock()
124 .unwrap()
125 .remove(&(user_id.to_string(), provider.to_string()));
126 }
127}
128
129#[derive(Debug, Clone)]
133pub enum OAuthFlow {
134 AuthorizationCodePkce {
136 auth_url: String,
138 token_url: String,
140 redirect_uri: String,
142 },
143 ClientCredentials {
145 token_url: String,
147 },
148 RefreshOnly {
150 token_url: String,
152 },
153}
154
155#[derive(Debug, Clone)]
157pub struct OAuthConfig {
158 pub provider: String,
160 pub client_id: String,
162 pub client_secret: Option<String>,
164 pub scopes: Vec<String>,
166 pub flow: OAuthFlow,
168 pub timeout: Duration,
170}
171
172impl OAuthConfig {
173 pub fn client_credentials(
175 token_url: impl Into<String>,
176 client_id: impl Into<String>,
177 client_secret: impl Into<String>,
178 scopes: &[&str],
179 ) -> Self {
180 Self {
181 provider: "custom".to_string(),
182 client_id: client_id.into(),
183 client_secret: Some(client_secret.into()),
184 scopes: scopes.iter().map(|s| s.to_string()).collect(),
185 flow: OAuthFlow::ClientCredentials {
186 token_url: token_url.into(),
187 },
188 timeout: Duration::from_secs(30),
189 }
190 }
191
192 pub fn authorization_code_pkce(
194 provider: impl Into<String>,
195 auth_url: impl Into<String>,
196 token_url: impl Into<String>,
197 redirect_uri: impl Into<String>,
198 client_id: impl Into<String>,
199 scopes: &[&str],
200 ) -> Self {
201 Self {
202 provider: provider.into(),
203 client_id: client_id.into(),
204 client_secret: None,
205 scopes: scopes.iter().map(|s| s.to_string()).collect(),
206 flow: OAuthFlow::AuthorizationCodePkce {
207 auth_url: auth_url.into(),
208 token_url: token_url.into(),
209 redirect_uri: redirect_uri.into(),
210 },
211 timeout: Duration::from_secs(30),
212 }
213 }
214
215 pub fn with_timeout(mut self, timeout: Duration) -> Self {
217 self.timeout = timeout;
218 self
219 }
220
221 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
223 self.provider = provider.into();
224 self
225 }
226}
227
228#[derive(Debug, Clone)]
232pub struct PkceChallenge {
233 pub verifier: String,
235 pub challenge: String,
237}
238
239impl PkceChallenge {
240 pub fn new() -> Self {
242 use sha2::{Digest, Sha256};
243
244 let mut raw = [0u8; 32];
246 getrandom::getrandom(&mut raw).expect("CSPRNG unavailable");
247 let verifier = base64_url_encode(&raw);
248
249 let mut hasher = Sha256::new();
251 hasher.update(verifier.as_bytes());
252 let digest = hasher.finalize();
253 let challenge = base64_url_encode(&digest);
254
255 Self {
256 verifier,
257 challenge,
258 }
259 }
260
261 pub fn authorization_url(
263 &self,
264 auth_url: &str,
265 client_id: &str,
266 redirect_uri: &str,
267 scopes: &[String],
268 state: &str,
269 ) -> String {
270 let scope = scopes.join(" ");
271 format!(
272 "{auth_url}?response_type=code\
273 &client_id={client_id}\
274 &redirect_uri={redirect_uri}\
275 &scope={scope}\
276 &state={state}\
277 &code_challenge={}\
278 &code_challenge_method=S256",
279 self.challenge
280 )
281 }
282}
283
284impl Default for PkceChallenge {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290fn base64_url_encode(data: &[u8]) -> String {
291 use std::fmt::Write;
292 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
294 let mut out = String::with_capacity((data.len() * 4).div_ceil(3));
295 for chunk in data.chunks(3) {
296 let b0 = chunk[0] as usize;
297 let b1 = if chunk.len() > 1 {
298 chunk[1] as usize
299 } else {
300 0
301 };
302 let b2 = if chunk.len() > 2 {
303 chunk[2] as usize
304 } else {
305 0
306 };
307 let _ = write!(out, "{}", CHARS[(b0 >> 2) & 63] as char);
308 let _ = write!(out, "{}", CHARS[((b0 << 4) | (b1 >> 4)) & 63] as char);
309 if chunk.len() > 1 {
310 let _ = write!(out, "{}", CHARS[((b1 << 2) | (b2 >> 6)) & 63] as char);
311 }
312 if chunk.len() > 2 {
313 let _ = write!(out, "{}", CHARS[b2 & 63] as char);
314 }
315 }
316 out
317}
318
319pub struct OAuthClient<S: OAuthTokenStore> {
326 config: OAuthConfig,
327 store: S,
328 http: Client,
329}
330
331impl<S: OAuthTokenStore> OAuthClient<S> {
332 pub fn new(config: OAuthConfig, store: S) -> anyhow::Result<Self> {
334 let http = Client::builder().timeout(config.timeout).build()?;
335 Ok(Self {
336 config,
337 store,
338 http,
339 })
340 }
341
342 pub async fn access_token(&self, user_id: &str) -> anyhow::Result<String> {
349 if let Some(token) = self.store.get(user_id, &self.config.provider).await {
351 if !token.is_expired() {
352 return Ok(token.access_token.clone());
353 }
354 if let Some(refresh_token) = &token.refresh_token
356 && let Ok(refreshed) = self.refresh_token(refresh_token).await
357 {
358 self.store
359 .set(user_id, &self.config.provider, refreshed.clone())
360 .await;
361 return Ok(refreshed.access_token);
362 }
364 }
365
366 if let OAuthFlow::ClientCredentials { .. } = &self.config.flow {
368 let token = self.fetch_client_credentials().await?;
369 self.store
370 .set(user_id, &self.config.provider, token.clone())
371 .await;
372 return Ok(token.access_token);
373 }
374
375 anyhow::bail!(
376 "No valid token for user '{}' on provider '{}'. \
377 Initiate an authorization flow first via OAuthClient::authorization_url().",
378 user_id,
379 self.config.provider
380 )
381 }
382
383 pub async fn store_token(&self, user_id: &str, token: OAuthToken) {
385 self.store.set(user_id, &self.config.provider, token).await;
386 }
387
388 pub async fn revoke(&self, user_id: &str) {
390 self.store.delete(user_id, &self.config.provider).await;
391 }
392
393 pub async fn exchange_code(&self, code: &str, verifier: &str) -> anyhow::Result<OAuthToken> {
398 let token_url = match &self.config.flow {
399 OAuthFlow::AuthorizationCodePkce {
400 token_url,
401 redirect_uri,
402 ..
403 } => (token_url.clone(), Some(redirect_uri.clone())),
404 _ => anyhow::bail!("exchange_code requires AuthorizationCodePkce flow"),
405 };
406
407 let mut params = vec![
408 ("grant_type", "authorization_code".to_string()),
409 ("code", code.to_string()),
410 ("client_id", self.config.client_id.clone()),
411 ("code_verifier", verifier.to_string()),
412 ];
413 if let Some(uri) = token_url.1 {
414 params.push(("redirect_uri", uri));
415 }
416 if let Some(secret) = &self.config.client_secret {
417 params.push(("client_secret", secret.clone()));
418 }
419
420 self.post_token(&token_url.0, ¶ms).await
421 }
422
423 pub fn authorization_url(&self, state: &str) -> anyhow::Result<(String, PkceChallenge)> {
428 match &self.config.flow {
429 OAuthFlow::AuthorizationCodePkce {
430 auth_url,
431 redirect_uri,
432 ..
433 } => {
434 let pkce = PkceChallenge::new();
435 let url = pkce.authorization_url(
436 auth_url,
437 &self.config.client_id,
438 redirect_uri,
439 &self.config.scopes,
440 state,
441 );
442 Ok((url, pkce))
443 }
444 _ => anyhow::bail!("authorization_url requires AuthorizationCodePkce flow"),
445 }
446 }
447
448 async fn fetch_client_credentials(&self) -> anyhow::Result<OAuthToken> {
451 let token_url = match &self.config.flow {
452 OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
453 _ => anyhow::bail!("fetch_client_credentials called on non-ClientCredentials flow"),
454 };
455
456 let mut params = vec![
457 ("grant_type", "client_credentials".to_string()),
458 ("client_id", self.config.client_id.clone()),
459 ];
460 if !self.config.scopes.is_empty() {
461 params.push(("scope", self.config.scopes.join(" ")));
462 }
463 if let Some(secret) = &self.config.client_secret {
464 params.push(("client_secret", secret.clone()));
465 }
466
467 self.post_token(&token_url, ¶ms).await
468 }
469
470 async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<OAuthToken> {
471 let token_url = match &self.config.flow {
472 OAuthFlow::AuthorizationCodePkce { token_url, .. } => token_url.clone(),
473 OAuthFlow::RefreshOnly { token_url } => token_url.clone(),
474 OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
475 };
476
477 let mut params = vec![
478 ("grant_type", "refresh_token".to_string()),
479 ("refresh_token", refresh_token.to_string()),
480 ("client_id", self.config.client_id.clone()),
481 ];
482 if let Some(secret) = &self.config.client_secret {
483 params.push(("client_secret", secret.clone()));
484 }
485
486 self.post_token(&token_url, ¶ms).await
487 }
488
489 async fn post_token(&self, url: &str, params: &[(&str, String)]) -> anyhow::Result<OAuthToken> {
490 let resp = self
491 .http
492 .post(url)
493 .form(params)
494 .send()
495 .await
496 .map_err(|e| anyhow::anyhow!("Token request failed: {e}"))?;
497
498 let status = resp.status();
499 let body = resp.text().await.unwrap_or_default();
500 if !status.is_success() {
501 anyhow::bail!("Token endpoint returned {status}: {body}");
502 }
503
504 let raw: TokenResponse =
505 serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("Token parse error: {e}"))?;
506
507 let expires_at = raw.expires_in.map(|secs| {
508 SystemTime::now()
509 .duration_since(SystemTime::UNIX_EPOCH)
510 .unwrap_or_default()
511 .as_secs()
512 + secs
513 });
514
515 Ok(OAuthToken {
516 access_token: raw.access_token,
517 refresh_token: raw.refresh_token,
518 expires_at,
519 scope: raw.scope,
520 token_type: raw.token_type.unwrap_or_else(|| "Bearer".to_string()),
521 })
522 }
523}
524
525#[derive(Deserialize)]
527struct TokenResponse {
528 access_token: String,
529 refresh_token: Option<String>,
530 expires_in: Option<u64>,
531 scope: Option<String>,
532 token_type: Option<String>,
533}
534
535#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn pkce_challenge_base64url_no_padding() {
543 let pkce = PkceChallenge::new();
544 assert!(!pkce.verifier.contains('='));
545 assert!(!pkce.challenge.contains('='));
546 assert!(!pkce.verifier.contains('+'));
547 assert!(!pkce.challenge.contains('+'));
548 assert!(!pkce.verifier.contains('/'));
549 assert!(!pkce.challenge.contains('/'));
550 }
551
552 #[test]
553 fn pkce_authorization_url_contains_required_params() {
554 let pkce = PkceChallenge::new();
555 let url = pkce.authorization_url(
556 "https://auth.example.com/authorize",
557 "client-abc",
558 "https://myapp.example.com/callback",
559 &["openid".to_string(), "profile".to_string()],
560 "random-state",
561 );
562 assert!(url.contains("response_type=code"));
563 assert!(url.contains("client_id=client-abc"));
564 assert!(url.contains("code_challenge_method=S256"));
565 assert!(url.contains(&pkce.challenge));
566 assert!(url.contains("state=random-state"));
567 }
568
569 #[test]
570 fn token_not_expired_without_expiry() {
571 let t = OAuthToken {
572 access_token: "tok".to_string(),
573 refresh_token: None,
574 expires_at: None,
575 scope: None,
576 token_type: "Bearer".to_string(),
577 };
578 assert!(!t.is_expired());
579 }
580
581 #[test]
582 fn token_expired_in_past() {
583 let t = OAuthToken {
584 access_token: "tok".to_string(),
585 refresh_token: None,
586 expires_at: Some(1), scope: None,
588 token_type: "Bearer".to_string(),
589 };
590 assert!(t.is_expired());
591 }
592
593 #[test]
594 fn in_memory_store_operations() {
595 let rt = tokio::runtime::Builder::new_current_thread()
596 .build()
597 .unwrap();
598 rt.block_on(async {
599 let store = InMemoryTokenStore::new();
600 let token = OAuthToken {
601 access_token: "abc".to_string(),
602 refresh_token: None,
603 expires_at: None,
604 scope: None,
605 token_type: "Bearer".to_string(),
606 };
607 store.set("user1", "github", token.clone()).await;
608 let fetched = store.get("user1", "github").await.unwrap();
609 assert_eq!(fetched.access_token, "abc");
610
611 store.delete("user1", "github").await;
612 assert!(store.get("user1", "github").await.is_none());
613 });
614 }
615
616 #[test]
617 fn config_client_credentials_builder() {
618 let cfg = OAuthConfig::client_credentials(
619 "https://token.example.com",
620 "id",
621 "secret",
622 &["read", "write"],
623 );
624 assert_eq!(cfg.scopes, vec!["read", "write"]);
625 matches!(cfg.flow, OAuthFlow::ClientCredentials { .. });
626 }
627}