1use super::jsonrpc::{
2 ContentItem, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolResultBody,
3};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::{BTreeMap, HashMap};
7use std::sync::{Arc, OnceLock};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
10pub struct McpPolicy {
11 #[serde(default)]
12 pub version: String,
13
14 #[serde(default)]
15 pub name: String,
16
17 #[serde(default)]
18 pub tools: ToolPolicy,
19
20 #[serde(default)]
22 pub allow: Option<Vec<String>>,
23 #[serde(default)]
24 pub deny: Option<Vec<String>>,
25
26 #[serde(default)]
28 pub schemas: HashMap<String, Value>,
29
30 #[serde(default, deserialize_with = "deserialize_constraints")]
32 pub constraints: Vec<ConstraintRule>,
33
34 #[serde(default)]
35 pub enforcement: EnforcementSettings,
36
37 #[serde(default)]
38 pub limits: Option<GlobalLimits>,
39
40 #[serde(default)]
41 pub signatures: Option<SignaturePolicy>,
42
43 #[serde(default)]
45 pub discovery: Option<DiscoveryConfig>,
46 #[serde(default)]
47 pub runtime_monitor: Option<RuntimeMonitorConfig>,
48 #[serde(default)]
49 pub kill_switch: Option<KillSwitchConfig>,
50
51 #[serde(skip)]
53 pub(crate) compiled: Arc<OnceLock<HashMap<String, Arc<jsonschema::JSONSchema>>>>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EnforcementSettings {
58 #[serde(default = "default_unconstrained")]
60 pub unconstrained_tools: UnconstrainedMode,
61}
62
63impl Default for EnforcementSettings {
64 fn default() -> Self {
65 Self {
66 unconstrained_tools: UnconstrainedMode::Warn,
67 }
68 }
69}
70
71fn default_unconstrained() -> UnconstrainedMode {
72 UnconstrainedMode::Warn
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
76#[serde(rename_all = "snake_case")]
77pub enum UnconstrainedMode {
78 #[default]
79 Warn,
80 Deny,
81 Allow,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct SignaturePolicy {
86 #[serde(default)]
87 pub check_descriptions: bool,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, Default)]
91pub struct GlobalLimits {
92 pub max_requests_total: Option<u64>,
93 pub max_tool_calls_total: Option<u64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, Default)]
97pub struct ToolPolicy {
98 pub allow: Option<Vec<String>>,
99 pub deny: Option<Vec<String>>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConstraintRule {
105 pub tool: String,
106 pub params: BTreeMap<String, ConstraintParam>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ConstraintParam {
111 #[serde(default)]
112 pub matches: Option<String>,
113}
114
115pub use super::runtime_features::{
116 ActionLevel, DiscoveryActions, DiscoveryConfig, DiscoveryMethod,
117 KillMode, KillSwitchConfig, KillTrigger,
118 MonitorAction, MonitorMatch, MonitorProvider, MonitorRule,
119 MonitorRuleType, RuntimeMonitorConfig,
120};
121
122#[derive(Debug, Default)]
123pub struct PolicyState {
124 pub requests_count: u64,
125 pub tool_calls_count: u64,
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub enum PolicyDecision {
130 Allow,
131 AllowWithWarning {
132 tool: String,
133 code: String,
134 reason: String,
135 },
136 Deny {
137 tool: String,
138 code: String,
139 reason: String,
140 contract: Value,
141 },
142}
143
144#[derive(Debug, Clone, Deserialize)]
146#[serde(untagged)]
147enum ConstraintsCompat {
148 List(Vec<ConstraintRule>),
149 Map(BTreeMap<String, BTreeMap<String, InputParamConstraint>>),
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(untagged)]
154enum InputParamConstraint {
155 Direct(String),
156 Object(ConstraintParam),
157}
158
159fn deserialize_constraints<'de, D>(d: D) -> Result<Vec<ConstraintRule>, D::Error>
160where
161 D: serde::Deserializer<'de>,
162{
163 let c = Option::<ConstraintsCompat>::deserialize(d)?;
164 let out = match c {
165 None => vec![],
166 Some(ConstraintsCompat::List(v)) => v,
167 Some(ConstraintsCompat::Map(m)) => m
168 .into_iter()
169 .map(|(tool, params)| {
170 let new_params = params
171 .into_iter()
172 .map(|(arg, val)| {
173 let param = match val {
174 InputParamConstraint::Direct(s) => ConstraintParam { matches: Some(s) },
175 InputParamConstraint::Object(o) => o,
176 };
177 (arg, param)
178 })
179 .collect();
180 ConstraintRule {
181 tool,
182 params: new_params,
183 }
184 })
185 .collect(),
186 };
187 Ok(out)
188}
189
190fn matches_tool_pattern(tool_name: &str, pattern: &str) -> bool {
191 if pattern == "*" {
192 return true;
193 }
194 if !pattern.contains('*') {
195 return tool_name == pattern;
196 }
197 let starts_star = pattern.starts_with('*');
198 let ends_star = pattern.ends_with('*');
199 match (starts_star, ends_star) {
200 (true, true) => {
201 let inner = pattern.trim_matches('*');
202 if inner.is_empty() {
203 true
204 } else {
205 tool_name.contains(inner)
206 }
207 }
208 (false, true) => {
209 let prefix = pattern.trim_end_matches('*');
210 !prefix.is_empty() && tool_name.starts_with(prefix)
211 }
212 (true, false) => {
213 let suffix = pattern.trim_start_matches('*');
214 !suffix.is_empty() && tool_name.ends_with(suffix)
215 }
216 (false, false) => tool_name == pattern,
217 }
218}
219
220impl McpPolicy {
221 pub fn new() -> Self {
222 Self::default()
223 }
224
225 pub fn from_file(path: &std::path::Path) -> anyhow::Result<Self> {
226 let content = std::fs::read_to_string(path)?;
227
228 let mut unknown = Vec::new();
229 let de = serde_yaml::Deserializer::from_str(&content);
230 let mut policy: McpPolicy = serde_ignored::deserialize(de, |path| {
231 unknown.push(path.to_string());
232 }).map_err(anyhow::Error::from)?;
233
234 if !unknown.is_empty() {
235 tracing::warn!(?unknown, "Unknown fields in policy (ignored)");
237 }
238
239 if policy.is_v1_format() {
241 if std::env::var("ASSAY_STRICT_DEPRECATIONS").ok().as_deref() == Some("1") {
242 anyhow::bail!("Strict mode: v1 policy format (constraints) is not allowed.");
243 }
244 emit_deprecation_warning();
245 }
246
247 policy.normalize_legacy_shapes();
249
250 if !policy.constraints.is_empty() {
252 policy.migrate_constraints_to_schemas();
253 }
254
255 policy.validate()?;
256
257 Ok(policy)
258 }
259
260 pub fn validate(&self) -> anyhow::Result<()> {
261 if let (Some(rm), Some(ks)) = (&self.runtime_monitor, &self.kill_switch) {
263 let rule_ids: std::collections::HashSet<&str> =
264 rm.rules.iter().map(|r| r.id.as_str()).collect();
265
266 for t in &ks.triggers {
267 if !rule_ids.contains(t.on_rule.as_str()) {
268 anyhow::bail!(
269 "kill_switch.triggers references unknown rule id: {}",
270 t.on_rule
271 );
272 }
273 }
274 }
275 Ok(())
276 }
277
278 pub fn is_v1_format(&self) -> bool {
279 !self.constraints.is_empty() || self.version == "1.0"
281 }
282
283 pub fn normalize_legacy_shapes(&mut self) {
285 if let Some(allow) = self.allow.take() {
286 let mut current = self.tools.allow.take().unwrap_or_default();
287 current.extend(allow);
288 self.tools.allow = Some(current);
289 }
290 if let Some(deny) = self.deny.take() {
291 let mut current = self.tools.deny.take().unwrap_or_default();
292 current.extend(deny);
293 self.tools.deny = Some(current);
294 }
295 }
296
297 pub fn migrate_constraints_to_schemas(&mut self) {
300 for constraint in std::mem::take(&mut self.constraints) {
301 let schema = constraint_to_schema(&constraint);
302 self.schemas.insert(constraint.tool.clone(), schema);
303 }
304 if self.version.is_empty() || self.version == "1.0" {
305 self.version = "2.0".to_string();
306 }
307 }
308
309 fn compiled_schemas(&self) -> &HashMap<String, Arc<jsonschema::JSONSchema>> {
310 self.compiled.get_or_init(|| self.compile_all_schemas())
311 }
312
313 pub fn compile_all_schemas(&self) -> HashMap<String, Arc<jsonschema::JSONSchema>> {
314 let root_defs = self.schemas.get("$defs").cloned();
316
317 let mut compiled = HashMap::new();
318 for (tool_name, schema) in &self.schemas {
319 if tool_name.starts_with('$') {
320 continue;
321 }
322
323 let mut schema_to_compile = schema.clone();
324 if let Some(defs) = &root_defs {
326 if let Value::Object(map) = &mut schema_to_compile {
327 map.insert("$defs".to_string(), defs.clone());
330 }
331 }
332
333 match jsonschema::JSONSchema::compile(&schema_to_compile) {
334 Ok(validator) => {
335 compiled.insert(tool_name.clone(), Arc::new(validator));
336 }
337 Err(e) => {
338 tracing::error!("Failed to compile schema for tool {}: {}", tool_name, e);
339 panic!(
341 "Failed to compile JSON schema for tool '{}': {}",
342 tool_name, e
343 );
344 }
345 }
346 }
347 compiled
348 }
349
350 pub fn evaluate(
352 &self,
353 tool_name: &str,
354 args: &Value,
355 state: &mut PolicyState,
356 ) -> PolicyDecision {
357 if let Some(decision) = self.check_rate_limits(state) {
359 return decision;
360 }
361
362 if self.is_denied(tool_name) {
364 return PolicyDecision::Deny {
365 tool: tool_name.to_string(),
366 code: "E_TOOL_DENIED".to_string(),
367 reason: "Tool is explicitly denylisted".to_string(),
368 contract: self.format_deny_contract(
369 tool_name,
370 "E_TOOL_DENIED",
371 "Tool is denylisted",
372 ),
373 };
374 }
375
376 if self.has_allowlist() && !self.is_allowed(tool_name) {
378 return PolicyDecision::Deny {
379 tool: tool_name.to_string(),
380 code: "E_TOOL_NOT_ALLOWED".to_string(),
381 reason: "Tool is not in the allowlist".to_string(),
382 contract: self.format_deny_contract(
383 tool_name,
384 "E_TOOL_NOT_ALLOWED",
385 "Tool is not in allowlist",
386 ),
387 };
388 }
389
390 let compiled = self.compiled_schemas();
392 if let Some(validator) = compiled.get(tool_name) {
393 match validator.validate(args) {
394 Ok(_) => return PolicyDecision::Allow,
395 Err(errors) => {
396 let violations: Vec<_> = errors
397 .map(|e| {
398 json!({
399 "path": e.instance_path.to_string(),
400 "message": e.to_string(),
401 })
402 })
403 .collect();
404 return PolicyDecision::Deny {
405 tool: tool_name.to_string(),
406 code: "E_ARG_SCHEMA".to_string(),
407 reason: "JSON Schema validation failed".to_string(),
408 contract: json!({
409 "status": "deny",
410 "error_code": "E_ARG_SCHEMA",
411 "tool": tool_name,
412 "violations": violations,
413 }),
414 };
415 }
416 }
417 }
418
419 match self.enforcement.unconstrained_tools {
421 UnconstrainedMode::Deny => PolicyDecision::Deny {
422 tool: tool_name.to_string(),
423 code: "E_TOOL_UNCONSTRAINED".to_string(),
424 reason: "Tool has no schema (enforcement: deny)".to_string(),
425 contract: self.format_deny_contract(
426 tool_name,
427 "E_TOOL_UNCONSTRAINED",
428 "Tool has no schema (enforcement: deny)",
429 ),
430 },
431 UnconstrainedMode::Warn => PolicyDecision::AllowWithWarning {
432 tool: tool_name.to_string(),
433 code: "E_TOOL_UNCONSTRAINED".to_string(),
434 reason: "Tool allowed but has no schema".to_string(),
435 },
436 UnconstrainedMode::Allow => PolicyDecision::Allow,
437 }
438 }
439
440 fn check_rate_limits(&self, state: &mut PolicyState) -> Option<PolicyDecision> {
442 state.requests_count += 1;
443 state.tool_calls_count += 1; if let Some(limits) = &self.limits {
446 if let Some(max) = limits.max_requests_total {
447 if state.requests_count > max {
451 return Some(PolicyDecision::Deny {
452 tool: "ALL".to_string(),
453 code: "E_RATE_LIMIT".to_string(),
454 reason: "Rate limit exceeded (total requests)".to_string(),
455 contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
456 });
457 }
458 }
459
460 if let Some(max) = limits.max_tool_calls_total {
461 if state.tool_calls_count > max {
462 return Some(PolicyDecision::Deny {
463 tool: "ALL".to_string(),
464 code: "E_RATE_LIMIT".to_string(),
465 reason: "Rate limit exceeded (tool calls)".to_string(),
466 contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
467 });
468 }
469 }
470 }
471 None
472 }
473
474 fn is_denied(&self, tool_name: &str) -> bool {
475 let root_deny = self.deny.as_ref();
476 let tools_deny = self.tools.deny.as_ref();
477 root_deny
478 .iter()
479 .flat_map(|v| v.iter())
480 .chain(tools_deny.iter().flat_map(|v| v.iter()))
481 .any(|pattern| matches_tool_pattern(tool_name, pattern))
482 }
483
484 fn has_allowlist(&self) -> bool {
485 self.allow.is_some() || self.tools.allow.is_some()
486 }
487
488 fn is_allowed(&self, tool_name: &str) -> bool {
489 let root_allow = self.allow.as_ref();
490 let tools_allow = self.tools.allow.as_ref();
491 root_allow
492 .iter()
493 .flat_map(|v| v.iter())
494 .chain(tools_allow.iter().flat_map(|v| v.iter()))
495 .any(|pattern| matches_tool_pattern(tool_name, pattern))
496 }
497
498 fn format_deny_contract(&self, tool: &str, code: &str, reason: &str) -> Value {
499 json!({
500 "status": "deny",
501 "error_code": code,
502 "tool": tool,
503 "reason": reason
504 })
505 }
506
507 pub fn check(&self, request: &JsonRpcRequest, state: &mut PolicyState) -> PolicyDecision {
509 if !request.is_tool_call() {
510 state.requests_count += 1;
511 return PolicyDecision::Allow;
512 }
513 if let Some(params) = request.tool_params() {
514 self.evaluate(¶ms.name, ¶ms.arguments, state)
516 } else {
517 state.requests_count += 1;
519 PolicyDecision::Allow
520 }
521 }
522}
523
524fn constraint_to_schema(constraint: &ConstraintRule) -> Value {
525 let mut properties = json!({});
526 let mut required = vec![];
527
528 for (param_name, param_constraint) in &constraint.params {
529 if let Some(pattern) = ¶m_constraint.matches {
530 properties[param_name] = json!({
531 "type": "string",
532 "pattern": pattern,
533 "minLength": 1
534 });
536 required.push(param_name.clone());
537 }
538 }
539
540 json!({
541 "type": "object",
542 "additionalProperties": true,
544 "properties": properties,
545 "required": required,
546 })
547}
548
549pub fn make_deny_response(id: Value, msg: &str, contract: Value) -> String {
550 let body = ToolResultBody {
551 content: vec![ContentItem::Text {
552 text: msg.to_string(),
553 }],
554 is_error: true,
555 structured_content: Some(contract),
556 };
557 let resp = JsonRpcResponse {
558 jsonrpc: "2.0",
559 id,
560 payload: ToolCallResult { result: body },
561 };
562 serde_json::to_string(&resp).unwrap_or_default() + "\n"
563}
564
565fn emit_deprecation_warning() {
566 static WARNED: OnceLock<()> = OnceLock::new();
567 WARNED.get_or_init(|| {
568 eprintln!(
569 "\n\x1b[33m⚠️ DEPRECATED: v1 policy format detected\x1b[0m\n\
570 \x1b[33m The 'constraints:' syntax is deprecated and will be removed in Assay v2.0.0.\x1b[0m\n\
571 \x1b[33m Migrate now:\x1b[0m\n\
572 \x1b[33m assay policy migrate --input <file>\x1b[0m\n\
573 \x1b[33m See: https://docs.assay.dev/migration/v1-to-v2\x1b[0m\n"
574 );
575 });
576}