use crate::error::{Error, ErrorCode, Result};
use crate::server::auth::oauth2::OAuthProvider;
use crate::server::auth::traits::AuthContext;
use crate::types::auth::{AuthInfo, AuthScheme};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
#[async_trait]
pub trait AuthMiddleware: Send + Sync {
async fn authenticate(&self, auth_info: Option<&AuthInfo>) -> Result<AuthContext>;
fn is_required(&self) -> bool {
true
}
}
pub struct BearerTokenMiddleware {
provider: Arc<dyn OAuthProvider>,
required: bool,
}
impl std::fmt::Debug for BearerTokenMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerTokenMiddleware")
.field("provider", &"<dyn OAuthProvider>")
.field("required", &self.required)
.finish()
}
}
impl BearerTokenMiddleware {
pub fn new(provider: Arc<dyn OAuthProvider>) -> Self {
Self {
provider,
required: true,
}
}
pub fn with_required(mut self, required: bool) -> Self {
self.required = required;
self
}
}
#[async_trait]
impl AuthMiddleware for BearerTokenMiddleware {
async fn authenticate(&self, auth_info: Option<&AuthInfo>) -> Result<AuthContext> {
let Some(auth_info) = auth_info else {
if self.required {
return Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Authentication required",
));
} else {
return Ok(AuthContext {
subject: "anonymous".to_string(),
scopes: vec![],
claims: HashMap::new(),
token: None,
client_id: Some("anonymous".to_string()),
expires_at: None,
authenticated: false,
});
}
};
if auth_info.scheme != AuthScheme::Bearer {
return Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Invalid authentication scheme",
));
}
let token = auth_info
.token
.as_ref()
.ok_or_else(|| Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Missing token"))?;
let token_info =
self.provider.validate_token(token).await.map_err(|_| {
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Invalid token")
})?;
Ok(AuthContext {
subject: token_info.user_id,
scopes: token_info.scopes,
claims: HashMap::new(),
token: Some(token.clone()),
client_id: Some(token_info.client_id),
expires_at: Some(token_info.expires_at),
authenticated: true,
})
}
fn is_required(&self) -> bool {
self.required
}
}
pub struct ClientCredentialsMiddleware {
provider: Arc<dyn OAuthProvider>,
}
impl std::fmt::Debug for ClientCredentialsMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientCredentialsMiddleware")
.field("provider", &"<dyn OAuthProvider>")
.finish()
}
}
impl ClientCredentialsMiddleware {
pub fn new(provider: Arc<dyn OAuthProvider>) -> Self {
Self { provider }
}
}
#[async_trait]
impl AuthMiddleware for ClientCredentialsMiddleware {
async fn authenticate(&self, auth_info: Option<&AuthInfo>) -> Result<AuthContext> {
let auth_info = auth_info.ok_or_else(|| {
Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Authentication required",
)
})?;
let client_id = auth_info
.params
.get("client_id")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Missing client_id")
})?;
let client_secret = auth_info
.params
.get("client_secret")
.and_then(|v| v.as_str())
.ok_or_else(|| {
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Missing client_secret")
})?;
let client =
self.provider.get_client(client_id).await?.ok_or_else(|| {
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Invalid client")
})?;
if client.client_secret.as_deref() != Some(client_secret) {
return Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Invalid client credentials",
));
}
Ok(AuthContext {
subject: client.client_id.clone(), scopes: client.scopes,
claims: HashMap::new(),
token: None,
client_id: Some(client.client_id),
expires_at: None,
authenticated: true,
})
}
}
pub struct ScopeMiddleware {
inner: Box<dyn AuthMiddleware>,
required_scopes: Vec<String>,
require_all: bool,
}
impl std::fmt::Debug for ScopeMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopeMiddleware")
.field("inner", &"<dyn AuthMiddleware>")
.field("required_scopes", &self.required_scopes)
.finish()
}
}
impl ScopeMiddleware {
pub fn all(inner: Box<dyn AuthMiddleware>, scopes: Vec<String>) -> Self {
Self {
inner,
required_scopes: scopes,
require_all: true,
}
}
pub fn any(inner: Box<dyn AuthMiddleware>, scopes: Vec<String>) -> Self {
Self {
inner,
required_scopes: scopes,
require_all: false,
}
}
}
#[async_trait]
impl AuthMiddleware for ScopeMiddleware {
async fn authenticate(&self, auth_info: Option<&AuthInfo>) -> Result<AuthContext> {
let context = self.inner.authenticate(auth_info).await?;
let scope_refs: Vec<&str> = self.required_scopes.iter().map(|s| s.as_str()).collect();
let has_required_scopes = if self.require_all {
context.has_all_scopes(&scope_refs)
} else {
context.has_any_scope(&scope_refs)
};
if !has_required_scopes {
return Err(Error::protocol(
ErrorCode::PERMISSION_DENIED,
"Insufficient scopes",
));
}
Ok(context)
}
fn is_required(&self) -> bool {
self.inner.is_required()
}
}
pub struct CompositeMiddleware {
middlewares: Vec<Box<dyn AuthMiddleware>>,
require_any: bool,
}
impl std::fmt::Debug for CompositeMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeMiddleware")
.field(
"middlewares",
&format!("{} middlewares", self.middlewares.len()),
)
.field("require_any", &self.require_any)
.finish()
}
}
impl CompositeMiddleware {
pub fn new(middlewares: Vec<Box<dyn AuthMiddleware>>) -> Self {
Self {
middlewares,
require_any: true,
}
}
pub fn with_require_any(mut self, require_any: bool) -> Self {
self.require_any = require_any;
self
}
}
#[async_trait]
impl AuthMiddleware for CompositeMiddleware {
async fn authenticate(&self, auth_info: Option<&AuthInfo>) -> Result<AuthContext> {
let mut last_error = None;
for middleware in &self.middlewares {
match middleware.authenticate(auth_info).await {
Ok(context) => return Ok(context),
Err(e) => last_error = Some(e),
}
}
if self.require_any {
Err(last_error.unwrap_or_else(|| {
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, "Authentication failed")
}))
} else {
Ok(AuthContext {
subject: "anonymous".to_string(),
scopes: vec![],
claims: HashMap::new(),
token: None,
client_id: Some("anonymous".to_string()),
expires_at: None,
authenticated: false,
})
}
}
fn is_required(&self) -> bool {
self.require_any && self.middlewares.iter().any(|m| m.is_required())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::auth::oauth2::InMemoryOAuthProvider;
#[tokio::test]
async fn test_bearer_token_middleware() {
let provider = Arc::new(InMemoryOAuthProvider::new("http://localhost:8080"));
let middleware = BearerTokenMiddleware::new(provider.clone());
let result = middleware.authenticate(None).await;
assert!(result.is_err());
let auth = AuthInfo::bearer("invalid-token");
let result = middleware.authenticate(Some(&auth)).await;
assert!(result.is_err());
let token = provider
.create_access_token(
"client-123",
"user-456",
vec!["read".to_string(), "write".to_string()],
)
.await
.unwrap();
let auth = AuthInfo::bearer(&token.access_token);
let context = middleware.authenticate(Some(&auth)).await.unwrap();
assert_eq!(context.client_id, Some("client-123".to_string()));
assert_eq!(context.subject, "user-456");
assert!(context.has_scope("read"));
assert!(context.has_scope("write"));
}
#[tokio::test]
async fn test_scope_middleware() {
let provider = Arc::new(InMemoryOAuthProvider::new("http://localhost:8080"));
let bearer = Box::new(BearerTokenMiddleware::new(provider.clone()));
let scope_middleware =
ScopeMiddleware::all(bearer, vec!["read".to_string(), "write".to_string()]);
let token = provider
.create_access_token("client-123", "user-456", vec!["read".to_string()])
.await
.unwrap();
let auth = AuthInfo::bearer(&token.access_token);
let result = scope_middleware.authenticate(Some(&auth)).await;
assert!(result.is_err());
let token = provider
.create_access_token(
"client-123",
"user-456",
vec!["read".to_string(), "write".to_string()],
)
.await
.unwrap();
let auth = AuthInfo::bearer(&token.access_token);
let context = scope_middleware.authenticate(Some(&auth)).await.unwrap();
assert!(context.has_all_scopes(&["read", "write"]));
}
}