use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthToken {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
pub scope: Option<String>,
pub token_type: String,
}
impl OAuthToken {
pub fn is_expired(&self) -> bool {
let Some(exp) = self.expires_at else {
return false;
};
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now + 30 >= exp
}
}
#[async_trait]
pub trait OAuthTokenStore: Send + Sync + 'static {
async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken>;
async fn set(&self, user_id: &str, provider: &str, token: OAuthToken);
async fn delete(&self, user_id: &str, provider: &str);
}
#[derive(Clone, Default)]
pub struct InMemoryTokenStore {
tokens: Arc<Mutex<HashMap<(String, String), OAuthToken>>>,
}
impl InMemoryTokenStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl OAuthTokenStore for InMemoryTokenStore {
async fn get(&self, user_id: &str, provider: &str) -> Option<OAuthToken> {
self.tokens
.lock()
.unwrap()
.get(&(user_id.to_string(), provider.to_string()))
.cloned()
}
async fn set(&self, user_id: &str, provider: &str, token: OAuthToken) {
self.tokens
.lock()
.unwrap()
.insert((user_id.to_string(), provider.to_string()), token);
}
async fn delete(&self, user_id: &str, provider: &str) {
self.tokens
.lock()
.unwrap()
.remove(&(user_id.to_string(), provider.to_string()));
}
}
#[derive(Debug, Clone)]
pub enum OAuthFlow {
AuthorizationCodePkce {
auth_url: String,
token_url: String,
redirect_uri: String,
},
ClientCredentials {
token_url: String,
},
RefreshOnly {
token_url: String,
},
}
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub provider: String,
pub client_id: String,
pub client_secret: Option<String>,
pub scopes: Vec<String>,
pub flow: OAuthFlow,
pub timeout: Duration,
}
impl OAuthConfig {
pub fn client_credentials(
token_url: impl Into<String>,
client_id: impl Into<String>,
client_secret: impl Into<String>,
scopes: &[&str],
) -> Self {
Self {
provider: "custom".to_string(),
client_id: client_id.into(),
client_secret: Some(client_secret.into()),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
flow: OAuthFlow::ClientCredentials {
token_url: token_url.into(),
},
timeout: Duration::from_secs(30),
}
}
pub fn authorization_code_pkce(
provider: impl Into<String>,
auth_url: impl Into<String>,
token_url: impl Into<String>,
redirect_uri: impl Into<String>,
client_id: impl Into<String>,
scopes: &[&str],
) -> Self {
Self {
provider: provider.into(),
client_id: client_id.into(),
client_secret: None,
scopes: scopes.iter().map(|s| s.to_string()).collect(),
flow: OAuthFlow::AuthorizationCodePkce {
auth_url: auth_url.into(),
token_url: token_url.into(),
redirect_uri: redirect_uri.into(),
},
timeout: Duration::from_secs(30),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
self.provider = provider.into();
self
}
}
#[derive(Debug, Clone)]
pub struct PkceChallenge {
pub verifier: String,
pub challenge: String,
}
impl PkceChallenge {
pub fn new() -> Self {
use sha2::{Digest, Sha256};
let mut raw = [0u8; 32];
getrandom::getrandom(&mut raw).expect("CSPRNG unavailable");
let verifier = base64_url_encode(&raw);
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let digest = hasher.finalize();
let challenge = base64_url_encode(&digest);
Self {
verifier,
challenge,
}
}
pub fn authorization_url(
&self,
auth_url: &str,
client_id: &str,
redirect_uri: &str,
scopes: &[String],
state: &str,
) -> String {
let scope = scopes.join(" ");
format!(
"{auth_url}?response_type=code\
&client_id={client_id}\
&redirect_uri={redirect_uri}\
&scope={scope}\
&state={state}\
&code_challenge={}\
&code_challenge_method=S256",
self.challenge
)
}
}
impl Default for PkceChallenge {
fn default() -> Self {
Self::new()
}
}
fn base64_url_encode(data: &[u8]) -> String {
use std::fmt::Write;
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut out = String::with_capacity((data.len() * 4 + 2) / 3);
for chunk in data.chunks(3) {
let b0 = chunk[0] as usize;
let b1 = if chunk.len() > 1 {
chunk[1] as usize
} else {
0
};
let b2 = if chunk.len() > 2 {
chunk[2] as usize
} else {
0
};
let _ = write!(out, "{}", CHARS[(b0 >> 2) & 63] as char);
let _ = write!(out, "{}", CHARS[((b0 << 4) | (b1 >> 4)) & 63] as char);
if chunk.len() > 1 {
let _ = write!(out, "{}", CHARS[((b1 << 2) | (b2 >> 6)) & 63] as char);
}
if chunk.len() > 2 {
let _ = write!(out, "{}", CHARS[b2 & 63] as char);
}
}
out
}
pub struct OAuthClient<S: OAuthTokenStore> {
config: OAuthConfig,
store: S,
http: Client,
}
impl<S: OAuthTokenStore> OAuthClient<S> {
pub fn new(config: OAuthConfig, store: S) -> anyhow::Result<Self> {
let http = Client::builder().timeout(config.timeout).build()?;
Ok(Self {
config,
store,
http,
})
}
pub async fn access_token(&self, user_id: &str) -> anyhow::Result<String> {
if let Some(token) = self.store.get(user_id, &self.config.provider).await {
if !token.is_expired() {
return Ok(token.access_token.clone());
}
if let Some(refresh_token) = &token.refresh_token {
if let Ok(refreshed) = self.refresh_token(refresh_token).await {
self.store
.set(user_id, &self.config.provider, refreshed.clone())
.await;
return Ok(refreshed.access_token);
}
}
}
if let OAuthFlow::ClientCredentials { .. } = &self.config.flow {
let token = self.fetch_client_credentials().await?;
self.store
.set(user_id, &self.config.provider, token.clone())
.await;
return Ok(token.access_token);
}
anyhow::bail!(
"No valid token for user '{}' on provider '{}'. \
Initiate an authorization flow first via OAuthClient::authorization_url().",
user_id,
self.config.provider
)
}
pub async fn store_token(&self, user_id: &str, token: OAuthToken) {
self.store.set(user_id, &self.config.provider, token).await;
}
pub async fn revoke(&self, user_id: &str) {
self.store.delete(user_id, &self.config.provider).await;
}
pub async fn exchange_code(&self, code: &str, verifier: &str) -> anyhow::Result<OAuthToken> {
let token_url = match &self.config.flow {
OAuthFlow::AuthorizationCodePkce {
token_url,
redirect_uri,
..
} => (token_url.clone(), Some(redirect_uri.clone())),
_ => anyhow::bail!("exchange_code requires AuthorizationCodePkce flow"),
};
let mut params = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("client_id", self.config.client_id.clone()),
("code_verifier", verifier.to_string()),
];
if let Some(uri) = token_url.1 {
params.push(("redirect_uri", uri));
}
if let Some(secret) = &self.config.client_secret {
params.push(("client_secret", secret.clone()));
}
self.post_token(&token_url.0, ¶ms).await
}
pub fn authorization_url(&self, state: &str) -> anyhow::Result<(String, PkceChallenge)> {
match &self.config.flow {
OAuthFlow::AuthorizationCodePkce {
auth_url,
redirect_uri,
..
} => {
let pkce = PkceChallenge::new();
let url = pkce.authorization_url(
auth_url,
&self.config.client_id,
redirect_uri,
&self.config.scopes,
state,
);
Ok((url, pkce))
}
_ => anyhow::bail!("authorization_url requires AuthorizationCodePkce flow"),
}
}
async fn fetch_client_credentials(&self) -> anyhow::Result<OAuthToken> {
let token_url = match &self.config.flow {
OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
_ => anyhow::bail!("fetch_client_credentials called on non-ClientCredentials flow"),
};
let mut params = vec![
("grant_type", "client_credentials".to_string()),
("client_id", self.config.client_id.clone()),
];
if !self.config.scopes.is_empty() {
params.push(("scope", self.config.scopes.join(" ")));
}
if let Some(secret) = &self.config.client_secret {
params.push(("client_secret", secret.clone()));
}
self.post_token(&token_url, ¶ms).await
}
async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<OAuthToken> {
let token_url = match &self.config.flow {
OAuthFlow::AuthorizationCodePkce { token_url, .. } => token_url.clone(),
OAuthFlow::RefreshOnly { token_url } => token_url.clone(),
OAuthFlow::ClientCredentials { token_url } => token_url.clone(),
};
let mut params = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_token.to_string()),
("client_id", self.config.client_id.clone()),
];
if let Some(secret) = &self.config.client_secret {
params.push(("client_secret", secret.clone()));
}
self.post_token(&token_url, ¶ms).await
}
async fn post_token(&self, url: &str, params: &[(&str, String)]) -> anyhow::Result<OAuthToken> {
let resp = self
.http
.post(url)
.form(params)
.send()
.await
.map_err(|e| anyhow::anyhow!("Token request failed: {e}"))?;
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
if !status.is_success() {
anyhow::bail!("Token endpoint returned {status}: {body}");
}
let raw: TokenResponse =
serde_json::from_str(&body).map_err(|e| anyhow::anyhow!("Token parse error: {e}"))?;
let expires_at = raw.expires_in.map(|secs| {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ secs
});
Ok(OAuthToken {
access_token: raw.access_token,
refresh_token: raw.refresh_token,
expires_at,
scope: raw.scope,
token_type: raw.token_type.unwrap_or_else(|| "Bearer".to_string()),
})
}
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
scope: Option<String>,
token_type: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pkce_challenge_base64url_no_padding() {
let pkce = PkceChallenge::new();
assert!(!pkce.verifier.contains('='));
assert!(!pkce.challenge.contains('='));
assert!(!pkce.verifier.contains('+'));
assert!(!pkce.challenge.contains('+'));
assert!(!pkce.verifier.contains('/'));
assert!(!pkce.challenge.contains('/'));
}
#[test]
fn pkce_authorization_url_contains_required_params() {
let pkce = PkceChallenge::new();
let url = pkce.authorization_url(
"https://auth.example.com/authorize",
"client-abc",
"https://myapp.example.com/callback",
&["openid".to_string(), "profile".to_string()],
"random-state",
);
assert!(url.contains("response_type=code"));
assert!(url.contains("client_id=client-abc"));
assert!(url.contains("code_challenge_method=S256"));
assert!(url.contains(&pkce.challenge));
assert!(url.contains("state=random-state"));
}
#[test]
fn token_not_expired_without_expiry() {
let t = OAuthToken {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: None,
scope: None,
token_type: "Bearer".to_string(),
};
assert!(!t.is_expired());
}
#[test]
fn token_expired_in_past() {
let t = OAuthToken {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: Some(1), scope: None,
token_type: "Bearer".to_string(),
};
assert!(t.is_expired());
}
#[test]
fn in_memory_store_operations() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
rt.block_on(async {
let store = InMemoryTokenStore::new();
let token = OAuthToken {
access_token: "abc".to_string(),
refresh_token: None,
expires_at: None,
scope: None,
token_type: "Bearer".to_string(),
};
store.set("user1", "github", token.clone()).await;
let fetched = store.get("user1", "github").await.unwrap();
assert_eq!(fetched.access_token, "abc");
store.delete("user1", "github").await;
assert!(store.get("user1", "github").await.is_none());
});
}
#[test]
fn config_client_credentials_builder() {
let cfg = OAuthConfig::client_credentials(
"https://token.example.com",
"id",
"secret",
&["read", "write"],
);
assert_eq!(cfg.scopes, vec!["read", "write"]);
matches!(cfg.flow, OAuthFlow::ClientCredentials { .. });
}
}