use super::identity::ToolIdentity;
use super::jsonrpc::{
ContentItem, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolResultBody,
};
use super::tool_match::MatchBasis;
use super::tool_taxonomy::ToolTaxonomy;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::sync::{Arc, OnceLock};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpPolicy {
#[serde(default)]
pub version: String,
#[serde(default)]
pub name: String,
#[serde(default)]
pub tools: ToolPolicy,
#[serde(default)]
pub allow: Option<Vec<String>>,
#[serde(default)]
pub deny: Option<Vec<String>>,
#[serde(default)]
pub schemas: HashMap<String, Value>,
#[serde(default, deserialize_with = "deserialize_constraints")]
pub constraints: Vec<ConstraintRule>,
#[serde(default)]
pub enforcement: EnforcementSettings,
#[serde(default)]
pub limits: Option<GlobalLimits>,
#[serde(default)]
pub signatures: Option<SignaturePolicy>,
#[serde(default)]
pub tool_pins: HashMap<String, ToolIdentity>,
#[serde(default, flatten)]
pub tool_taxonomy: ToolTaxonomy,
#[serde(default)]
pub discovery: Option<DiscoveryConfig>,
#[serde(default)]
pub runtime_monitor: Option<RuntimeMonitorConfig>,
#[serde(default)]
pub kill_switch: Option<KillSwitchConfig>,
#[serde(skip)]
pub(crate) compiled: Arc<OnceLock<HashMap<String, Arc<jsonschema::Validator>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnforcementSettings {
#[serde(default = "default_unconstrained")]
pub unconstrained_tools: UnconstrainedMode,
}
impl Default for EnforcementSettings {
fn default() -> Self {
Self {
unconstrained_tools: UnconstrainedMode::Warn,
}
}
}
fn default_unconstrained() -> UnconstrainedMode {
UnconstrainedMode::Warn
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum UnconstrainedMode {
#[default]
Warn,
Deny,
Allow,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SignaturePolicy {
#[serde(default)]
pub check_descriptions: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GlobalLimits {
pub max_requests_total: Option<u64>,
pub max_tool_calls_total: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolPolicy {
pub allow: Option<Vec<String>>,
pub deny: Option<Vec<String>>,
#[serde(default)]
pub allow_classes: Option<Vec<String>>,
#[serde(default)]
pub deny_classes: Option<Vec<String>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct PolicyMatchMetadata {
pub tool_classes: Vec<String>,
pub matched_tool_classes: Vec<String>,
pub match_basis: MatchBasis,
pub matched_rule: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PolicyEvaluation {
pub decision: PolicyDecision,
pub metadata: PolicyMatchMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintRule {
pub tool: String,
pub params: BTreeMap<String, ConstraintParam>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintParam {
#[serde(default)]
pub matches: Option<String>,
}
pub use super::runtime_features::{
ActionLevel, DiscoveryActions, DiscoveryConfig, DiscoveryMethod, KillMode, KillSwitchConfig,
KillTrigger, MonitorAction, MonitorMatch, MonitorProvider, MonitorRule, MonitorRuleType,
RuntimeMonitorConfig,
};
#[derive(Debug, Default)]
pub struct PolicyState {
pub requests_count: u64,
pub tool_calls_count: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PolicyDecision {
Allow,
AllowWithWarning {
tool: String,
code: String,
reason: String,
},
Deny {
tool: String,
code: String,
reason: String,
contract: Value,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum ConstraintsCompat {
List(Vec<ConstraintRule>),
Map(BTreeMap<String, BTreeMap<String, InputParamConstraint>>),
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum InputParamConstraint {
Direct(String),
Object(ConstraintParam),
}
fn deserialize_constraints<'de, D>(d: D) -> Result<Vec<ConstraintRule>, D::Error>
where
D: serde::Deserializer<'de>,
{
let c = Option::<ConstraintsCompat>::deserialize(d)?;
let out = match c {
None => vec![],
Some(ConstraintsCompat::List(v)) => v,
Some(ConstraintsCompat::Map(m)) => m
.into_iter()
.map(|(tool, params)| {
let new_params = params
.into_iter()
.map(|(arg, val)| {
let param = match val {
InputParamConstraint::Direct(s) => ConstraintParam { matches: Some(s) },
InputParamConstraint::Object(o) => o,
};
(arg, param)
})
.collect();
ConstraintRule {
tool,
params: new_params,
}
})
.collect(),
};
Ok(out)
}
fn matches_tool_pattern(tool_name: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if !pattern.contains('*') {
return tool_name == pattern;
}
let starts_star = pattern.starts_with('*');
let ends_star = pattern.ends_with('*');
match (starts_star, ends_star) {
(true, true) => {
let inner = pattern.trim_matches('*');
if inner.is_empty() {
true
} else {
tool_name.contains(inner)
}
}
(false, true) => {
let prefix = pattern.trim_end_matches('*');
!prefix.is_empty() && tool_name.starts_with(prefix)
}
(true, false) => {
let suffix = pattern.trim_start_matches('*');
!suffix.is_empty() && tool_name.ends_with(suffix)
}
(false, false) => tool_name == pattern,
}
}
impl McpPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn from_file(path: &std::path::Path) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let mut unknown = Vec::new();
let de = serde_yaml::Deserializer::from_str(&content);
let mut policy: McpPolicy = serde_ignored::deserialize(de, |path| {
unknown.push(path.to_string());
})
.map_err(anyhow::Error::from)?;
if !unknown.is_empty() {
tracing::warn!(?unknown, "Unknown fields in policy (ignored)");
}
if policy.is_v1_format() {
if std::env::var("ASSAY_STRICT_DEPRECATIONS").ok().as_deref() == Some("1") {
anyhow::bail!("Strict mode: v1 policy format (constraints) is not allowed.");
}
emit_deprecation_warning();
}
policy.normalize_legacy_shapes();
if !policy.constraints.is_empty() {
policy.migrate_constraints_to_schemas();
}
policy.validate()?;
Ok(policy)
}
pub fn validate(&self) -> anyhow::Result<()> {
if let (Some(rm), Some(ks)) = (&self.runtime_monitor, &self.kill_switch) {
let rule_ids: std::collections::HashSet<&str> =
rm.rules.iter().map(|r| r.id.as_str()).collect();
for t in &ks.triggers {
if !rule_ids.contains(t.on_rule.as_str()) {
anyhow::bail!(
"kill_switch.triggers references unknown rule id: {}",
t.on_rule
);
}
}
}
Ok(())
}
pub fn is_v1_format(&self) -> bool {
!self.constraints.is_empty() || self.version == "1.0"
}
pub fn normalize_legacy_shapes(&mut self) {
if let Some(allow) = self.allow.take() {
let mut current = self.tools.allow.take().unwrap_or_default();
current.extend(allow);
self.tools.allow = Some(current);
}
if let Some(deny) = self.deny.take() {
let mut current = self.tools.deny.take().unwrap_or_default();
current.extend(deny);
self.tools.deny = Some(current);
}
}
pub fn migrate_constraints_to_schemas(&mut self) {
for constraint in std::mem::take(&mut self.constraints) {
let schema = constraint_to_schema(&constraint);
self.schemas.insert(constraint.tool.clone(), schema);
}
if self.version.is_empty() || self.version == "1.0" {
self.version = "2.0".to_string();
}
}
fn compiled_schemas(&self) -> &HashMap<String, Arc<jsonschema::Validator>> {
self.compiled.get_or_init(|| self.compile_all_schemas())
}
pub fn compile_all_schemas(&self) -> HashMap<String, Arc<jsonschema::Validator>> {
let root_defs = self.schemas.get("$defs").cloned();
let mut compiled = HashMap::new();
for (tool_name, schema) in &self.schemas {
if tool_name.starts_with('$') {
continue;
}
let mut schema_to_compile = schema.clone();
if let Some(defs) = &root_defs {
if let Value::Object(map) = &mut schema_to_compile {
map.insert("$defs".to_string(), defs.clone());
}
}
match jsonschema::validator_for(&schema_to_compile) {
Ok(validator) => {
compiled.insert(tool_name.clone(), Arc::new(validator));
}
Err(e) => {
tracing::error!("Failed to compile schema for tool {}: {}", tool_name, e);
panic!(
"Failed to compile JSON schema for tool '{}': {}",
tool_name, e
);
}
}
}
compiled
}
pub fn evaluate(
&self,
tool_name: &str,
args: &Value,
state: &mut PolicyState,
runtime_identity: Option<&ToolIdentity>,
) -> PolicyDecision {
self.evaluate_with_metadata(tool_name, args, state, runtime_identity)
.decision
}
pub fn evaluate_with_metadata(
&self,
tool_name: &str,
args: &Value,
state: &mut PolicyState,
runtime_identity: Option<&ToolIdentity>,
) -> PolicyEvaluation {
let tool_classes = self.tool_taxonomy.classes_for(tool_name);
let tool_classes_vec: Vec<String> = tool_classes.iter().cloned().collect();
let mut metadata = PolicyMatchMetadata {
tool_classes: tool_classes_vec,
..PolicyMatchMetadata::default()
};
if let Some(pinned) = self.tool_pins.get(tool_name) {
if let Some(runtime) = runtime_identity {
if pinned != runtime {
return PolicyEvaluation {
decision: PolicyDecision::Deny {
tool: tool_name.to_string(),
code: "E_TOOL_DRIFT".to_string(),
reason: format!(
"Tool integrity failure: identity drifted from pinned version. (Runtime: {}, Pinned: {})",
runtime.fingerprint(),
pinned.fingerprint()
),
contract: self.format_deny_contract(
tool_name,
"E_TOOL_DRIFT",
"Tool metadata or schema has changed without policy update (SOTA Moat)",
),
},
metadata,
};
}
}
}
if let Some(decision) = self.check_rate_limits(state) {
return PolicyEvaluation { decision, metadata };
}
let deny_name_match = self.is_denied(tool_name);
let deny_class_matches = self.matched_deny_classes(&tool_classes);
if deny_name_match || !deny_class_matches.is_empty() {
metadata.matched_tool_classes = deny_class_matches.clone();
metadata.match_basis =
Self::classify_match_basis(deny_name_match, !deny_class_matches.is_empty());
metadata.matched_rule = Some(Self::matched_rule_name(
"tools.deny",
"tools.deny_classes",
&metadata,
));
let deny_reason = if deny_name_match && !deny_class_matches.is_empty() {
"Tool is explicitly denylisted by name and class"
} else if deny_name_match {
"Tool is explicitly denylisted by name"
} else {
"Tool is explicitly denylisted by class"
};
return PolicyEvaluation {
decision: PolicyDecision::Deny {
tool: tool_name.to_string(),
code: "E_TOOL_DENIED".to_string(),
reason: deny_reason.to_string(),
contract: self.format_deny_contract(tool_name, "E_TOOL_DENIED", deny_reason),
},
metadata,
};
}
let allow_name_match = self.is_allowed(tool_name);
let allow_class_matches = self.matched_allow_classes(&tool_classes);
if self.has_allowlist() && !allow_name_match && allow_class_matches.is_empty() {
return PolicyEvaluation {
decision: PolicyDecision::Deny {
tool: tool_name.to_string(),
code: "E_TOOL_NOT_ALLOWED".to_string(),
reason: "Tool is not in the allowlist".to_string(),
contract: self.format_deny_contract(
tool_name,
"E_TOOL_NOT_ALLOWED",
"Tool is not in allowlist",
),
},
metadata,
};
}
if allow_name_match || !allow_class_matches.is_empty() {
metadata.matched_tool_classes = allow_class_matches;
metadata.match_basis = Self::classify_match_basis(
allow_name_match,
!metadata.matched_tool_classes.is_empty(),
);
metadata.matched_rule = Some(Self::matched_rule_name(
"tools.allow",
"tools.allow_classes",
&metadata,
));
}
let compiled = self.compiled_schemas();
if let Some(validator) = compiled.get(tool_name) {
if !validator.is_valid(args) {
let violations: Vec<Value> = validator
.iter_errors(args)
.map(|e| {
json!({
"path": e.instance_path().to_string(),
"message": e.to_string(),
})
})
.collect();
return PolicyEvaluation {
decision: PolicyDecision::Deny {
tool: tool_name.to_string(),
code: "E_ARG_SCHEMA".to_string(),
reason: "JSON Schema validation failed".to_string(),
contract: json!({
"status": "deny",
"error_code": "E_ARG_SCHEMA",
"tool": tool_name,
"violations": violations,
}),
},
metadata,
};
}
return PolicyEvaluation {
decision: PolicyDecision::Allow,
metadata,
};
}
let decision = match self.enforcement.unconstrained_tools {
UnconstrainedMode::Deny => PolicyDecision::Deny {
tool: tool_name.to_string(),
code: "E_TOOL_UNCONSTRAINED".to_string(),
reason: "Tool has no schema (enforcement: deny)".to_string(),
contract: self.format_deny_contract(
tool_name,
"E_TOOL_UNCONSTRAINED",
"Tool has no schema (enforcement: deny)",
),
},
UnconstrainedMode::Warn => PolicyDecision::AllowWithWarning {
tool: tool_name.to_string(),
code: "E_TOOL_UNCONSTRAINED".to_string(),
reason: "Tool allowed but has no schema".to_string(),
},
UnconstrainedMode::Allow => PolicyDecision::Allow,
};
PolicyEvaluation { decision, metadata }
}
fn check_rate_limits(&self, state: &mut PolicyState) -> Option<PolicyDecision> {
state.requests_count += 1;
state.tool_calls_count += 1;
if let Some(limits) = &self.limits {
if let Some(max) = limits.max_requests_total {
if state.requests_count > max {
return Some(PolicyDecision::Deny {
tool: "ALL".to_string(),
code: "E_RATE_LIMIT".to_string(),
reason: "Rate limit exceeded (total requests)".to_string(),
contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
});
}
}
if let Some(max) = limits.max_tool_calls_total {
if state.tool_calls_count > max {
return Some(PolicyDecision::Deny {
tool: "ALL".to_string(),
code: "E_RATE_LIMIT".to_string(),
reason: "Rate limit exceeded (tool calls)".to_string(),
contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
});
}
}
}
None
}
fn is_denied(&self, tool_name: &str) -> bool {
let root_deny = self.deny.as_ref();
let tools_deny = self.tools.deny.as_ref();
root_deny
.iter()
.flat_map(|v| v.iter())
.chain(tools_deny.iter().flat_map(|v| v.iter()))
.any(|pattern| matches_tool_pattern(tool_name, pattern))
}
fn has_allowlist(&self) -> bool {
self.allow.is_some() || self.tools.allow.is_some() || self.tools.allow_classes.is_some()
}
fn is_allowed(&self, tool_name: &str) -> bool {
let root_allow = self.allow.as_ref();
let tools_allow = self.tools.allow.as_ref();
root_allow
.iter()
.flat_map(|v| v.iter())
.chain(tools_allow.iter().flat_map(|v| v.iter()))
.any(|pattern| matches_tool_pattern(tool_name, pattern))
}
fn format_deny_contract(&self, tool: &str, code: &str, reason: &str) -> Value {
json!({
"status": "deny",
"error_code": code,
"tool": tool,
"reason": reason
})
}
fn matched_deny_classes(&self, tool_classes: &BTreeSet<String>) -> Vec<String> {
self.match_classes(tool_classes, self.tools.deny_classes.as_ref())
}
fn matched_allow_classes(&self, tool_classes: &BTreeSet<String>) -> Vec<String> {
self.match_classes(tool_classes, self.tools.allow_classes.as_ref())
}
fn match_classes(
&self,
tool_classes: &BTreeSet<String>,
configured: Option<&Vec<String>>,
) -> Vec<String> {
let mut matched = BTreeSet::new();
if let Some(configured_classes) = configured {
for class_name in configured_classes {
if tool_classes.contains(class_name) {
matched.insert(class_name.clone());
}
}
}
matched.into_iter().collect()
}
fn classify_match_basis(name_match: bool, class_match: bool) -> MatchBasis {
match (name_match, class_match) {
(true, true) => MatchBasis::NameAndClass,
(true, false) => MatchBasis::Name,
(false, true) => MatchBasis::Class,
(false, false) => MatchBasis::None,
}
}
fn matched_rule_name(
name_field: &str,
class_field: &str,
metadata: &PolicyMatchMetadata,
) -> String {
match metadata.match_basis {
MatchBasis::NameAndClass => format!("{name_field}+{class_field}"),
MatchBasis::Name => name_field.to_string(),
MatchBasis::Class => class_field.to_string(),
MatchBasis::None => name_field.to_string(),
}
}
pub fn check(&self, request: &JsonRpcRequest, state: &mut PolicyState) -> PolicyDecision {
if !request.is_tool_call() {
state.requests_count += 1;
return PolicyDecision::Allow;
}
if let Some(params) = request.tool_params() {
self.evaluate(¶ms.name, ¶ms.arguments, state, None)
} else {
state.requests_count += 1;
PolicyDecision::Allow
}
}
}
fn constraint_to_schema(constraint: &ConstraintRule) -> Value {
let mut properties = json!({});
let mut required = vec![];
for (param_name, param_constraint) in &constraint.params {
if let Some(pattern) = ¶m_constraint.matches {
properties[param_name] = json!({
"type": "string",
"pattern": pattern,
"minLength": 1
});
required.push(param_name.clone());
}
}
json!({
"type": "object",
"additionalProperties": true,
"properties": properties,
"required": required,
})
}
pub fn make_deny_response(id: Value, msg: &str, contract: Value) -> String {
let body = ToolResultBody {
content: vec![ContentItem::Text {
text: msg.to_string(),
}],
is_error: true,
structured_content: Some(contract),
};
let resp = JsonRpcResponse {
jsonrpc: "2.0",
id,
payload: ToolCallResult { result: body },
};
serde_json::to_string(&resp).unwrap_or_default() + "\n"
}
fn emit_deprecation_warning() {
static WARNED: OnceLock<()> = OnceLock::new();
WARNED.get_or_init(|| {
eprintln!(
"\n\x1b[33m⚠️ DEPRECATED: v1 policy format detected\x1b[0m\n\
\x1b[33m The 'constraints:' syntax is deprecated and will be removed in Assay v2.0.0.\x1b[0m\n\
\x1b[33m Migrate now:\x1b[0m\n\
\x1b[33m assay policy migrate --input <file>\x1b[0m\n\
\x1b[33m See: https://docs.assay.dev/migration/v1-to-v2\x1b[0m\n"
);
});
}