use serde_json::{json, Value};
use crate::runtime::{evaluate_runtime_event, RuntimeEvent, RuntimeVerdict};
pub const BLOCKED_ERROR_CODE: i64 = -32001;
#[derive(Debug, Clone, PartialEq)]
pub enum ProxyDecision {
Forward,
ForwardSuppressed { rule_id: String },
Block(Value),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FailOn {
#[default]
Block,
Warn,
Never,
}
#[derive(Debug, Clone, Default)]
pub struct ProxyPolicy {
pub fail_on: FailOn,
pub tool_overrides: Vec<(String, FailOn)>,
}
impl ProxyPolicy {
fn fail_on_for(&self, tool_name: &str) -> FailOn {
self.tool_overrides
.iter()
.find(|(name, _)| name == tool_name)
.map(|(_, fail_on)| *fail_on)
.unwrap_or(self.fail_on)
}
}
pub fn decide(request: &Value, policy: &ProxyPolicy) -> ProxyDecision {
if request.get("method").and_then(Value::as_str) != Some("tools/call") {
return ProxyDecision::Forward;
}
let id = request.get("id").cloned().unwrap_or(Value::Null);
let event = match tool_call_to_event(request) {
Some(event) => event,
None => {
return ProxyDecision::Block(blocked_error(
&id,
"block",
"AGENTSHIELD-RUNTIME-INVALID-INPUT",
))
}
};
let tool_name = event.tool_name.clone().unwrap_or_default();
let result = evaluate_runtime_event(event);
let fail_on = policy.fail_on_for(&tool_name);
let rule_id = || {
result
.findings
.iter()
.map(|finding| finding.rule_id.clone())
.next_back()
.unwrap_or_else(|| "AGENTSHIELD-RUNTIME-BLOCK".to_string())
};
let would_block = matches!(result.verdict, RuntimeVerdict::Block)
|| matches!(
(result.verdict, fail_on),
(RuntimeVerdict::Warn, FailOn::Warn)
);
match (would_block, fail_on) {
(true, FailOn::Never) => ProxyDecision::ForwardSuppressed { rule_id: rule_id() },
(true, _) => {
let verdict = match result.verdict {
RuntimeVerdict::Block => "block",
RuntimeVerdict::Warn => "warn",
RuntimeVerdict::Allow => "allow",
};
ProxyDecision::Block(blocked_error(&id, verdict, &rule_id()))
}
(false, _) => ProxyDecision::Forward,
}
}
fn tool_call_to_event(request: &Value) -> Option<RuntimeEvent> {
use crate::runtime::{RuntimeAction, RuntimeEventSource, RuntimeSchemaVersion};
let params = request.get("params")?;
let name = params.get("name")?.as_str()?.to_string();
let arguments = params
.get("arguments")
.cloned()
.unwrap_or_else(|| json!({}));
let string_arg = |key: &str| {
arguments
.get(key)
.and_then(Value::as_str)
.map(str::to_string)
};
let url = string_arg("url").or_else(|| first_metadata_string(&arguments));
Some(RuntimeEvent {
schema_version: RuntimeSchemaVersion::V1,
source: RuntimeEventSource::Mcp,
action: RuntimeAction::ToolCall,
tool_name: Some(name),
command: string_arg("command"),
url,
path: string_arg("path"),
arguments,
redacted: false,
})
}
fn first_metadata_string(value: &Value) -> Option<String> {
use crate::rules::builtin::metadata_ssrf::references_metadata_endpoint;
let mut stack = vec![value];
while let Some(node) = stack.pop() {
match node {
Value::String(text) => {
if references_metadata_endpoint(text) {
return Some(text.clone());
}
}
Value::Array(items) => stack.extend(items.iter()),
Value::Object(entries) => stack.extend(entries.values()),
_ => {}
}
}
None
}
fn blocked_error(id: &Value, verdict: &str, rule_id: &str) -> Value {
json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": BLOCKED_ERROR_CODE,
"message": "Blocked by AgentShield runtime guard",
"data": {
"verdict": verdict,
"rule_id": rule_id,
"schema_version": "v1"
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn tools_call(name: &str, arguments: Value) -> Value {
json!({
"jsonrpc": "2.0",
"id": 7,
"method": "tools/call",
"params": { "name": name, "arguments": arguments }
})
}
#[test]
fn non_tool_call_is_passed_through() {
let req = json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list"});
assert_eq!(
decide(&req, &ProxyPolicy::default()),
ProxyDecision::Forward
);
}
#[test]
fn benign_tool_call_is_forwarded() {
let req = tools_call("calculator.add", json!({"a": 1, "b": 2}));
assert_eq!(
decide(&req, &ProxyPolicy::default()),
ProxyDecision::Forward
);
}
#[test]
fn metadata_ssrf_tool_call_is_blocked() {
let req = tools_call(
"http.get",
json!({"url": "http://169.254.169.254/latest/meta-data/"}),
);
let decision = decide(&req, &ProxyPolicy::default());
match decision {
ProxyDecision::Block(err) => {
assert_eq!(err["error"]["code"], BLOCKED_ERROR_CODE);
assert_eq!(err["id"], 7); assert_eq!(err["error"]["data"]["verdict"], "block");
assert_eq!(
err["error"]["data"]["rule_id"],
"AGENTSHIELD-RUNTIME-METADATA-SSRF"
);
assert!(!err.to_string().contains("169.254.169.254"));
}
other => panic!("expected block, got {other:?}"),
}
}
#[test]
fn malformed_tool_call_fails_closed() {
let req = json!({"jsonrpc": "2.0", "id": 9, "method": "tools/call"});
match decide(&req, &ProxyPolicy::default()) {
ProxyDecision::Block(err) => {
assert_eq!(err["id"], 9);
assert_eq!(
err["error"]["data"]["rule_id"],
"AGENTSHIELD-RUNTIME-INVALID-INPUT"
);
}
other => panic!("expected fail-closed block, got {other:?}"),
}
}
#[test]
fn warn_verdict_forwards_by_default_but_blocks_under_strict_override() {
let req = tools_call(
"log.write",
json!({"token": "ghp_EXAMPLEEXAMPLEEXAMPLEEXAMPLE00"}),
);
assert_eq!(
decide(&req, &ProxyPolicy::default()),
ProxyDecision::Forward
);
let strict = ProxyPolicy {
fail_on: FailOn::Block,
tool_overrides: vec![("log.write".to_string(), FailOn::Warn)],
};
match decide(&req, &strict) {
ProxyDecision::Block(err) => assert_eq!(err["error"]["data"]["verdict"], "warn"),
other => panic!("expected warn-block under strict override, got {other:?}"),
}
}
#[test]
fn never_override_forwards_but_audits_suppressed_block() {
let req = tools_call("trusted.fetch", json!({"url": "http://169.254.169.254/"}));
let policy = ProxyPolicy {
fail_on: FailOn::Block,
tool_overrides: vec![("trusted.fetch".to_string(), FailOn::Never)],
};
match decide(&req, &policy) {
ProxyDecision::ForwardSuppressed { rule_id } => {
assert_eq!(rule_id, "AGENTSHIELD-RUNTIME-METADATA-SSRF");
}
other => panic!("expected forward-suppressed, got {other:?}"),
}
}
#[test]
fn metadata_endpoint_in_nested_argument_is_blocked() {
let req = tools_call(
"http.get",
json!({"req": {"target": {"url": "http://169.254.169.254/latest/meta-data/"}}}),
);
match decide(&req, &ProxyPolicy::default()) {
ProxyDecision::Block(err) => {
assert_eq!(
err["error"]["data"]["rule_id"],
"AGENTSHIELD-RUNTIME-METADATA-SSRF"
);
assert!(!err.to_string().contains("169.254.169.254"));
}
other => panic!("expected nested metadata to block, got {other:?}"),
}
}
#[test]
fn metadata_endpoint_in_string_arguments_is_blocked() {
let req = json!({
"jsonrpc": "2.0", "id": 4, "method": "tools/call",
"params": { "name": "fetch", "arguments": "http://169.254.169.254/" }
});
match decide(&req, &ProxyPolicy::default()) {
ProxyDecision::Block(_) => {}
other => panic!("expected string-arguments metadata to block, got {other:?}"),
}
}
#[test]
fn metadata_in_array_argument_is_blocked() {
let req = tools_call(
"batch",
json!({"urls": ["https://ok.example.com", "http://169.254.169.254/"]}),
);
match decide(&req, &ProxyPolicy::default()) {
ProxyDecision::Block(_) => {}
other => panic!("expected array metadata to block, got {other:?}"),
}
}
}