#![allow(deprecated)]
use crate::repository::{SimpleUserRepository, UserRepository};
use crate::{AuthenticationBackend, AuthenticationError, User};
use async_trait::async_trait;
use reinhardt_http::Request;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use subtle::ConstantTimeEq;
use tokio::sync::Mutex;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum GrantType {
AuthorizationCode,
ClientCredentials,
RefreshToken,
Implicit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub token: String,
pub token_type: String,
pub expires_in: u64,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
const AUTHORIZATION_CODE_TTL: Duration = Duration::from_secs(600);
#[derive(Debug, Clone)]
pub struct AuthorizationCode {
pub code: String,
pub client_id: String,
pub redirect_uri: String,
pub user_id: String,
pub scope: Option<String>,
pub created_at: Instant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Application {
pub client_id: String,
pub client_secret: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<GrantType>,
}
#[async_trait]
pub trait OAuth2TokenStore: Send + Sync {
async fn store_code(&self, code: AuthorizationCode) -> Result<(), String>;
async fn consume_code(&self, code: &str) -> Result<Option<AuthorizationCode>, String>;
async fn store_token(&self, user_id: &str, token: AccessToken) -> Result<(), String>;
async fn get_token(&self, token: &str) -> Result<Option<String>, String>;
async fn revoke_token(&self, token: &str) -> Result<(), String>;
}
pub struct InMemoryOAuth2Store {
codes: Arc<Mutex<HashMap<String, AuthorizationCode>>>,
tokens: Arc<Mutex<HashMap<String, String>>>, }
impl InMemoryOAuth2Store {
pub fn new() -> Self {
Self {
codes: Arc::new(Mutex::new(HashMap::new())),
tokens: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl Default for InMemoryOAuth2Store {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl OAuth2TokenStore for InMemoryOAuth2Store {
async fn store_code(&self, code: AuthorizationCode) -> Result<(), String> {
let mut codes = self.codes.lock().await;
codes.insert(code.code.clone(), code);
Ok(())
}
async fn consume_code(&self, code: &str) -> Result<Option<AuthorizationCode>, String> {
let mut codes = self.codes.lock().await;
match codes.remove(code) {
Some(auth_code) if auth_code.created_at.elapsed() > AUTHORIZATION_CODE_TTL => {
Err("authorization code has expired".to_string())
}
other => Ok(other),
}
}
async fn store_token(&self, user_id: &str, token: AccessToken) -> Result<(), String> {
let mut tokens = self.tokens.lock().await;
tokens.insert(token.token.clone(), user_id.to_string());
Ok(())
}
async fn get_token(&self, token: &str) -> Result<Option<String>, String> {
let tokens = self.tokens.lock().await;
Ok(tokens.get(token).cloned())
}
async fn revoke_token(&self, token: &str) -> Result<(), String> {
let mut tokens = self.tokens.lock().await;
tokens.remove(token);
Ok(())
}
}
pub struct OAuth2Authentication {
applications: Arc<Mutex<HashMap<String, OAuth2Application>>>,
token_store: Arc<dyn OAuth2TokenStore>,
user_repository: Arc<dyn UserRepository>,
}
impl OAuth2Authentication {
pub fn new() -> Self {
Self {
applications: Arc::new(Mutex::new(HashMap::new())),
token_store: Arc::new(InMemoryOAuth2Store::new()),
user_repository: Arc::new(SimpleUserRepository),
}
}
pub fn with_store(token_store: Arc<dyn OAuth2TokenStore>) -> Self {
Self {
applications: Arc::new(Mutex::new(HashMap::new())),
token_store,
user_repository: Arc::new(SimpleUserRepository),
}
}
pub fn with_repository(user_repository: Arc<dyn UserRepository>) -> Self {
Self {
applications: Arc::new(Mutex::new(HashMap::new())),
token_store: Arc::new(InMemoryOAuth2Store::new()),
user_repository,
}
}
pub fn with_store_and_repository(
token_store: Arc<dyn OAuth2TokenStore>,
user_repository: Arc<dyn UserRepository>,
) -> Self {
Self {
applications: Arc::new(Mutex::new(HashMap::new())),
token_store,
user_repository,
}
}
pub async fn register_application(&self, app: OAuth2Application) {
let mut applications = self.applications.lock().await;
applications.insert(app.client_id.clone(), app);
}
pub async fn validate_client(&self, client_id: &str, client_secret: &str) -> bool {
let applications = self.applications.lock().await;
if let Some(app) = applications.get(client_id) {
app.client_secret
.as_bytes()
.ct_eq(client_secret.as_bytes())
.into()
} else {
false
}
}
pub async fn generate_authorization_code(
&self,
client_id: &str,
redirect_uri: &str,
user_id: &str,
scope: Option<String>,
) -> Result<String, String> {
let code = format!("code_{}", Uuid::new_v4());
let auth_code = AuthorizationCode {
code: code.clone(),
client_id: client_id.to_string(),
redirect_uri: redirect_uri.to_string(),
user_id: user_id.to_string(),
scope,
created_at: Instant::now(),
};
self.token_store.store_code(auth_code).await?;
Ok(code)
}
pub async fn exchange_code(
&self,
code: &str,
client_id: &str,
client_secret: &str,
redirect_uri: &str,
) -> Result<AccessToken, String> {
if !self.validate_client(client_id, client_secret).await {
return Err("Invalid client credentials".to_string());
}
let auth_code = self
.token_store
.consume_code(code)
.await?
.ok_or_else(|| "Invalid or expired authorization code".to_string())?;
if auth_code.client_id != client_id {
return Err("Authorization code was not issued to this client".to_string());
}
if auth_code.redirect_uri != redirect_uri {
return Err("redirect_uri does not match the authorization request".to_string());
}
let token = AccessToken {
token: format!("access_{}", Uuid::new_v4()),
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: Some(format!("refresh_{}", Uuid::new_v4())),
scope: auth_code.scope.clone(),
};
self.token_store
.store_token(&auth_code.user_id, token.clone())
.await?;
Ok(token)
}
}
impl Default for OAuth2Authentication {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AuthenticationBackend for OAuth2Authentication {
async fn authenticate(
&self,
request: &Request,
) -> Result<Option<Box<dyn User>>, AuthenticationError> {
let auth_header = request
.headers
.get("Authorization")
.and_then(|h| h.to_str().ok());
if let Some(header) = auth_header
&& let Some(token) = header.strip_prefix("Bearer ")
{
match self.token_store.get_token(token).await {
Ok(Some(user_id)) => {
return self.get_user(&user_id).await;
}
Ok(None) => {
return Ok(None);
}
Err(e) => {
return Err(AuthenticationError::Unknown(format!(
"Token store error: {}",
e
)));
}
}
}
Ok(None)
}
async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
self.user_repository
.get_user_by_id(user_id)
.await
.map_err(|e| AuthenticationError::Unknown(format!("User repository error: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::user::SimpleUser;
use rstest::rstest;
#[rstest]
#[tokio::test]
async fn test_oauth2_application() {
let app = OAuth2Application {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["https://example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app).await;
assert!(auth.validate_client("test_client", "test_secret").await);
assert!(!auth.validate_client("test_client", "wrong_secret").await);
}
#[rstest]
#[tokio::test]
async fn test_authorization_code_flow() {
let app = OAuth2Application {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["https://example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app).await;
let code = auth
.generate_authorization_code(
"test_client",
"https://example.com/callback",
"user_123",
Some("read write".to_string()),
)
.await
.unwrap();
assert!(code.starts_with("code_"));
let token = auth
.exchange_code(
&code,
"test_client",
"test_secret",
"https://example.com/callback",
)
.await
.unwrap();
assert_eq!(token.token_type, "Bearer");
assert_eq!(token.expires_in, 3600);
assert!(token.refresh_token.is_some());
}
#[rstest]
#[tokio::test]
async fn test_token_store() {
let store = InMemoryOAuth2Store::new();
let code = AuthorizationCode {
code: "test_code".to_string(),
client_id: "client_1".to_string(),
redirect_uri: "https://example.com/callback".to_string(),
user_id: "user_123".to_string(),
scope: Some("read".to_string()),
created_at: Instant::now(),
};
store.store_code(code.clone()).await.unwrap();
let retrieved = store.consume_code("test_code").await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().user_id, "user_123");
let consumed = store.consume_code("test_code").await.unwrap();
assert!(consumed.is_none());
}
#[rstest]
#[tokio::test]
async fn test_exchange_code_rejects_mismatched_client_id() {
let app_a = OAuth2Application {
client_id: "client_a".to_string(),
client_secret: "secret_a".to_string(),
redirect_uris: vec!["https://a.example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let app_b = OAuth2Application {
client_id: "client_b".to_string(),
client_secret: "secret_b".to_string(),
redirect_uris: vec!["https://b.example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app_a).await;
auth.register_application(app_b).await;
let code = auth
.generate_authorization_code(
"client_a",
"https://a.example.com/callback",
"user_123",
None,
)
.await
.unwrap();
let result = auth
.exchange_code(
&code,
"client_b",
"secret_b",
"https://b.example.com/callback",
)
.await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
"Authorization code was not issued to this client"
);
}
#[rstest]
#[tokio::test]
async fn test_invalid_client_credentials() {
let app = OAuth2Application {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["https://example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app).await;
let code = auth
.generate_authorization_code(
"test_client",
"https://example.com/callback",
"user_123",
None,
)
.await
.unwrap();
let result = auth
.exchange_code(
&code,
"test_client",
"wrong_secret",
"https://example.com/callback",
)
.await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_simple_user_repository() {
let repo = SimpleUserRepository;
let user = repo.get_user_by_id("test_user").await.unwrap();
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.get_username(), "test_user");
assert!(user.is_authenticated());
assert!(user.is_active());
}
#[rstest]
#[tokio::test]
async fn test_oauth2_with_default_repository() {
let auth = OAuth2Authentication::new();
let user = auth.get_user("user_456").await.unwrap();
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.get_username(), "user_456");
}
#[rstest]
#[tokio::test]
async fn test_oauth2_with_custom_repository() {
struct MockUserRepository {
username: String,
}
#[async_trait]
impl UserRepository for MockUserRepository {
async fn get_user_by_id(&self, user_id: &str) -> Result<Option<Box<dyn User>>, String> {
if user_id == "mock_user" {
Ok(Some(Box::new(SimpleUser {
id: Uuid::from_u128(999),
username: self.username.clone(),
email: "mock@example.com".to_string(),
is_active: true,
is_admin: true,
is_staff: true,
is_superuser: true,
})))
} else {
Ok(None)
}
}
}
let custom_repo = Arc::new(MockUserRepository {
username: "custom_mock_user".to_string(),
});
let auth = OAuth2Authentication::with_repository(custom_repo);
let user = auth.get_user("mock_user").await.unwrap();
assert!(user.is_some());
let user = user.unwrap();
assert_eq!(user.get_username(), "custom_mock_user");
assert!(user.is_admin());
let user = auth.get_user("nonexistent").await.unwrap();
assert!(user.is_none());
}
#[rstest]
#[tokio::test]
async fn test_oauth2_with_store_and_repository() {
struct CustomRepository;
#[async_trait]
impl UserRepository for CustomRepository {
async fn get_user_by_id(&self, user_id: &str) -> Result<Option<Box<dyn User>>, String> {
Ok(Some(Box::new(SimpleUser {
id: Uuid::from_u128(777),
username: format!("custom_{}", user_id),
email: format!("{}@custom.com", user_id),
is_active: true,
is_admin: false,
is_staff: true,
is_superuser: false,
})))
}
}
let token_store = Arc::new(InMemoryOAuth2Store::new());
let user_repo = Arc::new(CustomRepository);
let auth = OAuth2Authentication::with_store_and_repository(token_store, user_repo);
let user = auth.get_user("test").await.unwrap().unwrap();
assert_eq!(user.get_username(), "custom_test");
}
#[rstest]
#[tokio::test]
async fn test_exchange_code_rejects_mismatched_redirect_uri() {
let app = OAuth2Application {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["https://example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app).await;
let code = auth
.generate_authorization_code(
"test_client",
"https://example.com/callback",
"user_123",
None,
)
.await
.unwrap();
let result = auth
.exchange_code(
&code,
"test_client",
"test_secret",
"https://attacker.example.com/callback",
)
.await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
"redirect_uri does not match the authorization request"
);
}
#[rstest]
#[tokio::test]
async fn test_exchange_code_succeeds_with_matching_redirect_uri() {
let app = OAuth2Application {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string(),
redirect_uris: vec!["https://example.com/callback".to_string()],
grant_types: vec![GrantType::AuthorizationCode],
};
let auth = OAuth2Authentication::new();
auth.register_application(app).await;
let code = auth
.generate_authorization_code(
"test_client",
"https://example.com/callback",
"user_123",
None,
)
.await
.unwrap();
let result = auth
.exchange_code(
&code,
"test_client",
"test_secret",
"https://example.com/callback",
)
.await;
assert!(result.is_ok());
let token = result.unwrap();
assert_eq!(token.token_type, "Bearer");
}
}