use super::policy::*;
use serde_json::{json, Value};
use std::collections::HashMap;
fn create_v2_policy(schemas: HashMap<String, Value>) -> McpPolicy {
McpPolicy {
version: "2.0".to_string(),
schemas,
enforcement: EnforcementSettings::default(),
..Default::default()
}
}
#[test]
fn test_v2_schema_validation_allow() {
let mut schemas = HashMap::new();
schemas.insert(
"read_file".to_string(),
json!({
"type": "object",
"properties": {
"path": { "type": "string", "pattern": "^/safe/.*" }
},
"required": ["path"]
}),
);
let policy = create_v2_policy(schemas);
let mut state = PolicyState::default();
let args = json!({ "path": "/safe/test.txt" });
let decision = policy.evaluate("read_file", &args, &mut state, None);
assert_eq!(decision, PolicyDecision::Allow);
}
#[test]
fn test_v2_schema_validation_deny() {
let mut schemas = HashMap::new();
schemas.insert(
"read_file".to_string(),
json!({
"type": "object",
"properties": {
"path": { "type": "string", "pattern": "^/safe/.*" }
},
"required": ["path"]
}),
);
let policy = create_v2_policy(schemas);
let mut state = PolicyState::default();
let args = json!({ "path": "/unsafe/hack.sh" });
let decision = policy.evaluate("read_file", &args, &mut state, None);
if let PolicyDecision::Deny { code, .. } = decision {
assert_eq!(code, "E_ARG_SCHEMA");
} else {
panic!("Expected Deny, got {:?}", decision);
}
let args_missing = json!({});
let decision_missing = policy.evaluate("read_file", &args_missing, &mut state, None);
if let PolicyDecision::Deny { code, .. } = decision_missing {
assert_eq!(code, "E_ARG_SCHEMA");
} else {
panic!("Expected Deny for missing arg, got {:?}", decision_missing);
}
}
#[test]
fn test_v1_migration_correctness() {
let yaml = r#"
version: "1.0"
constraints:
- tool: read_file
params:
path:
matches: "^/safe/.*"
"#;
let mut policy: McpPolicy = serde_yaml::from_str(yaml).unwrap();
policy.migrate_constraints_to_schemas();
assert!(policy.schemas.contains_key("read_file"));
let schema = policy.schemas.get("read_file").unwrap();
let path_pattern = schema
.get("properties")
.and_then(|p| p.get("path"))
.and_then(|p| p.get("pattern"))
.and_then(|v| v.as_str())
.expect("Missing pattern in migrated schema");
assert_eq!(path_pattern, "^/safe/.*");
let required = schema
.get("required")
.and_then(|v| v.as_array())
.expect("Missing required array");
assert!(required.iter().any(|v| v.as_str() == Some("path")));
let mut state = PolicyState::default();
let args_ok = json!({ "path": "/safe/file" });
assert_eq!(
policy.evaluate("read_file", &args_ok, &mut state, None),
PolicyDecision::Allow
);
let args_bad = json!({ "path": "/unsafe/file" });
match policy.evaluate("read_file", &args_bad, &mut state, None) {
PolicyDecision::Deny { code, .. } => assert_eq!(code, "E_ARG_SCHEMA"),
_ => panic!("Migrated policy failed to deny invalid arg"),
}
}
#[test]
fn test_enforcement_modes() {
let mut policy = McpPolicy::default();
policy.enforcement.unconstrained_tools = UnconstrainedMode::Warn;
let mut state = PolicyState::default();
let decision = policy.evaluate("unknown_tool", &json!({}), &mut state, None);
if let PolicyDecision::AllowWithWarning { code, .. } = decision {
assert_eq!(code, "E_TOOL_UNCONSTRAINED");
} else {
panic!("Expected AllowWithWarning, got {:?}", decision);
}
policy.enforcement.unconstrained_tools = UnconstrainedMode::Deny;
let decision_deny = policy.evaluate("unknown_tool", &json!({}), &mut state, None);
if let PolicyDecision::Deny { code, .. } = decision_deny {
assert_eq!(code, "E_TOOL_UNCONSTRAINED");
} else {
panic!("Expected Deny, got {:?}", decision_deny);
}
policy.enforcement.unconstrained_tools = UnconstrainedMode::Allow;
let decision_allow = policy.evaluate("unknown_tool", &json!({}), &mut state, None);
assert_eq!(decision_allow, PolicyDecision::Allow);
}
#[test]
fn test_defs_resolution() {
let mut schemas = HashMap::new();
let defs = json!({
"path_pattern": { "type": "string", "pattern": "^/safe/.*" }
});
schemas.insert("$defs".to_string(), defs);
let tool_schema = json!({
"type": "object",
"properties": {
"path": { "$ref": "#/$defs/path_pattern" }
},
"required": ["path"]
});
schemas.insert("refined_tool".to_string(), tool_schema);
let policy = create_v2_policy(schemas);
let mut state = PolicyState::default();
let args_ok = json!({ "path": "/safe/ok" });
assert_eq!(
policy.evaluate("refined_tool", &args_ok, &mut state, None),
PolicyDecision::Allow
);
let args_bad = json!({ "path": "/unsafe/bad" });
if let PolicyDecision::Deny { code, .. } =
policy.evaluate("refined_tool", &args_bad, &mut state, None)
{
assert_eq!(code, "E_ARG_SCHEMA");
} else {
panic!("Expected Deny for ref violation");
}
}
#[test]
fn test_tool_integrity_drift() {
use crate::mcp::identity::ToolIdentity;
let mut policy = McpPolicy::default();
let tool_name = "test_tool";
let pinned_id = ToolIdentity::new("srv1", tool_name, &None, &Some("old desc".into()));
let runtime_id = ToolIdentity::new("srv1", tool_name, &None, &Some("new desc".into()));
policy
.tool_pins
.insert(tool_name.to_string(), pinned_id.clone());
let mut state = PolicyState::default();
let tool_args = &json!({});
let decision = policy.evaluate(tool_name, tool_args, &mut state, None);
assert!(matches!(decision, PolicyDecision::AllowWithWarning { .. }));
let decision_fail = policy.evaluate(tool_name, &json!({}), &mut state, Some(&runtime_id));
if let PolicyDecision::Deny { code, .. } = decision_fail {
assert_eq!(code, "E_TOOL_DRIFT");
} else {
panic!("Expected E_TOOL_DRIFT, got {:?}", decision_fail);
}
}
#[test]
fn test_is_v1_format() {
let v1 = McpPolicy {
version: "1.0".to_string(),
..Default::default()
};
assert!(v1.is_v1_format());
let v1_implied = McpPolicy {
constraints: vec![ConstraintRule {
tool: "t".into(),
params: std::collections::BTreeMap::new(),
}],
..Default::default()
};
assert!(v1_implied.is_v1_format());
let v2 = McpPolicy {
version: "2.0".to_string(),
..Default::default()
};
assert!(!v2.is_v1_format());
let empty = McpPolicy::default();
assert!(!empty.is_v1_format());
}
#[test]
#[serial_test::serial]
#[allow(unsafe_code)]
fn test_strict_deprecation_env_var() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path();
std::fs::write(path, "version: '1.0'\nconstraints: []").unwrap();
unsafe {
std::env::remove_var("ASSAY_STRICT_DEPRECATIONS");
}
let res = McpPolicy::from_file(path);
assert!(res.is_ok());
unsafe {
std::env::set_var("ASSAY_STRICT_DEPRECATIONS", "1");
}
let res_strict = McpPolicy::from_file(path);
assert!(res_strict.is_err());
assert!(res_strict.unwrap_err().to_string().contains("Strict mode"));
unsafe {
std::env::remove_var("ASSAY_STRICT_DEPRECATIONS");
}
}