use crate::error::{Error, ErrorCode, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GrantType {
#[serde(rename = "authorization_code")]
AuthorizationCode,
#[serde(rename = "refresh_token")]
RefreshToken,
#[serde(rename = "client_credentials")]
ClientCredentials,
#[serde(rename = "password")]
Password,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ResponseType {
#[serde(rename = "code")]
Code,
#[serde(rename = "token")]
Token,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TokenType {
Bearer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthClient {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<GrantType>,
pub response_types: Vec<ResponseType>,
pub scopes: Vec<String>,
#[serde(flatten)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct AuthorizationCode {
pub code: String,
pub client_id: String,
pub user_id: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub expires_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub access_token: String,
pub token_type: TokenType,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct TokenInfo {
pub token: String,
pub client_id: String,
pub user_id: String,
pub scopes: Vec<String>,
pub expires_at: u64,
pub token_type: TokenType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthError {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_uri: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcDiscoveryMetadata {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revocation_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub introspection_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device_authorization_endpoint: Option<String>,
pub response_types_supported: Vec<ResponseType>,
pub grant_types_supported: Vec<GrantType>,
pub scopes_supported: Vec<String>,
pub token_endpoint_auth_methods_supported: Vec<String>,
pub code_challenge_methods_supported: Vec<String>,
}
pub type OAuthMetadata = OidcDiscoveryMetadata;
#[derive(Debug, Clone, Deserialize)]
pub struct AuthorizationRequest {
pub response_type: ResponseType,
pub client_id: String,
pub redirect_uri: String,
#[serde(default)]
pub scope: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge_method: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TokenRequest {
pub grant_type: GrantType,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_verifier: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RevocationRequest {
pub token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type_hint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
}
#[async_trait]
pub trait OAuthProvider: Send + Sync {
async fn register_client(&self, client: OAuthClient) -> Result<OAuthClient>;
async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()>;
async fn create_authorization_code(
&self,
client_id: &str,
user_id: &str,
redirect_uri: &str,
scopes: Vec<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
) -> Result<String>;
async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken>;
async fn create_access_token(
&self,
client_id: &str,
user_id: &str,
scopes: Vec<String>,
) -> Result<AccessToken>;
async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken>;
async fn revoke_token(&self, token: &str) -> Result<()>;
async fn validate_token(&self, token: &str) -> Result<TokenInfo>;
async fn metadata(&self) -> Result<OAuthMetadata>;
async fn discover(&self, _issuer_url: &str) -> Result<OidcDiscoveryMetadata> {
Err(Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
"OIDC discovery not implemented for this provider",
))
}
}
#[derive(Debug)]
pub struct InMemoryOAuthProvider {
base_url: String,
clients: Arc<RwLock<HashMap<String, OAuthClient>>>,
codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
refresh_tokens: Arc<RwLock<HashMap<String, String>>>,
token_expiration: u64,
code_expiration: u64,
supported_scopes: Vec<String>,
}
impl InMemoryOAuthProvider {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
clients: Arc::new(RwLock::new(HashMap::new())),
codes: Arc::new(RwLock::new(HashMap::new())),
tokens: Arc::new(RwLock::new(HashMap::new())),
refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
token_expiration: 3600, code_expiration: 600, supported_scopes: vec!["read".to_string(), "write".to_string()],
}
}
fn generate_token() -> String {
Uuid::new_v4().to_string()
}
fn now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn verify_pkce(verifier: &str, challenge: &str, method: &str) -> bool {
match method {
"plain" => verifier == challenge,
"S256" => {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let result = hasher.finalize();
let encoded = URL_SAFE_NO_PAD.encode(result);
encoded == challenge
},
_ => false,
}
}
}
#[async_trait]
impl OAuthProvider for InMemoryOAuthProvider {
async fn register_client(&self, mut client: OAuthClient) -> Result<OAuthClient> {
if client.client_id.is_empty() {
client.client_id = Self::generate_token();
}
if client.client_secret.is_none() {
client.client_secret = Some(Self::generate_token());
}
let mut clients = self.clients.write().await;
clients.insert(client.client_id.clone(), client.clone());
Ok(client)
}
async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>> {
let clients = self.clients.read().await;
Ok(clients.get(client_id).cloned())
}
async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()> {
let client = self
.get_client(&request.client_id)
.await?
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid client_id"))?;
if !client.redirect_uris.contains(&request.redirect_uri) {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid redirect_uri",
));
}
if !client.response_types.contains(&request.response_type) {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Unsupported response_type",
));
}
let requested_scopes: Vec<&str> = request.scope.split_whitespace().collect();
for scope in &requested_scopes {
if !self.supported_scopes.iter().any(|s| s == scope) {
return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid scope"));
}
}
Ok(())
}
async fn create_authorization_code(
&self,
client_id: &str,
user_id: &str,
redirect_uri: &str,
scopes: Vec<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
) -> Result<String> {
let code = Self::generate_token();
let expires_at = Self::now() + self.code_expiration;
let auth_code = AuthorizationCode {
code: code.clone(),
client_id: client_id.to_string(),
user_id: user_id.to_string(),
redirect_uri: redirect_uri.to_string(),
scopes,
code_challenge,
code_challenge_method,
expires_at,
};
let mut codes = self.codes.write().await;
codes.insert(code.clone(), auth_code);
Ok(code)
}
async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken> {
if request.grant_type != GrantType::AuthorizationCode {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid grant_type",
));
}
let code = request
.code
.as_ref()
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code"))?;
let mut codes = self.codes.write().await;
let auth_code = codes
.remove(code)
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid code"))?;
if auth_code.expires_at < Self::now() {
return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Code expired"));
}
let client_id = request
.client_id
.as_ref()
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing client_id"))?;
if auth_code.client_id != *client_id {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid client_id",
));
}
let redirect_uri = request
.redirect_uri
.as_ref()
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing redirect_uri"))?;
if auth_code.redirect_uri != *redirect_uri {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid redirect_uri",
));
}
if let Some(challenge) = &auth_code.code_challenge {
let verifier = request.code_verifier.as_ref().ok_or_else(|| {
Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code_verifier")
})?;
let method = auth_code
.code_challenge_method
.as_deref()
.unwrap_or("plain");
if !Self::verify_pkce(verifier, challenge, method) {
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid code_verifier",
));
}
}
self.create_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scopes)
.await
}
async fn create_access_token(
&self,
client_id: &str,
user_id: &str,
scopes: Vec<String>,
) -> Result<AccessToken> {
let access_token = Self::generate_token();
let refresh_token = Self::generate_token();
let expires_at = Self::now() + self.token_expiration;
let token_info = TokenInfo {
token: access_token.clone(),
client_id: client_id.to_string(),
user_id: user_id.to_string(),
scopes: scopes.clone(),
expires_at,
token_type: TokenType::Bearer,
};
let mut tokens = self.tokens.write().await;
tokens.insert(access_token.clone(), token_info);
let mut refresh_tokens = self.refresh_tokens.write().await;
refresh_tokens.insert(refresh_token.clone(), access_token.clone());
Ok(AccessToken {
access_token,
token_type: TokenType::Bearer,
expires_in: Some(self.token_expiration),
refresh_token: Some(refresh_token),
scope: Some(scopes.join(" ")),
extra: HashMap::new(),
})
}
async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken> {
let refresh_tokens = self.refresh_tokens.read().await;
let old_token = refresh_tokens
.get(refresh_token)
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?
.clone();
let tokens = self.tokens.read().await;
let token_info = tokens
.get(&old_token)
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?;
let client_id = token_info.client_id.clone();
let user_id = token_info.user_id.clone();
let scopes = token_info.scopes.clone();
drop(tokens);
drop(refresh_tokens);
let mut tokens = self.tokens.write().await;
tokens.remove(&old_token);
drop(tokens);
let mut refresh_tokens = self.refresh_tokens.write().await;
refresh_tokens.remove(refresh_token);
drop(refresh_tokens);
self.create_access_token(&client_id, &user_id, scopes).await
}
async fn revoke_token(&self, token: &str) -> Result<()> {
let mut tokens = self.tokens.write().await;
if tokens.remove(token).is_some() {
return Ok(());
}
drop(tokens);
let mut refresh_tokens = self.refresh_tokens.write().await;
if let Some(access_token) = refresh_tokens.remove(token) {
let mut tokens = self.tokens.write().await;
tokens.remove(&access_token);
}
Ok(())
}
async fn validate_token(&self, token: &str) -> Result<TokenInfo> {
let tokens = self.tokens.read().await;
let token_info = tokens
.get(token)
.ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid token"))?;
if token_info.expires_at < Self::now() {
return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Token expired"));
}
Ok(token_info.clone())
}
async fn metadata(&self) -> Result<OAuthMetadata> {
Ok(OAuthMetadata {
issuer: self.base_url.clone(),
authorization_endpoint: format!("{}/oauth2/authorize", self.base_url),
token_endpoint: format!("{}/oauth2/token", self.base_url),
jwks_uri: Some(format!("{}/oauth2/jwks", self.base_url)),
userinfo_endpoint: Some(format!("{}/oauth2/userinfo", self.base_url)),
registration_endpoint: Some(format!("{}/oauth2/register", self.base_url)),
revocation_endpoint: Some(format!("{}/oauth2/revoke", self.base_url)),
introspection_endpoint: Some(format!("{}/oauth2/introspect", self.base_url)),
device_authorization_endpoint: Some(format!(
"{}/oauth2/device/authorize",
self.base_url
)),
response_types_supported: vec![ResponseType::Code],
grant_types_supported: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
scopes_supported: self.supported_scopes.clone(),
token_endpoint_auth_methods_supported: vec![
"client_secret_basic".to_string(),
"client_secret_post".to_string(),
],
code_challenge_methods_supported: vec!["plain".to_string(), "S256".to_string()],
})
}
}
#[derive(Debug)]
pub struct ProxyOAuthProvider {
_upstream_url: String,
_token_cache: Arc<RwLock<HashMap<String, TokenInfo>>>,
}
impl ProxyOAuthProvider {
pub fn new(upstream_url: impl Into<String>) -> Self {
Self {
_upstream_url: upstream_url.into(),
_token_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_oauth_flow() {
let provider = InMemoryOAuthProvider::new("http://localhost:8080");
let client = OAuthClient {
client_id: String::new(),
client_secret: None,
client_name: "Test Client".to_string(),
redirect_uris: vec!["http://localhost:3000/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
response_types: vec![ResponseType::Code],
scopes: vec!["read".to_string(), "write".to_string()],
metadata: HashMap::new(),
};
let registered = provider.register_client(client).await.unwrap();
assert!(!registered.client_id.is_empty());
assert!(registered.client_secret.is_some());
let auth_req = AuthorizationRequest {
response_type: ResponseType::Code,
client_id: registered.client_id.clone(),
redirect_uri: "http://localhost:3000/callback".to_string(),
scope: "read write".to_string(),
state: Some("test-state".to_string()),
code_challenge: None,
code_challenge_method: None,
};
provider.validate_authorization(&auth_req).await.unwrap();
let code = provider
.create_authorization_code(
®istered.client_id,
"user-123",
&auth_req.redirect_uri,
vec!["read".to_string(), "write".to_string()],
None,
None,
)
.await
.unwrap();
let token_req = TokenRequest {
grant_type: GrantType::AuthorizationCode,
code: Some(code),
redirect_uri: Some(auth_req.redirect_uri),
client_id: Some(registered.client_id.clone()),
client_secret: registered.client_secret.clone(),
refresh_token: None,
username: None,
password: None,
scope: None,
code_verifier: None,
};
let token = provider.exchange_code(&token_req).await.unwrap();
assert_eq!(token.token_type, TokenType::Bearer);
assert!(token.refresh_token.is_some());
let token_info = provider.validate_token(&token.access_token).await.unwrap();
assert_eq!(token_info.client_id, registered.client_id);
assert_eq!(token_info.user_id, "user-123");
}
}