assay_core/mcp/
policy.rs

1use super::jsonrpc::{
2    ContentItem, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolResultBody,
3};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::{BTreeMap, HashMap};
7use std::sync::{Arc, OnceLock};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
10pub struct McpPolicy {
11    #[serde(default)]
12    pub version: String,
13
14    #[serde(default)]
15    pub name: String,
16
17    #[serde(default)]
18    pub tools: ToolPolicy,
19
20    // Legacy v1: root-level allow/deny (normalized into tools.* on load)
21    #[serde(default)]
22    pub allow: Option<Vec<String>>,
23    #[serde(default)]
24    pub deny: Option<Vec<String>>,
25
26    /// V2: JSON Schema per tool (primary)
27    #[serde(default)]
28    pub schemas: HashMap<String, Value>,
29
30    /// V1 (deprecated): Regex constraints - auto-converted to schemas on load
31    #[serde(default, deserialize_with = "deserialize_constraints")]
32    pub constraints: Vec<ConstraintRule>,
33
34    #[serde(default)]
35    pub enforcement: EnforcementSettings,
36
37    #[serde(default)]
38    pub limits: Option<GlobalLimits>,
39
40    #[serde(default)]
41    pub signatures: Option<SignaturePolicy>,
42
43    // Phase 4: Runtime Features
44    #[serde(default)]
45    pub discovery: Option<DiscoveryConfig>,
46    #[serde(default)]
47    pub runtime_monitor: Option<RuntimeMonitorConfig>,
48    #[serde(default)]
49    pub kill_switch: Option<KillSwitchConfig>,
50
51    /// Compiled schemas (lazy, thread-safe, shared across clones)
52    #[serde(skip)]
53    pub(crate) compiled: Arc<OnceLock<HashMap<String, Arc<jsonschema::JSONSchema>>>>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EnforcementSettings {
58    /// What to do when a tool has no schema
59    #[serde(default = "default_unconstrained")]
60    pub unconstrained_tools: UnconstrainedMode,
61}
62
63impl Default for EnforcementSettings {
64    fn default() -> Self {
65        Self {
66            unconstrained_tools: UnconstrainedMode::Warn,
67        }
68    }
69}
70
71fn default_unconstrained() -> UnconstrainedMode {
72    UnconstrainedMode::Warn
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
76#[serde(rename_all = "snake_case")]
77pub enum UnconstrainedMode {
78    #[default]
79    Warn,
80    Deny,
81    Allow,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct SignaturePolicy {
86    #[serde(default)]
87    pub check_descriptions: bool,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, Default)]
91pub struct GlobalLimits {
92    pub max_requests_total: Option<u64>,
93    pub max_tool_calls_total: Option<u64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, Default)]
97pub struct ToolPolicy {
98    pub allow: Option<Vec<String>>,
99    pub deny: Option<Vec<String>>,
100}
101
102// Canonical Rule Shape (Legacy V1)
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConstraintRule {
105    pub tool: String,
106    pub params: BTreeMap<String, ConstraintParam>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ConstraintParam {
111    #[serde(default)]
112    pub matches: Option<String>,
113}
114
115pub use super::runtime_features::{
116    ActionLevel, DiscoveryActions, DiscoveryConfig, DiscoveryMethod,
117    KillMode, KillSwitchConfig, KillTrigger,
118    MonitorAction, MonitorMatch, MonitorProvider, MonitorRule,
119    MonitorRuleType, RuntimeMonitorConfig,
120};
121
122#[derive(Debug, Default)]
123pub struct PolicyState {
124    pub requests_count: u64,
125    pub tool_calls_count: u64,
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub enum PolicyDecision {
130    Allow,
131    AllowWithWarning {
132        tool: String,
133        code: String,
134        reason: String,
135    },
136    Deny {
137        tool: String,
138        code: String,
139        reason: String,
140        contract: Value,
141    },
142}
143
144// Dual-Shape Deserializer Helper (Legacy)
145#[derive(Debug, Clone, Deserialize)]
146#[serde(untagged)]
147enum ConstraintsCompat {
148    List(Vec<ConstraintRule>),
149    Map(BTreeMap<String, BTreeMap<String, InputParamConstraint>>),
150}
151
152#[derive(Debug, Clone, Deserialize)]
153#[serde(untagged)]
154enum InputParamConstraint {
155    Direct(String),
156    Object(ConstraintParam),
157}
158
159fn deserialize_constraints<'de, D>(d: D) -> Result<Vec<ConstraintRule>, D::Error>
160where
161    D: serde::Deserializer<'de>,
162{
163    let c = Option::<ConstraintsCompat>::deserialize(d)?;
164    let out = match c {
165        None => vec![],
166        Some(ConstraintsCompat::List(v)) => v,
167        Some(ConstraintsCompat::Map(m)) => m
168            .into_iter()
169            .map(|(tool, params)| {
170                let new_params = params
171                    .into_iter()
172                    .map(|(arg, val)| {
173                        let param = match val {
174                            InputParamConstraint::Direct(s) => ConstraintParam { matches: Some(s) },
175                            InputParamConstraint::Object(o) => o,
176                        };
177                        (arg, param)
178                    })
179                    .collect();
180                ConstraintRule {
181                    tool,
182                    params: new_params,
183                }
184            })
185            .collect(),
186    };
187    Ok(out)
188}
189
190fn matches_tool_pattern(tool_name: &str, pattern: &str) -> bool {
191    if pattern == "*" {
192        return true;
193    }
194    if !pattern.contains('*') {
195        return tool_name == pattern;
196    }
197    let starts_star = pattern.starts_with('*');
198    let ends_star = pattern.ends_with('*');
199    match (starts_star, ends_star) {
200        (true, true) => {
201            let inner = pattern.trim_matches('*');
202            if inner.is_empty() {
203                true
204            } else {
205                tool_name.contains(inner)
206            }
207        }
208        (false, true) => {
209            let prefix = pattern.trim_end_matches('*');
210            !prefix.is_empty() && tool_name.starts_with(prefix)
211        }
212        (true, false) => {
213            let suffix = pattern.trim_start_matches('*');
214            !suffix.is_empty() && tool_name.ends_with(suffix)
215        }
216        (false, false) => tool_name == pattern,
217    }
218}
219
220impl McpPolicy {
221    pub fn new() -> Self {
222        Self::default()
223    }
224
225    pub fn from_file(path: &std::path::Path) -> anyhow::Result<Self> {
226        let content = std::fs::read_to_string(path)?;
227
228        let mut unknown = Vec::new();
229        let de = serde_yaml::Deserializer::from_str(&content);
230        let mut policy: McpPolicy = serde_ignored::deserialize(de, |path| {
231            unknown.push(path.to_string());
232        }).map_err(anyhow::Error::from)?;
233
234        if !unknown.is_empty() {
235            // Filter out transient/internal fields if any. For now, log all.
236            tracing::warn!(?unknown, "Unknown fields in policy (ignored)");
237        }
238
239        // Check for v1 format and warn if necessary
240        if policy.is_v1_format() {
241            if std::env::var("ASSAY_STRICT_DEPRECATIONS").ok().as_deref() == Some("1") {
242                anyhow::bail!("Strict mode: v1 policy format (constraints) is not allowed.");
243            }
244            emit_deprecation_warning();
245        }
246
247        // Normalize legacy shapes
248        policy.normalize_legacy_shapes();
249
250        // Auto-migrate v1 constraints
251        if !policy.constraints.is_empty() {
252            policy.migrate_constraints_to_schemas();
253        }
254
255        policy.validate()?;
256
257        Ok(policy)
258    }
259
260    pub fn validate(&self) -> anyhow::Result<()> {
261        // Cross-validation: Kill triggers must reference valid rules
262        if let (Some(rm), Some(ks)) = (&self.runtime_monitor, &self.kill_switch) {
263             let rule_ids: std::collections::HashSet<&str> =
264                rm.rules.iter().map(|r| r.id.as_str()).collect();
265
266            for t in &ks.triggers {
267                if !rule_ids.contains(t.on_rule.as_str()) {
268                    anyhow::bail!(
269                        "kill_switch.triggers references unknown rule id: {}",
270                        t.on_rule
271                    );
272                }
273            }
274        }
275        Ok(())
276    }
277
278    pub fn is_v1_format(&self) -> bool {
279        // v1 if constraints are present OR version is explicitly "1.0"
280        !self.constraints.is_empty() || self.version == "1.0"
281    }
282
283    /// Normalize legacy root-level allow/deny into tools.allow/deny.
284    pub fn normalize_legacy_shapes(&mut self) {
285        if let Some(allow) = self.allow.take() {
286            let mut current = self.tools.allow.take().unwrap_or_default();
287            current.extend(allow);
288            self.tools.allow = Some(current);
289        }
290        if let Some(deny) = self.deny.take() {
291            let mut current = self.tools.deny.take().unwrap_or_default();
292            current.extend(deny);
293            self.tools.deny = Some(current);
294        }
295    }
296
297    /// Migrate V1 regex constraints to V2 JSON Schemas.
298    /// Warning: This clears the `constraints` field.
299    pub fn migrate_constraints_to_schemas(&mut self) {
300        for constraint in std::mem::take(&mut self.constraints) {
301            let schema = constraint_to_schema(&constraint);
302            self.schemas.insert(constraint.tool.clone(), schema);
303        }
304        if self.version.is_empty() || self.version == "1.0" {
305            self.version = "2.0".to_string();
306        }
307    }
308
309    fn compiled_schemas(&self) -> &HashMap<String, Arc<jsonschema::JSONSchema>> {
310        self.compiled.get_or_init(|| self.compile_all_schemas())
311    }
312
313    pub fn compile_all_schemas(&self) -> HashMap<String, Arc<jsonschema::JSONSchema>> {
314        // Option 1: Inline $defs into every schema to support relative #/$defs/... refs
315        let root_defs = self.schemas.get("$defs").cloned();
316
317        let mut compiled = HashMap::new();
318        for (tool_name, schema) in &self.schemas {
319            if tool_name.starts_with('$') {
320                continue;
321            }
322
323            let mut schema_to_compile = schema.clone();
324            // Inject $defs if they exist and the schema is an object
325            if let Some(defs) = &root_defs {
326                if let Value::Object(map) = &mut schema_to_compile {
327                    // Only insert if not already present to allow overrides (or just overwrite?)
328                    // For now, insert if missing or overwrite to ensure global defs availability.
329                    map.insert("$defs".to_string(), defs.clone());
330                }
331            }
332
333            match jsonschema::JSONSchema::compile(&schema_to_compile) {
334                Ok(validator) => {
335                    compiled.insert(tool_name.clone(), Arc::new(validator));
336                }
337                Err(e) => {
338                    tracing::error!("Failed to compile schema for tool {}: {}", tool_name, e);
339                    // Fail securely: do not allow tools with broken schemas to load.
340                    panic!(
341                        "Failed to compile JSON schema for tool '{}': {}",
342                        tool_name, e
343                    );
344                }
345            }
346        }
347        compiled
348    }
349
350    /// Single evaluation entry point for CLI and Server
351    pub fn evaluate(
352        &self,
353        tool_name: &str,
354        args: &Value,
355        state: &mut PolicyState,
356    ) -> PolicyDecision {
357        // 1. Rate limits
358        if let Some(decision) = self.check_rate_limits(state) {
359            return decision;
360        }
361
362        // 2. Deny list
363        if self.is_denied(tool_name) {
364            return PolicyDecision::Deny {
365                tool: tool_name.to_string(),
366                code: "E_TOOL_DENIED".to_string(),
367                reason: "Tool is explicitly denylisted".to_string(),
368                contract: self.format_deny_contract(
369                    tool_name,
370                    "E_TOOL_DENIED",
371                    "Tool is denylisted",
372                ),
373            };
374        }
375
376        // 3. Allow list
377        if self.has_allowlist() && !self.is_allowed(tool_name) {
378            return PolicyDecision::Deny {
379                tool: tool_name.to_string(),
380                code: "E_TOOL_NOT_ALLOWED".to_string(),
381                reason: "Tool is not in the allowlist".to_string(),
382                contract: self.format_deny_contract(
383                    tool_name,
384                    "E_TOOL_NOT_ALLOWED",
385                    "Tool is not in allowlist",
386                ),
387            };
388        }
389
390        // 4. Schema Validation
391        let compiled = self.compiled_schemas();
392        if let Some(validator) = compiled.get(tool_name) {
393            match validator.validate(args) {
394                Ok(_) => return PolicyDecision::Allow,
395                Err(errors) => {
396                    let violations: Vec<_> = errors
397                        .map(|e| {
398                            json!({
399                                "path": e.instance_path.to_string(),
400                                "message": e.to_string(),
401                            })
402                        })
403                        .collect();
404                    return PolicyDecision::Deny {
405                        tool: tool_name.to_string(),
406                        code: "E_ARG_SCHEMA".to_string(),
407                        reason: "JSON Schema validation failed".to_string(),
408                        contract: json!({
409                            "status": "deny",
410                            "error_code": "E_ARG_SCHEMA",
411                            "tool": tool_name,
412                            "violations": violations,
413                        }),
414                    };
415                }
416            }
417        }
418
419        // 5. Unconstrained Mode
420        match self.enforcement.unconstrained_tools {
421            UnconstrainedMode::Deny => PolicyDecision::Deny {
422                tool: tool_name.to_string(),
423                code: "E_TOOL_UNCONSTRAINED".to_string(),
424                reason: "Tool has no schema (enforcement: deny)".to_string(),
425                contract: self.format_deny_contract(
426                    tool_name,
427                    "E_TOOL_UNCONSTRAINED",
428                    "Tool has no schema (enforcement: deny)",
429                ),
430            },
431            UnconstrainedMode::Warn => PolicyDecision::AllowWithWarning {
432                tool: tool_name.to_string(),
433                code: "E_TOOL_UNCONSTRAINED".to_string(),
434                reason: "Tool allowed but has no schema".to_string(),
435            },
436            UnconstrainedMode::Allow => PolicyDecision::Allow,
437        }
438    }
439
440    // Helper methods (extracted from original code or refactored)
441    fn check_rate_limits(&self, state: &mut PolicyState) -> Option<PolicyDecision> {
442        state.requests_count += 1;
443        state.tool_calls_count += 1; // Simplified: Assumes evaluate called on tool call
444
445        if let Some(limits) = &self.limits {
446            if let Some(max) = limits.max_requests_total {
447                // Note: requests_count tracks total JSON-RPC, which we might not have here accurately
448                // unless state is persistent session state.
449                // For now, allow it to increment, assuming state is managing session.
450                if state.requests_count > max {
451                    return Some(PolicyDecision::Deny {
452                        tool: "ALL".to_string(),
453                        code: "E_RATE_LIMIT".to_string(),
454                        reason: "Rate limit exceeded (total requests)".to_string(),
455                        contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
456                    });
457                }
458            }
459
460            if let Some(max) = limits.max_tool_calls_total {
461                if state.tool_calls_count > max {
462                    return Some(PolicyDecision::Deny {
463                        tool: "ALL".to_string(),
464                        code: "E_RATE_LIMIT".to_string(),
465                        reason: "Rate limit exceeded (tool calls)".to_string(),
466                        contract: json!({ "status": "deny", "error_code": "E_RATE_LIMIT" }),
467                    });
468                }
469            }
470        }
471        None
472    }
473
474    fn is_denied(&self, tool_name: &str) -> bool {
475        let root_deny = self.deny.as_ref();
476        let tools_deny = self.tools.deny.as_ref();
477        root_deny
478            .iter()
479            .flat_map(|v| v.iter())
480            .chain(tools_deny.iter().flat_map(|v| v.iter()))
481            .any(|pattern| matches_tool_pattern(tool_name, pattern))
482    }
483
484    fn has_allowlist(&self) -> bool {
485        self.allow.is_some() || self.tools.allow.is_some()
486    }
487
488    fn is_allowed(&self, tool_name: &str) -> bool {
489        let root_allow = self.allow.as_ref();
490        let tools_allow = self.tools.allow.as_ref();
491        root_allow
492            .iter()
493            .flat_map(|v| v.iter())
494            .chain(tools_allow.iter().flat_map(|v| v.iter()))
495            .any(|pattern| matches_tool_pattern(tool_name, pattern))
496    }
497
498    fn format_deny_contract(&self, tool: &str, code: &str, reason: &str) -> Value {
499        json!({
500            "status": "deny",
501            "error_code": code,
502            "tool": tool,
503            "reason": reason
504        })
505    }
506
507    // Proxy-specific check method (Legacy compatibility wrapper)
508    pub fn check(&self, request: &JsonRpcRequest, state: &mut PolicyState) -> PolicyDecision {
509        if !request.is_tool_call() {
510            state.requests_count += 1;
511            return PolicyDecision::Allow;
512        }
513        if let Some(params) = request.tool_params() {
514            // evaluate() increments counts, so we don't need to increment requests_count here
515            self.evaluate(&params.name, &params.arguments, state)
516        } else {
517            // Ordinary request, just count it
518            state.requests_count += 1;
519            PolicyDecision::Allow
520        }
521    }
522}
523
524fn constraint_to_schema(constraint: &ConstraintRule) -> Value {
525    let mut properties = json!({});
526    let mut required = vec![];
527
528    for (param_name, param_constraint) in &constraint.params {
529        if let Some(pattern) = &param_constraint.matches {
530            properties[param_name] = json!({
531                "type": "string",
532                "pattern": pattern,
533                "minLength": 1
534                // No maxLength restriction for V1 backward compatibility
535            });
536            required.push(param_name.clone());
537        }
538    }
539
540    json!({
541        "type": "object",
542        // Allow additional properties for V1 backward compatibility
543        "additionalProperties": true,
544        "properties": properties,
545        "required": required,
546    })
547}
548
549pub fn make_deny_response(id: Value, msg: &str, contract: Value) -> String {
550    let body = ToolResultBody {
551        content: vec![ContentItem::Text {
552            text: msg.to_string(),
553        }],
554        is_error: true,
555        structured_content: Some(contract),
556    };
557    let resp = JsonRpcResponse {
558        jsonrpc: "2.0",
559        id,
560        payload: ToolCallResult { result: body },
561    };
562    serde_json::to_string(&resp).unwrap_or_default() + "\n"
563}
564
565fn emit_deprecation_warning() {
566    static WARNED: OnceLock<()> = OnceLock::new();
567    WARNED.get_or_init(|| {
568        eprintln!(
569            "\n\x1b[33m⚠️  DEPRECATED: v1 policy format detected\x1b[0m\n\
570             \x1b[33m   The 'constraints:' syntax is deprecated and will be removed in Assay v2.0.0.\x1b[0m\n\
571             \x1b[33m   Migrate now:\x1b[0m\n\
572             \x1b[33m     assay policy migrate --input <file>\x1b[0m\n\
573             \x1b[33m   See: https://docs.assay.dev/migration/v1-to-v2\x1b[0m\n"
574        );
575    });
576}