use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{Error, Result, StorageError, ValidationError};
use crate::random::{generate_random_alphanumeric, generate_random_base64_url};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ClientType {
#[default]
Confidential,
Public,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthClient {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret_hash: Option<String>,
pub name: String,
pub client_type: ClientType,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<GrantType>,
pub scopes: Vec<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GrantType {
AuthorizationCode,
ClientCredentials,
RefreshToken,
Implicit,
Password,
}
impl std::fmt::Display for GrantType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GrantType::AuthorizationCode => write!(f, "authorization_code"),
GrantType::ClientCredentials => write!(f, "client_credentials"),
GrantType::RefreshToken => write!(f, "refresh_token"),
GrantType::Implicit => write!(f, "implicit"),
GrantType::Password => write!(f, "password"),
}
}
}
impl std::str::FromStr for GrantType {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
match s {
"authorization_code" => Ok(GrantType::AuthorizationCode),
"client_credentials" => Ok(GrantType::ClientCredentials),
"refresh_token" => Ok(GrantType::RefreshToken),
"implicit" => Ok(GrantType::Implicit),
"password" => Ok(GrantType::Password),
_ => Err(Error::Validation(ValidationError::Custom(format!(
"Unknown grant type: {}",
s
)))),
}
}
}
#[derive(Debug, Default)]
pub struct OAuthClientBuilder {
name: Option<String>,
client_type: ClientType,
redirect_uris: Vec<String>,
grant_types: Vec<GrantType>,
scopes: Vec<String>,
metadata: HashMap<String, String>,
}
impl OAuthClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn client_type(mut self, client_type: ClientType) -> Self {
self.client_type = client_type;
self
}
pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
self.redirect_uris.push(uri.into());
self
}
pub fn redirect_uris(mut self, uris: Vec<String>) -> Self {
self.redirect_uris = uris;
self
}
pub fn grant_type(mut self, grant_type: GrantType) -> Self {
if !self.grant_types.contains(&grant_type) {
self.grant_types.push(grant_type);
}
self
}
pub fn grant_types(mut self, grant_types: Vec<GrantType>) -> Self {
self.grant_types = grant_types;
self
}
pub fn scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn build(self) -> Result<(OAuthClient, Option<String>)> {
let name = self
.name
.ok_or_else(|| Error::Validation(ValidationError::EmptyField("name".to_string())))?;
if self.redirect_uris.is_empty() {
return Err(Error::Validation(ValidationError::Custom(
"At least one redirect URI is required".to_string(),
)));
}
for uri in &self.redirect_uris {
validate_redirect_uri(uri)?;
}
let client_id = generate_client_id()?;
let now = Utc::now();
let (client_secret_hash, plain_secret) = if self.client_type == ClientType::Confidential {
let secret = generate_client_secret()?;
let hash = hash_client_secret(&secret);
(Some(hash), Some(secret))
} else {
(None, None)
};
let grant_types = if self.grant_types.is_empty() {
vec![GrantType::AuthorizationCode]
} else {
self.grant_types
};
let client = OAuthClient {
client_id,
client_secret_hash,
name,
client_type: self.client_type,
redirect_uris: self.redirect_uris,
grant_types,
scopes: self.scopes,
metadata: self.metadata,
created_at: now,
updated_at: now,
enabled: true,
};
Ok((client, plain_secret))
}
}
impl OAuthClient {
pub fn builder() -> OAuthClientBuilder {
OAuthClientBuilder::new()
}
pub fn verify_secret(&self, secret: &str) -> bool {
match &self.client_secret_hash {
Some(hash) => verify_client_secret(secret, hash),
None => false, }
}
pub fn allows_grant_type(&self, grant_type: GrantType) -> bool {
self.grant_types.contains(&grant_type)
}
pub fn allows_redirect_uri(&self, uri: &str) -> bool {
self.redirect_uris.iter().any(|allowed| allowed == uri)
}
pub fn allows_scope(&self, scope: &str) -> bool {
self.scopes.is_empty() || self.scopes.iter().any(|s| s == scope)
}
pub fn allows_scopes(&self, scopes: &[String]) -> bool {
scopes.iter().all(|s| self.allows_scope(s))
}
pub fn filter_scopes(&self, requested: &[String]) -> Vec<String> {
if self.scopes.is_empty() {
requested.to_vec()
} else {
requested
.iter()
.filter(|s| self.scopes.contains(s))
.cloned()
.collect()
}
}
pub fn rotate_secret(&mut self) -> Result<Option<String>> {
if self.client_type == ClientType::Public {
return Ok(None);
}
let new_secret = generate_client_secret()?;
self.client_secret_hash = Some(hash_client_secret(&new_secret));
self.updated_at = Utc::now();
Ok(Some(new_secret))
}
pub fn disable(&mut self) {
self.enabled = false;
self.updated_at = Utc::now();
}
pub fn enable(&mut self) {
self.enabled = true;
self.updated_at = Utc::now();
}
}
#[async_trait]
pub trait OAuthClientStore: Send + Sync {
async fn save(&mut self, client: &OAuthClient) -> Result<()>;
async fn find_by_id(&self, client_id: &str) -> Result<Option<OAuthClient>>;
async fn delete(&mut self, client_id: &str) -> Result<()>;
async fn list(&self) -> Result<Vec<OAuthClient>>;
}
#[derive(Debug, Default)]
pub struct InMemoryClientStore {
clients: HashMap<String, OAuthClient>,
}
impl InMemoryClientStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl OAuthClientStore for InMemoryClientStore {
async fn save(&mut self, client: &OAuthClient) -> Result<()> {
self.clients
.insert(client.client_id.clone(), client.clone());
Ok(())
}
async fn find_by_id(&self, client_id: &str) -> Result<Option<OAuthClient>> {
Ok(self.clients.get(client_id).cloned())
}
async fn delete(&mut self, client_id: &str) -> Result<()> {
self.clients
.remove(client_id)
.ok_or_else(|| Error::Storage(StorageError::NotFound(client_id.to_string())))?;
Ok(())
}
async fn list(&self) -> Result<Vec<OAuthClient>> {
Ok(self.clients.values().cloned().collect())
}
}
fn generate_client_id() -> Result<String> {
let random = generate_random_alphanumeric(24)?;
Ok(format!("oa_{}", random))
}
fn generate_client_secret() -> Result<String> {
generate_random_base64_url(32)
}
fn hash_client_secret(secret: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(secret.as_bytes());
let result = hasher.finalize();
result.iter().map(|b| format!("{:02x}", b)).collect()
}
fn verify_client_secret(secret: &str, hash: &str) -> bool {
use crate::random::constant_time_compare_str;
let computed_hash = hash_client_secret(secret);
constant_time_compare_str(&computed_hash, hash)
}
fn validate_redirect_uri(uri: &str) -> Result<()> {
if uri.is_empty() {
return Err(Error::Validation(ValidationError::Custom(
"Redirect URI cannot be empty".to_string(),
)));
}
if uri.starts_with("http://localhost") || uri.starts_with("http://127.0.0.1") {
return Ok(());
}
if !uri.starts_with("https://") && !uri.starts_with("http://") {
if !uri.contains("://") {
return Err(Error::Validation(ValidationError::Custom(
"Redirect URI must have a valid scheme".to_string(),
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_confidential_client() {
let (client, secret) = OAuthClient::builder()
.name("Test App")
.client_type(ClientType::Confidential)
.redirect_uri("https://example.com/callback")
.grant_type(GrantType::AuthorizationCode)
.scope("read")
.scope("write")
.build()
.unwrap();
assert!(client.client_id.starts_with("oa_"));
assert!(client.client_secret_hash.is_some());
assert!(secret.is_some());
assert_eq!(client.name, "Test App");
assert!(client.enabled);
assert!(client.verify_secret(&secret.unwrap()));
assert!(!client.verify_secret("wrong_secret"));
}
#[test]
fn test_create_public_client() {
let (client, secret) = OAuthClient::builder()
.name("Mobile App")
.client_type(ClientType::Public)
.redirect_uri("myapp://callback")
.build()
.unwrap();
assert!(client.client_secret_hash.is_none());
assert!(secret.is_none());
assert!(!client.verify_secret("any_secret"));
}
#[test]
fn test_grant_type_check() {
let (client, _) = OAuthClient::builder()
.name("Test")
.redirect_uri("https://example.com/cb")
.grant_type(GrantType::AuthorizationCode)
.grant_type(GrantType::RefreshToken)
.build()
.unwrap();
assert!(client.allows_grant_type(GrantType::AuthorizationCode));
assert!(client.allows_grant_type(GrantType::RefreshToken));
assert!(!client.allows_grant_type(GrantType::ClientCredentials));
}
#[test]
fn test_scope_validation() {
let (client, _) = OAuthClient::builder()
.name("Test")
.redirect_uri("https://example.com/cb")
.scope("read")
.scope("write")
.build()
.unwrap();
assert!(client.allows_scope("read"));
assert!(client.allows_scope("write"));
assert!(!client.allows_scope("admin"));
let filtered = client.filter_scopes(&["read".to_string(), "admin".to_string()]);
assert_eq!(filtered, vec!["read".to_string()]);
}
#[test]
fn test_redirect_uri_validation() {
assert!(validate_redirect_uri("https://example.com/callback").is_ok());
assert!(validate_redirect_uri("http://localhost:3000/cb").is_ok());
assert!(validate_redirect_uri("myapp://callback").is_ok());
assert!(validate_redirect_uri("").is_err());
assert!(validate_redirect_uri("not-a-uri").is_err());
}
#[test]
fn test_secret_rotation() {
let (mut client, original_secret) = OAuthClient::builder()
.name("Test")
.client_type(ClientType::Confidential)
.redirect_uri("https://example.com/cb")
.build()
.unwrap();
let original_secret = original_secret.unwrap();
assert!(client.verify_secret(&original_secret));
let new_secret = client.rotate_secret().unwrap().unwrap();
assert!(!client.verify_secret(&original_secret));
assert!(client.verify_secret(&new_secret));
}
#[test]
fn test_grant_type_parsing() {
assert_eq!(
"authorization_code".parse::<GrantType>().unwrap(),
GrantType::AuthorizationCode
);
assert_eq!(
"client_credentials".parse::<GrantType>().unwrap(),
GrantType::ClientCredentials
);
assert!("invalid".parse::<GrantType>().is_err());
}
#[tokio::test]
async fn test_in_memory_store() {
let mut store = InMemoryClientStore::new();
let (client, _) = OAuthClient::builder()
.name("Test")
.redirect_uri("https://example.com/cb")
.build()
.unwrap();
let client_id = client.client_id.clone();
store.save(&client).await.unwrap();
assert!(store.find_by_id(&client_id).await.unwrap().is_some());
assert_eq!(store.list().await.unwrap().len(), 1);
store.delete(&client_id).await.unwrap();
assert!(store.find_by_id(&client_id).await.unwrap().is_none());
}
}