use std::sync::Arc;
use aa_proto::assembly::common::v1::{ActionType, AgentId as ProtoAgentId, Decision};
use aa_proto::assembly::policy::v1::{
action_context::Action, ActionContext, CheckActionRequest, CheckActionResponse, RedactInstructions, ToolCallContext,
};
use aa_runtime::gateway_client::GatewayClient;
use tokio::sync::Mutex;
use crate::intercept::mcp::McpToolCall;
pub const MCP_TOOL_SOURCE: &str = "mcp";
const PROXY_AGENT_ID: &str = "aa-proxy";
#[derive(Debug, Clone, PartialEq)]
pub enum McpDecision {
Allow,
Deny { reason: String },
Redact { instructions: RedactInstructions },
}
pub fn build_check_action_request(
call: &McpToolCall,
target_url: &str,
trace_id: &str,
span_id: &str,
) -> CheckActionRequest {
let args_json = serde_json::to_vec(&call.arguments).unwrap_or_default();
CheckActionRequest {
agent_id: Some(ProtoAgentId {
org_id: String::new(),
team_id: String::new(),
agent_id: PROXY_AGENT_ID.into(),
}),
credential_token: String::new(),
trace_id: trace_id.into(),
span_id: span_id.into(),
action_type: ActionType::ToolCall as i32,
context: Some(ActionContext {
action: Some(Action::ToolCall(ToolCallContext {
tool_name: call.tool_name.clone(),
tool_source: MCP_TOOL_SOURCE.into(),
args_json,
target_url: target_url.into(),
})),
}),
caller_agent_id: None,
}
}
pub fn decision_from_response(response: &CheckActionResponse) -> McpDecision {
match Decision::try_from(response.decision) {
Ok(Decision::Allow) => McpDecision::Allow,
Ok(Decision::Deny) => McpDecision::Deny {
reason: response.reason.clone(),
},
Ok(Decision::Redact) => McpDecision::Redact {
instructions: response.redact.clone().unwrap_or_default(),
},
Ok(Decision::Pending) => McpDecision::Deny {
reason: format!(
"policy returned PENDING (approval queue {:?}) — proxy cannot block on human approval",
response.approval_id,
),
},
Ok(Decision::Unspecified) | Err(_) => McpDecision::Deny {
reason: format!("unrecognised policy decision code {}", response.decision),
},
}
}
pub async fn evaluate_mcp_call(
gateway: &Arc<Mutex<GatewayClient>>,
call: &McpToolCall,
target_url: &str,
trace_id: &str,
span_id: &str,
) -> anyhow::Result<McpDecision> {
let request = build_check_action_request(call, target_url, trace_id, span_id);
let response = {
let mut client = gateway.lock().await;
client
.check_action(request)
.await
.map_err(|e| anyhow::anyhow!("PolicyService.CheckAction failed: {e}"))?
};
Ok(decision_from_response(&response))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn sample_call() -> McpToolCall {
McpToolCall {
tool_name: "read_file".into(),
arguments: json!({ "path": "/etc/passwd" }),
}
}
#[test]
fn build_request_populates_tool_call_context_fields() {
let call = sample_call();
let req = build_check_action_request(&call, "https://mcp.example.com/tools", "trace-abc", "span-1");
assert_eq!(req.action_type, ActionType::ToolCall as i32);
assert_eq!(req.trace_id, "trace-abc");
assert_eq!(req.span_id, "span-1");
let action = req.context.expect("context").action.expect("action");
let tool = match action {
Action::ToolCall(t) => t,
other => panic!("expected ToolCall action, got {other:?}"),
};
assert_eq!(tool.tool_name, "read_file");
assert_eq!(tool.tool_source, MCP_TOOL_SOURCE);
assert_eq!(tool.target_url, "https://mcp.example.com/tools");
let parsed: serde_json::Value = serde_json::from_slice(&tool.args_json).expect("args_json must be valid JSON");
assert_eq!(parsed, json!({ "path": "/etc/passwd" }));
}
fn response_with(decision: Decision, reason: &str) -> CheckActionResponse {
CheckActionResponse {
decision: decision as i32,
reason: reason.into(),
..Default::default()
}
}
#[test]
fn decision_allow_maps_to_mcp_allow() {
let resp = response_with(Decision::Allow, "ok");
assert_eq!(decision_from_response(&resp), McpDecision::Allow);
}
#[test]
fn decision_deny_maps_to_mcp_deny_with_reason() {
let resp = response_with(Decision::Deny, "tool_name read_file blocked on /etc paths");
match decision_from_response(&resp) {
McpDecision::Deny { reason } => {
assert_eq!(reason, "tool_name read_file blocked on /etc paths");
}
other => panic!("expected Deny, got {other:?}"),
}
}
#[test]
fn decision_redact_maps_to_mcp_redact_with_instructions() {
let resp = CheckActionResponse {
decision: Decision::Redact as i32,
redact: Some(RedactInstructions::default()),
..Default::default()
};
assert!(matches!(decision_from_response(&resp), McpDecision::Redact { .. }));
}
#[test]
fn decision_pending_downgrades_to_deny_at_proxy_layer() {
let mut resp = response_with(Decision::Pending, "");
resp.approval_id = "queue-7".into();
match decision_from_response(&resp) {
McpDecision::Deny { reason } => {
assert!(
reason.contains("queue-7") || reason.contains("PENDING"),
"deny reason should explain the downgrade, got: {reason}",
);
}
other => panic!("expected Deny, got {other:?}"),
}
}
#[test]
fn decision_unspecified_downgrades_to_deny() {
let resp = response_with(Decision::Unspecified, "");
assert!(matches!(decision_from_response(&resp), McpDecision::Deny { .. }));
}
#[test]
fn unknown_decision_code_downgrades_to_deny() {
let resp = CheckActionResponse {
decision: 9999,
..Default::default()
};
match decision_from_response(&resp) {
McpDecision::Deny { reason } => {
assert!(
reason.contains("9999"),
"reason should name the unknown code, got: {reason}"
);
}
other => panic!("expected Deny, got {other:?}"),
}
}
}