use car_ir::{Action, ActionType, ToolSchema};
use car_policy::PolicyEngine;
use car_state::StateStore;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthzStage {
ToolExists,
Capability,
Permission,
Restriction,
Policy,
Validation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthzDecision {
Allow,
AskUser,
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthzResult {
pub decision: AuthzDecision,
pub stage: AuthzStage,
pub reason_code: String,
pub explanation: String,
pub stage_results: Vec<StageResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageResult {
pub stage: AuthzStage,
pub decision: AuthzDecision,
pub reason: String,
}
impl AuthzResult {
pub fn allowed(stage: AuthzStage) -> Self {
Self {
decision: AuthzDecision::Allow,
stage,
reason_code: "allowed".to_string(),
explanation: "All authorization checks passed".to_string(),
stage_results: Vec::new(),
}
}
pub fn denied(stage: AuthzStage, reason_code: &str, explanation: &str) -> Self {
Self {
decision: AuthzDecision::Deny,
stage,
reason_code: reason_code.to_string(),
explanation: explanation.to_string(),
stage_results: Vec::new(),
}
}
pub fn ask_user(stage: AuthzStage, reason_code: &str, explanation: &str) -> Self {
Self {
decision: AuthzDecision::AskUser,
stage,
reason_code: reason_code.to_string(),
explanation: explanation.to_string(),
stage_results: Vec::new(),
}
}
fn with_stages(mut self, stages: Vec<StageResult>) -> Self {
self.stage_results = stages;
self
}
}
pub struct Restriction {
pub name: String,
pub description: String,
check: Box<dyn Fn(&Action) -> Option<String> + Send + Sync>,
}
impl Restriction {
pub fn new<F>(name: &str, description: &str, check: F) -> Self
where
F: Fn(&Action) -> Option<String> + Send + Sync + 'static,
{
Self {
name: name.to_string(),
description: description.to_string(),
check: Box::new(check),
}
}
fn check(&self, action: &Action) -> Option<String> {
(self.check)(action)
}
}
#[async_trait::async_trait]
pub trait PermissionHandler: Send + Sync {
async fn check(&self, tool_name: &str, action: &Action) -> AuthzDecision;
}
pub struct AllowAllPermissions;
#[async_trait::async_trait]
impl PermissionHandler for AllowAllPermissions {
async fn check(&self, _tool_name: &str, _action: &Action) -> AuthzDecision {
AuthzDecision::Allow
}
}
pub struct AuthzPipeline {
restrictions: Vec<Restriction>,
permission_handler: Box<dyn PermissionHandler>,
}
impl AuthzPipeline {
pub fn new() -> Self {
Self {
restrictions: Vec::new(),
permission_handler: Box::new(AllowAllPermissions),
}
}
pub fn add_restriction(&mut self, restriction: Restriction) {
self.restrictions.push(restriction);
}
pub fn set_permission_handler(&mut self, handler: Box<dyn PermissionHandler>) {
self.permission_handler = handler;
}
pub async fn authorize(
&self,
action: &Action,
tools: &HashMap<String, ToolSchema>,
capabilities: Option<&crate::capabilities::CapabilitySet>,
policies: &PolicyEngine,
state: &StateStore,
) -> AuthzResult {
let mut stages = Vec::new();
if let Some(tool_name) = &action.tool {
if action.action_type == ActionType::ToolCall && !tools.contains_key(tool_name) {
stages.push(StageResult {
stage: AuthzStage::ToolExists,
decision: AuthzDecision::Deny,
reason: format!("tool '{}' not registered", tool_name),
});
return AuthzResult::denied(
AuthzStage::ToolExists,
"tool_not_found",
&format!("Tool '{}' is not registered", tool_name),
)
.with_stages(stages);
}
}
stages.push(StageResult {
stage: AuthzStage::ToolExists,
decision: AuthzDecision::Allow,
reason: "tool registered".to_string(),
});
if let Some(caps) = capabilities {
if let Some(tool_name) = &action.tool {
if !caps.tool_allowed(tool_name) {
stages.push(StageResult {
stage: AuthzStage::Capability,
decision: AuthzDecision::Deny,
reason: format!("tool '{}' not in capability set", tool_name),
});
return AuthzResult::denied(
AuthzStage::Capability,
"capability_denied",
&format!("Tool '{}' denied by capability set", tool_name),
)
.with_stages(stages);
}
}
}
stages.push(StageResult {
stage: AuthzStage::Capability,
decision: AuthzDecision::Allow,
reason: "capability check passed".to_string(),
});
if let Some(tool_name) = &action.tool {
let perm = self.permission_handler.check(tool_name, action).await;
stages.push(StageResult {
stage: AuthzStage::Permission,
decision: perm,
reason: format!("permission handler returned {:?}", perm),
});
if perm == AuthzDecision::Deny {
return AuthzResult::denied(
AuthzStage::Permission,
"permission_denied",
&format!("Permission denied for tool '{}'", tool_name),
)
.with_stages(stages);
}
if perm == AuthzDecision::AskUser {
return AuthzResult::ask_user(
AuthzStage::Permission,
"approval_required",
&format!("Tool '{}' requires user approval", tool_name),
)
.with_stages(stages);
}
} else {
stages.push(StageResult {
stage: AuthzStage::Permission,
decision: AuthzDecision::Allow,
reason: "no tool name, skipped".to_string(),
});
}
for restriction in &self.restrictions {
if let Some(reason) = restriction.check(action) {
stages.push(StageResult {
stage: AuthzStage::Restriction,
decision: AuthzDecision::Deny,
reason: reason.clone(),
});
return AuthzResult::denied(
AuthzStage::Restriction,
&format!("restriction_{}", restriction.name),
&format!("Permanent restriction '{}': {}", restriction.name, reason),
)
.with_stages(stages);
}
}
stages.push(StageResult {
stage: AuthzStage::Restriction,
decision: AuthzDecision::Allow,
reason: "all restrictions passed".to_string(),
});
let violations = policies.check(action, state);
if !violations.is_empty() {
let reasons: Vec<String> = violations
.iter()
.map(|v| format!("{}: {}", v.policy_name, v.reason))
.collect();
stages.push(StageResult {
stage: AuthzStage::Policy,
decision: AuthzDecision::Deny,
reason: reasons.join("; "),
});
return AuthzResult::denied(
AuthzStage::Policy,
"policy_violation",
&format!("Policy violations: {}", reasons.join("; ")),
)
.with_stages(stages);
}
stages.push(StageResult {
stage: AuthzStage::Policy,
decision: AuthzDecision::Allow,
reason: "all policies passed".to_string(),
});
stages.push(StageResult {
stage: AuthzStage::Validation,
decision: AuthzDecision::Allow,
reason: "validation deferred".to_string(),
});
AuthzResult::allowed(AuthzStage::Validation).with_stages(stages)
}
}
impl Default for AuthzPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use car_ir::{Action, ActionType, FailureBehavior, ToolSchema};
fn test_action(tool: &str) -> Action {
Action {
id: "test-1".to_string(),
action_type: ActionType::ToolCall,
tool: Some(tool.to_string()),
parameters: Default::default(),
preconditions: vec![],
expected_effects: Default::default(),
state_dependencies: Vec::new(),
idempotent: false,
max_retries: 3,
failure_behavior: FailureBehavior::Abort,
timeout_ms: None,
metadata: Default::default(),
}
}
fn test_tools() -> HashMap<String, ToolSchema> {
let mut m = HashMap::new();
m.insert(
"read".to_string(),
ToolSchema {
name: "read".to_string(),
description: "Read a file".to_string(),
parameters: serde_json::json!({"type": "object"}),
returns: None,
idempotent: true,
cache_ttl_secs: None,
rate_limit: None,
},
);
m
}
#[tokio::test]
async fn test_allow_registered_tool() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("read"), &tools, None, &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::Allow);
assert_eq!(result.stage_results.len(), 6);
}
#[tokio::test]
async fn test_deny_unregistered_tool() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("delete"), &tools, None, &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::Deny);
assert_eq!(result.stage, AuthzStage::ToolExists);
assert_eq!(result.reason_code, "tool_not_found");
}
#[tokio::test]
async fn test_capability_denial() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let mut caps = crate::capabilities::CapabilitySet::default();
caps.denied_tools.insert("read".to_string());
let result = pipeline
.authorize(&test_action("read"), &tools, Some(&caps), &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::Deny);
assert_eq!(result.stage, AuthzStage::Capability);
}
#[tokio::test]
async fn test_restriction() {
let mut pipeline = AuthzPipeline::new();
pipeline.add_restriction(Restriction::new("no_read", "Never allow read", |action| {
if action.tool.as_deref() == Some("read") {
Some("reads are restricted".to_string())
} else {
None
}
}));
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("read"), &tools, None, &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::Deny);
assert_eq!(result.stage, AuthzStage::Restriction);
}
#[tokio::test]
async fn test_policy_violation() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let state = StateStore::new();
let mut policies = PolicyEngine::new();
policies.register(
"deny_all",
Box::new(|_action: &Action, _state: &StateStore| Some("denied by test".to_string())),
"test policy",
);
let result = pipeline
.authorize(&test_action("read"), &tools, None, &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::Deny);
assert_eq!(result.stage, AuthzStage::Policy);
}
#[tokio::test]
async fn test_ask_user_permission() {
struct AskPermissions;
#[async_trait::async_trait]
impl PermissionHandler for AskPermissions {
async fn check(&self, _tool_name: &str, _action: &Action) -> AuthzDecision {
AuthzDecision::AskUser
}
}
let mut pipeline = AuthzPipeline::new();
pipeline.set_permission_handler(Box::new(AskPermissions));
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("read"), &tools, None, &policies, &state)
.await;
assert_eq!(result.decision, AuthzDecision::AskUser);
assert_eq!(result.stage, AuthzStage::Permission);
assert_eq!(result.reason_code, "approval_required");
}
#[tokio::test]
async fn test_stage_results_trace() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("read"), &tools, None, &policies, &state)
.await;
let stage_names: Vec<AuthzStage> = result.stage_results.iter().map(|s| s.stage).collect();
assert_eq!(
stage_names,
vec![
AuthzStage::ToolExists,
AuthzStage::Capability,
AuthzStage::Permission,
AuthzStage::Restriction,
AuthzStage::Policy,
AuthzStage::Validation,
]
);
}
#[tokio::test]
async fn test_short_circuit_on_deny() {
let pipeline = AuthzPipeline::new();
let tools = test_tools();
let policies = PolicyEngine::new();
let state = StateStore::new();
let result = pipeline
.authorize(&test_action("nonexistent"), &tools, None, &policies, &state)
.await;
assert_eq!(result.stage_results.len(), 1);
assert_eq!(result.stage_results[0].stage, AuthzStage::ToolExists);
}
#[tokio::test]
async fn test_serde_roundtrip() {
let result = AuthzResult::denied(AuthzStage::Policy, "policy_violation", "Test violation");
let json = serde_json::to_string(&result).unwrap();
let roundtripped: AuthzResult = serde_json::from_str(&json).unwrap();
assert_eq!(roundtripped.decision, AuthzDecision::Deny);
assert_eq!(roundtripped.stage, AuthzStage::Policy);
assert_eq!(roundtripped.reason_code, "policy_violation");
}
}