#![cfg(feature = "axum")]
use axum::{
Extension, body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response,
};
use serde::{Deserialize, Serialize};
use crate::models::JwksDocument;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub iss: String,
pub sub: String,
pub iat: i64,
pub exp: i64,
pub sid: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub organization: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workspace: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub permissions: Option<TokenPermissions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub claims: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(flatten)]
pub custom_claims: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPermissions {
#[serde(skip_serializing_if = "Option::is_none")]
pub organization: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workspace: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct AuthContext {
pub user_id: String,
pub session_id: String,
pub organization_id: Option<String>,
pub workspace_id: Option<String>,
pub permissions: Option<TokenPermissions>,
pub claims: TokenClaims,
}
#[derive(Clone)]
pub struct AuthConfig {
pub public_key: String,
pub public_jwks: Option<JwksDocument>,
pub allowed_clock_skew: u64,
pub validate_exp: bool,
pub validate_nbf: bool,
pub required_issuer: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
public_key: String::new(),
public_jwks: None,
allowed_clock_skew: 5,
validate_exp: true,
validate_nbf: true,
required_issuer: None,
}
}
}
pub fn extract_auth_context(req: &Request<Body>) -> Option<&AuthContext> {
req.extensions().get::<AuthContext>()
}
#[derive(Debug, Clone)]
pub struct RequirePermission {
pub permission: String,
pub scope: PermissionScope,
}
#[derive(Debug, Clone)]
pub enum PermissionScope {
Organization,
Workspace,
}
impl RequirePermission {
pub fn organization(permission: impl Into<String>) -> Self {
Self {
permission: permission.into(),
scope: PermissionScope::Organization,
}
}
pub fn workspace(permission: impl Into<String>) -> Self {
Self {
permission: permission.into(),
scope: PermissionScope::Workspace,
}
}
}
pub async fn require_permission_middleware(
Extension(required): Extension<RequirePermission>,
req: Request<Body>,
next: Next,
) -> Result<Response, (StatusCode, String)> {
let auth_context = req.extensions().get::<AuthContext>().ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
"No auth context found".to_string(),
)
})?;
let has_permission = if let Some(permissions) = &auth_context.permissions {
match required.scope {
PermissionScope::Organization => permissions
.organization
.as_ref()
.map(|perms| perms.contains(&required.permission))
.unwrap_or(false),
PermissionScope::Workspace => permissions
.workspace
.as_ref()
.map(|perms| perms.contains(&required.permission))
.unwrap_or(false),
}
} else {
false
};
if !has_permission {
return Err((
StatusCode::FORBIDDEN,
format!("Missing required permission: {}", required.permission),
));
}
Ok(next.run(req).await)
}