use crate::errors::{AuthError, Result};
use crate::server::core::client_registry::ClientConfig;
use crate::server::oauth::oauth2::OAuth2Server;
use crate::storage::core::AuthStorage;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct OAuth21SecurityConfig {
pub require_pkce_for_public_clients: bool,
pub disallow_implicit_grant: bool,
pub require_exact_redirect_uri_matching: bool,
pub require_secure_redirect_uris: bool,
pub max_auth_code_lifetime: u64,
pub max_access_token_lifetime: u64,
pub require_client_authentication: bool,
pub disallow_password_grant: bool,
}
impl Default for OAuth21SecurityConfig {
fn default() -> Self {
Self {
require_pkce_for_public_clients: true,
disallow_implicit_grant: true,
require_exact_redirect_uri_matching: true,
require_secure_redirect_uris: true,
max_auth_code_lifetime: 600, max_access_token_lifetime: 3600, require_client_authentication: true,
disallow_password_grant: true,
}
}
}
#[derive(Clone)]
pub struct OAuth21Server {
oauth2_server: Arc<OAuth2Server>,
security_config: OAuth21SecurityConfig,
}
impl OAuth21Server {
pub async fn new(
security_config: Option<OAuth21SecurityConfig>,
storage: Arc<dyn AuthStorage>,
) -> Result<Self> {
let oauth2_server = Arc::new(OAuth2Server::new(storage).await?);
let security_config = security_config.unwrap_or_default();
Ok(Self {
oauth2_server,
security_config,
})
}
pub async fn register_client(&self, mut config: ClientConfig) -> Result<ClientConfig> {
self.validate_oauth21_client_config(&mut config)?;
self.oauth2_server.register_client(config).await
}
pub async fn get_client(&self, client_id: &str) -> Result<Option<ClientConfig>> {
self.oauth2_server.get_client(client_id).await
}
pub async fn update_client(&self, client_id: &str, config: ClientConfig) -> Result<()> {
self.oauth2_server.update_client(client_id, config).await
}
pub async fn delete_client(&self, client_id: &str) -> Result<()> {
self.oauth2_server.delete_client(client_id).await
}
pub async fn get_server_configuration(&self) -> Result<serde_json::Value> {
let mut config = self.oauth2_server.get_server_configuration().await?;
if let Some(obj) = config.as_object_mut() {
if self.security_config.disallow_implicit_grant {
if let Some(grant_types) = obj.get_mut("grant_types_supported")
&& let Some(grants) = grant_types.as_array_mut()
{
grants.retain(|g| g.as_str() != Some("implicit"));
}
if let Some(response_types) = obj.get_mut("response_types_supported")
&& let Some(types) = response_types.as_array_mut()
{
types.retain(|t| {
if let Some(type_str) = t.as_str() {
!type_str.contains("token") || type_str.contains("code")
} else {
true
}
});
}
}
if self.security_config.disallow_password_grant
&& let Some(grant_types) = obj.get_mut("grant_types_supported")
&& let Some(grants) = grant_types.as_array_mut()
{
grants.retain(|g| g.as_str() != Some("password"));
}
obj.insert(
"oauth21_compliant".to_string(),
serde_json::Value::Bool(true),
);
obj.insert(
"pkce_required".to_string(),
serde_json::Value::Bool(self.security_config.require_pkce_for_public_clients),
);
obj.insert(
"implicit_grant_disabled".to_string(),
serde_json::Value::Bool(self.security_config.disallow_implicit_grant),
);
obj.insert(
"password_grant_disabled".to_string(),
serde_json::Value::Bool(self.security_config.disallow_password_grant),
);
}
Ok(config)
}
fn validate_oauth21_client_config(&self, config: &mut ClientConfig) -> Result<()> {
if self.security_config.require_secure_redirect_uris {
for uri in &config.redirect_uris {
if !uri.starts_with("https://")
&& !uri.starts_with("http://localhost")
&& !uri.starts_with("http://127.0.0.1")
{
return Err(AuthError::validation(
"OAuth 2.1 requires HTTPS redirect URIs (except localhost)",
));
}
}
}
if self.security_config.disallow_implicit_grant {
config.authorized_grant_types.retain(|g| g != "implicit");
config
.authorized_response_types
.retain(|r| !r.contains("token") || r.contains("code"));
}
if self.security_config.disallow_password_grant {
config.authorized_grant_types.retain(|g| g != "password");
}
if config.authorized_grant_types.is_empty() {
config
.authorized_grant_types
.push("authorization_code".to_string());
}
if config.authorized_response_types.is_empty() {
config.authorized_response_types.push("code".to_string());
}
Ok(())
}
pub fn get_security_config(&self) -> &OAuth21SecurityConfig {
&self.security_config
}
pub fn update_security_config(&mut self, config: OAuth21SecurityConfig) {
self.security_config = config;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::core::client_registry::{ClientConfig, ClientType};
use crate::storage::memory::InMemoryStorage;
#[tokio::test]
async fn test_oauth21_server_creation() {
let storage = Arc::new(InMemoryStorage::new());
let security_config = OAuth21SecurityConfig::default();
let server = OAuth21Server::new(Some(security_config), storage)
.await
.unwrap();
let client_config = ClientConfig {
client_id: "test_client".to_string(),
client_type: ClientType::Public,
redirect_uris: vec!["https://example.com/callback".to_string()].into(),
..Default::default()
};
let registered_client = server.register_client(client_config).await.unwrap();
assert_eq!(registered_client.client_id, "test_client");
assert!(
!registered_client
.authorized_grant_types
.contains(&"implicit".to_string())
);
assert!(
!registered_client
.authorized_response_types
.iter()
.any(|r| r.contains("token") && !r.contains("code"))
);
}
#[tokio::test]
async fn test_oauth21_security_validations() {
let storage = Arc::new(InMemoryStorage::new());
let security_config = OAuth21SecurityConfig::default();
let server = OAuth21Server::new(Some(security_config), storage)
.await
.unwrap();
let client_config = ClientConfig {
client_id: "test_client".to_string(),
client_type: ClientType::Public,
redirect_uris: vec!["http://example.com/callback".to_string()].into(),
..Default::default()
};
assert!(server.register_client(client_config).await.is_err());
let client_config = ClientConfig {
client_id: "test_client".to_string(),
client_type: ClientType::Public,
redirect_uris: vec!["https://example.com/callback".to_string()].into(),
..Default::default()
};
assert!(server.register_client(client_config).await.is_ok());
}
#[tokio::test]
async fn test_oauth21_server_configuration() {
let storage = Arc::new(InMemoryStorage::new());
let security_config = OAuth21SecurityConfig::default();
let server = OAuth21Server::new(Some(security_config), storage)
.await
.unwrap();
let config = server.get_server_configuration().await.unwrap();
assert_eq!(config["oauth21_compliant"], true);
assert_eq!(config["pkce_required"], true);
assert_eq!(config["implicit_grant_disabled"], true);
assert_eq!(config["password_grant_disabled"], true);
let grant_types = config["grant_types_supported"].as_array().unwrap();
assert!(!grant_types.iter().any(|g| g.as_str() == Some("implicit")));
assert!(!grant_types.iter().any(|g| g.as_str() == Some("password")));
}
}