Skip to main content

assay_core/
policy_engine.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
5#[serde(rename_all = "snake_case")]
6pub enum VerdictStatus {
7    Allowed,
8    Blocked,
9}
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
12pub struct Verdict {
13    pub status: VerdictStatus,
14    pub reason_code: String, // e.g., "OK", "E_ARG_SCHEMA", "E_TOOL_NOT_ALLOWED"
15    pub details: Value,      // JSON details, violations, etc.
16}
17
18/// Evaluates tool arguments against a policy (JSON/YAML Value).
19/// The policy is expected to be a map of tool_name -> schema.
20pub fn evaluate_tool_args(policy: &Value, tool_name: &str, tool_args: &Value) -> Verdict {
21    // 1. Check if tool exists in policy
22    let schema_val = match policy.get(tool_name) {
23        Some(s) => s,
24        None => {
25            // Check for potential typos
26            let mut message = format!("Tool '{}' not defined in policy", tool_name);
27            if let Some(obj) = policy.as_object() {
28                // Use our similarity helper
29                if let Some(match_) =
30                    crate::errors::similarity::closest_prompt(tool_name, obj.keys())
31                {
32                    message.push_str(&format!(". Did you mean '{}'?", match_.prompt));
33                }
34            }
35
36            return Verdict {
37                status: VerdictStatus::Blocked,
38                reason_code: "E_POLICY_MISSING_TOOL".to_string(),
39                details: serde_json::json!({
40                    "message": message
41                }),
42            };
43        }
44    };
45
46    // 2. Compile Schema
47    // In a real high-perf scenario, we'd cache this (Compilation is expensive).
48    // For this core function, we compile on the fly or need a cached compilation context.
49    // User Step 1.2: "Compile JSON Schema validators één keer bij policy load".
50    // Since this function takes `&Value`, it implies per-call.
51    // To support caching, we'd need a `PolicyState` struct.
52    // For now, I'll compile on the fly (parity correctness first).
53
54    let compiled = match jsonschema::validator_for(schema_val) {
55        Ok(c) => c,
56        Err(e) => {
57            return Verdict {
58                status: VerdictStatus::Blocked,
59                reason_code: "E_SCHEMA_COMPILE".to_string(),
60                details: serde_json::json!({
61                    "message": format!("Invalid schema for tool '{}': {}", tool_name, e)
62                }),
63            };
64        }
65    };
66
67    // 3. Validate
68    evaluate_schema(&compiled, tool_args)
69}
70
71/// Evaluates tool arguments against a compiled schema.
72pub fn evaluate_schema(compiled: &jsonschema::Validator, tool_args: &Value) -> Verdict {
73    if compiled.is_valid(tool_args) {
74        return Verdict {
75            status: VerdictStatus::Allowed,
76            reason_code: "OK".to_string(),
77            details: serde_json::json!({}),
78        };
79    }
80    let violations: Vec<Value> = compiled
81        .iter_errors(tool_args)
82        .map(|e| {
83            serde_json::json!({
84                "path": e.instance_path().to_string(),
85                "constraint": e.to_string(),
86                "message": e.to_string()
87            })
88        })
89        .collect();
90    Verdict {
91        status: VerdictStatus::Blocked,
92        reason_code: "E_ARG_SCHEMA".to_string(),
93        details: serde_json::json!({
94            "violations": violations
95        }),
96    }
97}
98
99/// Evaluates a sequence of tool calls against a sequence policy (regex-like).
100/// For v0.9, simplified: the policy is just a string (regex) of tool names.
101/// E.g. "^search (analyze )*report$"
102/// The input is a list of tool names invoked in order.
103pub fn evaluate_sequence(policy_regex: &str, tool_names: &[String]) -> Verdict {
104    // 1. Construct the sequence string
105    // We join tool names with space. Note: tool names should not contain spaces ideally.
106    // If they do, this simple approach might be ambiguous, but standard tools usually don't.
107    let trace_str = tool_names.join(" ");
108
109    // 2. Compile Regex
110    // Again, efficiency concern: compile once.
111    let re = match regex::Regex::new(policy_regex) {
112        Ok(r) => r,
113        Err(e) => {
114            return Verdict {
115                status: VerdictStatus::Blocked,
116                reason_code: "E_POLICY_REGEX_INVALID".to_string(),
117                details: serde_json::json!({
118                    "message": format!("Invalid regex policy '{}': {}", policy_regex, e)
119                }),
120            };
121        }
122    };
123
124    // 3. Match
125    if re.is_match(&trace_str) {
126        Verdict {
127            status: VerdictStatus::Allowed,
128            reason_code: "OK".to_string(),
129            details: serde_json::json!({}),
130        }
131    } else {
132        Verdict {
133            status: VerdictStatus::Blocked,
134            reason_code: "E_SEQUENCE_VIOLATION".to_string(),
135            details: serde_json::json!({
136                "expected": policy_regex,
137                "found": trace_str
138            }),
139        }
140    }
141}