use axum::{
body::Body,
extract::{MatchedPath, Request, State},
http::{HeaderMap, Method},
middleware::Next,
response::Response,
};
use cedar_policy::{
Authorizer, Context, Decision, Entities, EntityUid, PolicySet, Request as CedarRequest,
};
use chrono::{Datelike, Timelike};
use figment;
use serde_json::json;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::{
config::{CedarConfig, Config},
error::Error,
middleware::token::Claims,
};
pub struct CedarAuthzBuilder {
config: CedarConfig,
path_normalizer: Option<fn(&str) -> String>,
#[cfg(feature = "cache")]
cache: Option<Arc<dyn PolicyCache>>,
}
impl CedarAuthzBuilder {
pub fn new(config: CedarConfig) -> Self {
Self {
config,
path_normalizer: None,
#[cfg(feature = "cache")]
cache: None,
}
}
pub fn with_path_normalizer(mut self, normalizer: fn(&str) -> String) -> Self {
self.path_normalizer = Some(normalizer);
self
}
#[cfg(feature = "cache")]
pub fn with_cache<C: PolicyCache + 'static>(mut self, cache: C) -> Self {
self.cache = Some(Arc::new(cache));
self
}
pub async fn build(self) -> Result<CedarAuthz, Error> {
let path = self.config.policy_path.clone();
let policies = tokio::task::spawn_blocking(move || std::fs::read_to_string(&path))
.await
.map_err(|e| Error::Internal(format!("Task join error: {}", e)))?
.map_err(|e| {
Error::Config(Box::new(figment::Error::from(format!(
"Failed to read Cedar policy file from '{}': {}",
self.config.policy_path.display(),
e
))))
})?;
let policy_set: PolicySet = policies.parse().map_err(|e| {
Error::Config(Box::new(figment::Error::from(format!(
"Failed to parse Cedar policies: {}",
e
))))
})?;
Ok(CedarAuthz {
authorizer: Arc::new(Authorizer::new()),
policy_set: Arc::new(RwLock::new(policy_set)),
config: Arc::new(self.config),
#[cfg(feature = "cache")]
cache: self.cache,
path_normalizer: self.path_normalizer,
})
}
}
#[derive(Clone)]
pub struct CedarAuthz {
authorizer: Arc<Authorizer>,
policy_set: Arc<RwLock<PolicySet>>,
config: Arc<CedarConfig>,
#[cfg(feature = "cache")]
cache: Option<Arc<dyn PolicyCache>>,
path_normalizer: Option<fn(&str) -> String>,
}
impl CedarAuthz {
pub fn builder(config: CedarConfig) -> CedarAuthzBuilder {
CedarAuthzBuilder::new(config)
}
pub async fn from_config(config: CedarConfig) -> Result<Self, Error> {
Self::builder(config).build().await
}
pub async fn from_app_config(config: &Config) -> Result<Option<Self>, Error> {
match &config.cedar {
Some(cedar_config) if cedar_config.enabled => {
Ok(Some(Self::from_config(cedar_config.clone()).await?))
}
_ => Ok(None),
}
}
pub async fn middleware(
State(authz): State<Self>,
request: Request<Body>,
next: Next,
) -> Result<Response, Error> {
if !authz.config.enabled {
return Ok(next.run(request).await);
}
let path = request.uri().path();
if path == "/health" || path == "/ready" {
return Ok(next.run(request).await);
}
let claims = request
.extensions()
.get::<Claims>()
.ok_or_else(|| {
Error::Unauthorized(
"Missing JWT claims. Ensure JWT middleware runs before Cedar middleware."
.to_string(),
)
})?
.clone();
let method = request.method().clone();
let principal = build_principal(&claims)?;
let action = build_action_http(&method, &request, authz.path_normalizer)?;
let context = build_context_http(request.headers(), &claims)?;
let resource = build_resource()?;
let cedar_request = CedarRequest::new(
principal.clone(),
action.clone(),
resource.clone(),
context,
None, )
.map_err(|e| Error::Internal(format!("Failed to build Cedar request: {}", e)))?;
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
if let Some(decision) = cache.get(&cedar_request).await? {
match decision {
Decision::Allow => return Ok(next.run(request).await),
Decision::Deny => {
return Err(Error::Forbidden("Access denied by policy".to_string()))
}
}
}
}
let policy_set = authz.policy_set.read().await;
let entities = build_entities(&claims)?;
let response = authz
.authorizer
.is_authorized(&cedar_request, &policy_set, &entities);
match response.decision() {
Decision::Allow => {
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
let _ = cache
.set(&cedar_request, Decision::Allow, authz.config.cache_ttl_secs)
.await;
}
Ok(next.run(request).await)
}
Decision::Deny => {
tracing::warn!(
principal = ?principal,
action = ?action,
"Cedar policy denied request"
);
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
let _ = cache
.set(&cedar_request, Decision::Deny, authz.config.cache_ttl_secs)
.await;
}
if authz.config.fail_open {
tracing::warn!("Cedar policy denied but fail_open=true, allowing request");
Ok(next.run(request).await)
} else {
Err(Error::Forbidden("Access denied by policy".to_string()))
}
}
}
}
pub async fn reload_policies(&self) -> Result<(), Error> {
let path = self.config.policy_path.clone();
let policies = tokio::task::spawn_blocking(move || std::fs::read_to_string(&path))
.await
.map_err(|e| Error::Internal(format!("Task join error: {}", e)))?
.map_err(|e| Error::Internal(format!("Failed to read policy file: {}", e)))?;
let new_policy_set: PolicySet = policies
.parse()
.map_err(|e| Error::Internal(format!("Failed to parse policies: {}", e)))?;
let mut policy_set = self.policy_set.write().await;
*policy_set = new_policy_set;
tracing::info!(
"Cedar policies reloaded from {}",
self.config.policy_path.display()
);
Ok(())
}
}
fn build_resource() -> Result<EntityUid, Error> {
r#"Resource::"default""#
.parse()
.map_err(|e| Error::Internal(format!("Failed to parse resource: {}", e)))
}
fn build_principal(claims: &Claims) -> Result<EntityUid, Error> {
let principal_str = if claims.is_user() {
format!(r#"User::"{}""#, claims.sub)
} else if claims.is_client() {
format!(r#"Client::"{}""#, claims.sub)
} else {
format!(r#"Principal::"{}""#, claims.sub)
};
let principal: EntityUid = principal_str
.parse()
.map_err(|e| Error::Internal(format!("Invalid principal: {}", e)))?;
Ok(principal)
}
fn build_action_http(
method: &Method,
request: &Request<Body>,
path_normalizer: Option<fn(&str) -> String>,
) -> Result<EntityUid, Error> {
let normalized_path = request
.extensions()
.get::<MatchedPath>()
.map(|matched| matched.as_str().to_string())
.unwrap_or_else(|| {
match path_normalizer {
Some(normalizer) => normalizer(request.uri().path()),
None => normalize_path_generic(request.uri().path()),
}
});
let action_str = format!(r#"Action::"{} {}""#, method, normalized_path);
let action: EntityUid = action_str
.parse()
.map_err(|e| Error::Internal(format!("Invalid action: {}", e)))?;
tracing::debug!(
method = %method,
path = %request.uri().path(),
normalized = %normalized_path,
action = %action,
"Built Cedar action"
);
Ok(action)
}
fn normalize_path_generic(path: &str) -> String {
let uuid_pattern =
regex::Regex::new(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}").unwrap();
let path = uuid_pattern.replace_all(path, "{id}");
let numeric_pattern = regex::Regex::new(r"/(\d+)(?:/|$)").unwrap();
let path = numeric_pattern.replace_all(&path, "/{id}");
path.to_string()
}
fn build_context_http(headers: &HeaderMap, claims: &Claims) -> Result<Context, Error> {
let mut context_map = serde_json::Map::new();
context_map.insert("roles".to_string(), json!(claims.roles));
context_map.insert("permissions".to_string(), json!(claims.perms));
if let Some(email) = &claims.email {
context_map.insert("email".to_string(), json!(email));
}
if let Some(username) = &claims.username {
context_map.insert("username".to_string(), json!(username));
}
let now = chrono::Utc::now();
context_map.insert(
"timestamp".to_string(),
json!({
"unix": now.timestamp(),
"hour": now.hour(),
"dayOfWeek": now.weekday().to_string(),
}),
);
if let Some(ip) = extract_client_ip(headers) {
context_map.insert("ip".to_string(), json!(ip));
}
if let Some(request_id) = headers.get("x-request-id").and_then(|v| v.to_str().ok()) {
context_map.insert("requestId".to_string(), json!(request_id));
}
if let Some(user_agent) = headers.get("user-agent").and_then(|v| v.to_str().ok()) {
context_map.insert("userAgent".to_string(), json!(user_agent));
}
Context::from_json_value(serde_json::Value::Object(context_map), None)
.map_err(|e| Error::Internal(format!("Failed to build context: {}", e)))
}
fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
return xff_str.split(',').next().map(|s| s.trim().to_string());
}
}
if let Some(xri) = headers.get("x-real-ip") {
if let Ok(xri_str) = xri.to_str() {
return Some(xri_str.to_string());
}
}
None
}
fn build_entities(claims: &Claims) -> Result<Entities, Error> {
use serde_json::Value;
let entity = json!({
"uid": {
"type": if claims.is_user() { "User" } else { "Client" },
"id": claims.sub.clone()
},
"attrs": {
"email": claims.email.clone().unwrap_or_default(),
"roles": claims.roles.clone(),
"permissions": claims.perms.clone(),
"sub": claims.sub.clone(),
},
"parents": []
});
Entities::from_json_value(Value::Array(vec![entity]), None)
.map_err(|e| Error::Internal(format!("Failed to build entities: {}", e)))
}
#[cfg(feature = "cache")]
#[async_trait::async_trait]
pub trait PolicyCache: Send + Sync {
async fn get(&self, request: &CedarRequest) -> Result<Option<Decision>, Error>;
async fn set(
&self,
request: &CedarRequest,
decision: Decision,
ttl_secs: u64,
) -> Result<(), Error>;
}
#[cfg(feature = "cache")]
pub struct RedisPolicyCache {
pool: deadpool_redis::Pool,
}
#[cfg(feature = "cache")]
impl RedisPolicyCache {
pub fn new(pool: deadpool_redis::Pool) -> Self {
Self { pool }
}
fn cache_key(request: &CedarRequest) -> String {
format!(
"cedar:authz:{}:{}:{}",
request
.principal()
.map(|p| p.to_string())
.unwrap_or_else(|| "None".to_string()),
request
.action()
.map(|a| a.to_string())
.unwrap_or_else(|| "None".to_string()),
request
.resource()
.map(|r| r.to_string())
.unwrap_or_else(|| "None".to_string()),
)
}
}
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl PolicyCache for RedisPolicyCache {
async fn get(&self, request: &CedarRequest) -> Result<Option<Decision>, Error> {
use deadpool_redis::redis::AsyncCommands;
let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::Internal(format!("Redis connection failed: {}", e)))?;
let key = Self::cache_key(request);
let value: Option<String> = conn
.get(&key)
.await
.map_err(|e| Error::Internal(format!("Redis GET failed: {}", e)))?;
Ok(value.and_then(|v| match v.as_str() {
"allow" => Some(Decision::Allow),
"deny" => Some(Decision::Deny),
_ => None,
}))
}
async fn set(
&self,
request: &CedarRequest,
decision: Decision,
ttl_secs: u64,
) -> Result<(), Error> {
use deadpool_redis::redis::AsyncCommands;
let mut conn = self
.pool
.get()
.await
.map_err(|e| Error::Internal(format!("Redis connection failed: {}", e)))?;
let key = Self::cache_key(request);
let value = match decision {
Decision::Allow => "allow",
Decision::Deny => "deny",
};
conn.set_ex::<_, _, ()>(&key, value, ttl_secs)
.await
.map_err(|e| Error::Internal(format!("Redis SETEX failed: {}", e)))?;
Ok(())
}
}
#[cfg(feature = "grpc")]
use std::future::Future;
#[cfg(feature = "grpc")]
use std::pin::Pin;
#[cfg(feature = "grpc")]
use std::task::{Context as TaskContext, Poll};
#[cfg(feature = "grpc")]
use tonic::{body::Body as TonicBody, Request as TonicRequest, Response as TonicResponse, Status};
#[cfg(feature = "grpc")]
use tower::{Layer, Service};
#[cfg(feature = "grpc")]
#[derive(Clone)]
pub struct CedarAuthzLayer {
authz: CedarAuthz,
}
#[cfg(feature = "grpc")]
impl CedarAuthzLayer {
pub fn new(authz: CedarAuthz) -> Self {
Self { authz }
}
}
#[cfg(feature = "grpc")]
impl<S> Layer<S> for CedarAuthzLayer {
type Service = CedarAuthzService<S>;
fn layer(&self, inner: S) -> Self::Service {
CedarAuthzService {
inner,
authz: self.authz.clone(),
}
}
}
#[cfg(feature = "grpc")]
#[derive(Clone)]
pub struct CedarAuthzService<S> {
inner: S,
authz: CedarAuthz,
}
#[cfg(feature = "grpc")]
impl<S, ReqBody> Service<TonicRequest<ReqBody>> for CedarAuthzService<S>
where
S: Service<TonicRequest<ReqBody>, Response = TonicResponse<TonicBody>, Error = Status>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: TonicRequest<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
let authz = self.authz.clone();
Box::pin(async move {
if !authz.config.enabled {
return inner.call(req).await;
}
let claims = req
.extensions()
.get::<Claims>()
.ok_or_else(|| {
Status::unauthenticated(
"Missing JWT claims. Ensure JWT interceptor runs before Cedar layer.",
)
})?
.clone();
let method_path = req
.metadata()
.get(":path")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
let principal = build_principal(&claims)
.map_err(|_| Status::internal("Failed to build principal"))?;
let action = build_action_grpc(&method_path)
.map_err(|_| Status::internal("Failed to build action"))?;
let context = build_context_grpc(req.metadata(), &claims)
.map_err(|_| Status::internal("Failed to build context"))?;
let resource: EntityUid = r#"Resource::"default""#
.parse()
.map_err(|_| Status::internal("Failed to parse resource"))?;
let cedar_request = CedarRequest::new(
principal.clone(),
action.clone(),
resource.clone(),
context,
None,
)
.map_err(|_| Status::internal("Failed to build Cedar request"))?;
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
if let Ok(Some(decision)) = cache.get(&cedar_request).await {
match decision {
Decision::Allow => return inner.call(req).await,
Decision::Deny => {
return Err(Status::permission_denied("Access denied by policy"))
}
}
}
}
let policy_set = authz.policy_set.read().await;
let entities = build_entities(&claims)
.map_err(|_| Status::internal("Failed to build entities"))?;
let response = authz
.authorizer
.is_authorized(&cedar_request, &policy_set, &entities);
match response.decision() {
Decision::Allow => {
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
let _ = cache
.set(&cedar_request, Decision::Allow, authz.config.cache_ttl_secs)
.await;
}
inner.call(req).await
}
Decision::Deny => {
tracing::warn!(
principal = ?principal,
action = ?action,
method = %method_path,
"Cedar policy denied gRPC request"
);
#[cfg(feature = "cache")]
if let Some(cache) = &authz.cache {
let _ = cache
.set(&cedar_request, Decision::Deny, authz.config.cache_ttl_secs)
.await;
}
if authz.config.fail_open {
tracing::warn!(
"Cedar policy denied but fail_open=true, allowing gRPC request"
);
inner.call(req).await
} else {
Err(Status::permission_denied("Access denied by policy"))
}
}
}
})
}
}
#[cfg(feature = "grpc")]
fn build_action_grpc(method_path: &str) -> Result<EntityUid, Error> {
let action_str = format!(r#"Action::"{}""#, method_path);
let action: EntityUid = action_str
.parse()
.map_err(|e| Error::Internal(format!("Invalid gRPC action: {}", e)))?;
Ok(action)
}
#[cfg(feature = "grpc")]
fn build_context_grpc(
metadata: &tonic::metadata::MetadataMap,
claims: &Claims,
) -> Result<Context, Error> {
let mut context_map = serde_json::Map::new();
context_map.insert("roles".to_string(), json!(claims.roles));
context_map.insert("permissions".to_string(), json!(claims.perms));
if let Some(email) = &claims.email {
context_map.insert("email".to_string(), json!(email));
}
if let Some(username) = &claims.username {
context_map.insert("username".to_string(), json!(username));
}
let now = chrono::Utc::now();
context_map.insert(
"timestamp".to_string(),
json!({
"unix": now.timestamp(),
"hour": now.hour(),
"dayOfWeek": now.weekday().to_string(),
}),
);
if let Some(ip) = extract_grpc_client_ip(metadata) {
context_map.insert("ip".to_string(), json!(ip));
}
if let Some(request_id) = metadata.get("x-request-id").and_then(|v| v.to_str().ok()) {
context_map.insert("requestId".to_string(), json!(request_id));
}
if let Some(user_agent) = metadata.get("user-agent").and_then(|v| v.to_str().ok()) {
context_map.insert("userAgent".to_string(), json!(user_agent));
}
Context::from_json_value(serde_json::Value::Object(context_map), None)
.map_err(|e| Error::Internal(format!("Failed to build gRPC context: {}", e)))
}
#[cfg(feature = "grpc")]
fn extract_grpc_client_ip(metadata: &tonic::metadata::MetadataMap) -> Option<String> {
if let Some(xff) = metadata.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
return xff_str.split(',').next().map(|s| s.trim().to_string());
}
}
if let Some(xri) = metadata.get("x-real-ip") {
if let Ok(xri_str) = xri.to_str() {
return Some(xri_str.to_string());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_path_generic() {
assert_eq!(
normalize_path_generic("/api/v1/users/123"),
"/api/v1/users/{id}"
);
assert_eq!(
normalize_path_generic("/api/v1/users/550e8400-e29b-41d4-a716-446655440000"),
"/api/v1/users/{id}"
);
assert_eq!(normalize_path_generic("/api/v1/users"), "/api/v1/users");
}
#[test]
fn test_build_principal() {
let claims = Claims {
sub: "user:123".to_string(),
email: Some("test@example.com".to_string()),
username: Some("testuser".to_string()),
roles: vec!["user".to_string()],
perms: vec![],
exp: 0,
iat: None,
jti: None,
iss: None,
aud: None,
};
let principal = build_principal(&claims).unwrap();
assert_eq!(principal.to_string(), r#"User::"user:123""#);
}
}