1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
use super::model::TraceAssertion;
use super::EpisodeGraph;
use crate::errors::diagnostic::Diagnostic;
// usage of HashMap removed
pub fn evaluate(
graph: &EpisodeGraph,
assertions: &[TraceAssertion],
) -> anyhow::Result<Vec<Diagnostic>> {
let mut out = vec![];
for a in assertions {
if let Some(d) = check_one(graph, a) {
out.push(d);
}
}
Ok(out)
}
fn check_one(graph: &EpisodeGraph, a: &TraceAssertion) -> Option<Diagnostic> {
match a {
TraceAssertion::TraceMustCallTool { tool, min_calls } => {
let actual = graph
.tool_calls
.iter()
.filter(|t| t.tool_name.as_deref() == Some(tool.as_str()))
.count();
let min = min_calls.unwrap_or(1);
if (actual as u32) < min {
return Some(make_diag(
"E_TRACE_ASSERT_FAIL",
&format!(
"Expected tool '{}' to be called at least {} times, but got {}.",
tool, min, actual
),
Some(format!("Must call tool: {}", tool)),
None,
));
}
}
TraceAssertion::TraceMustNotCallTool { tool } => {
if let Some(call) = graph
.tool_calls
.iter()
.find(|t| t.tool_name.as_deref() == Some(tool.as_str()))
{
return Some(make_diag(
"E_TRACE_ASSERT_FAIL",
&format!(
"Expected tool '{}' NOT to be called, but it was called.",
tool
),
Some(format!("Must not call tool: {}", tool)),
Some(serde_json::json!({
"failing_step_id": call.step_id,
"failing_tool": tool,
"failing_call_index": call.call_index
})),
));
}
}
TraceAssertion::TraceToolSequence {
sequence,
allow_other_tools,
} => {
if *allow_other_tools {
// Subsequence check
if let Err(msg) = check_subsequence(&graph.tool_calls, sequence) {
return Some(make_diag(
"E_TRACE_ASSERT_FAIL",
&msg,
Some(format!("Tool sequence (subsequence): {:?}", sequence)),
None,
));
}
} else {
// Exact sequence check (contiguous, no extras)
let actual_seq: Vec<String> = graph
.tool_calls
.iter()
.filter_map(|t| t.tool_name.clone())
.collect();
if actual_seq != *sequence {
return Some(make_diag(
"E_TRACE_ASSERT_FAIL",
&format!(
"Expected exact tool sequence {:?}, got {:?}.",
sequence, actual_seq
),
Some(format!("Tool sequence (exact): {:?}", sequence)),
None,
));
}
}
}
TraceAssertion::TraceMaxSteps { max } => {
let count = graph.steps.len();
if count as u32 > *max {
return Some(make_diag(
"E_TRACE_ASSERT_FAIL",
&format!("Expected at most {} steps, got {}.", max, count),
Some(format!("Max steps: {}", max)),
None,
));
}
}
TraceAssertion::ArgsValid {
tool,
test_args,
policy,
expect,
} => {
if let Some(args) = test_args {
let Some(pol) = policy else {
return Some(make_diag(
"E_CONFIG_ERROR",
"ArgsValid assertion requires 'policy' field (schema) when used in unit test mode.",
None,
None
));
};
// Accommodate structure: { schema: { ... } } vs { properties: ... }
let schema = pol.get("schema").unwrap_or(pol);
// Wrap in tool map as expected by policy_engine
let policy_map = serde_json::json!({ tool: schema });
let verdict = crate::policy_engine::evaluate_tool_args(&policy_map, tool, args);
let expected_pass = expect.as_deref().unwrap_or("pass") == "pass";
let actual_pass = verdict.status == crate::policy_engine::VerdictStatus::Allowed;
if expected_pass != actual_pass {
return Some(make_diag(
"E_POLICY_ASSERT_FAIL",
&format!(
"ArgsValid check failed. Expected {}, got {}. Reason: {:?}",
if expected_pass { "PASS" } else { "FAIL" },
if actual_pass { "PASS" } else { "FAIL" },
verdict.details
),
None,
Some(serde_json::json!({
"tool": tool,
"args": args,
"verdict": verdict
})),
));
}
}
}
TraceAssertion::SequenceValid {
test_trace_raw,
policy,
expect,
..
} => {
if let Some(trace_vals) = test_trace_raw {
if let Some(pol) = policy {
// Extract tool names from trace
// trace_vals is Vec<Value>. Expect { tool_name: "..." }
let tools: Vec<String> = trace_vals
.iter()
.filter_map(|v| {
v.get("tool")
.or(v.get("tool_name"))
.and_then(|s| s.as_str())
.map(|s| s.to_string())
})
.collect();
// Policy is { rules: [...] } or just rules array?
// evaluate_sequence expects regex string.
// But parity.rs constructs regex from JSON rules.
// We need a helper to convert JSON rules to regex string.
// crate::policy_engine::evaluate_sequence takes (regex, tools).
// We can assume 'policy' here IS the regex string or we need transformation logic.
// Implementation Plan: Assume policy contains "rules" and we construct regex or simplistic "join".
// parity.rs did: rules.join(" THEN ") ? No, that was latency_check.
// policy_engine has NO helper to convert JSON->Regex yet?
// Wait, `policy_engine::evaluate_sequence` takes `policy_regex: &str`.
// Does `policy_engine` have a JSON parser?
// Let's assume for this specific integration, we pass the regex string in the policy field?
// Or we assume the user provides it.
// Actually, parity.rs handled `CheckType::SequenceValid` by converting JSON rules to Regex.
// If we want Asserts to work, we verify what format `policy` comes in.
// fp_suite.yaml doesn't specify policy format yet.
// Let's assume policy IS the regex string for simplicity now, or simple rule list.
// Simplified: We skip implementing full rule engine here if not readily avail.
// We will allow `policy` to contain `regex` field.
let regex = pol.get("regex").and_then(|s| s.as_str()).unwrap_or(".*");
let verdict = crate::policy_engine::evaluate_sequence(regex, &tools);
let expected_pass = expect.as_deref().unwrap_or("pass") == "pass";
let actual_pass =
verdict.status == crate::policy_engine::VerdictStatus::Allowed;
if expected_pass != actual_pass {
return Some(make_diag(
"E_POLICY_ASSERT_FAIL",
&format!(
"SequenceValid check failed. Expected {}, got {}.",
if expected_pass { "PASS" } else { "FAIL" },
if actual_pass { "PASS" } else { "FAIL" }
),
None,
None,
));
}
}
}
}
TraceAssertion::ToolBlocklist {
test_tool_calls,
policy,
expect,
..
} => {
if let Some(tools) = test_tool_calls {
if let Some(pol) = policy {
// pol should look like { "blocked": [...] }
let blocked: Vec<String> = pol
.get("blocked")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let expected_pass = expect.as_deref().unwrap_or("pass") == "pass";
// Check if *any* tool is blocked
let mut actual_pass = true;
for t in tools {
if blocked.contains(t) {
actual_pass = false;
break;
}
}
if expected_pass != actual_pass {
return Some(make_diag(
"E_POLICY_ASSERT_FAIL",
&format!(
"ToolBlocklist check failed. Expected {}, got {}.",
if expected_pass { "PASS" } else { "FAIL" },
if actual_pass { "PASS" } else { "FAIL" }
),
None,
None,
));
}
}
}
}
}
None
}
fn check_subsequence(
calls: &[crate::storage::rows::ToolCallRow],
expected: &[String],
) -> Result<(), String> {
let mut current_idx = 0; // index in calls
for expected_tool in expected {
// Find next occurrence of expected_tool starting from current_idx
let mut found = false;
while current_idx < calls.len() {
let row = &calls[current_idx];
current_idx += 1;
if row.tool_name.as_deref() == Some(expected_tool.as_str()) {
found = true;
break;
}
}
if !found {
return Err(format!(
"Expected tool '{}' in sequence, but not found (missing or out of order).",
expected_tool
));
}
}
Ok(())
}
fn make_diag(
code: &str,
message: &str,
_expected: Option<String>,
context: Option<serde_json::Value>,
) -> Diagnostic {
// We construct Diagnostic manually to match the struct definition.
// Note: DiagnosticCode enum usage is available in other files but here we might need strings?
// The Diagnostic struct uses String for code.
Diagnostic {
code: code.to_string(),
severity: "error".to_string(),
source: "agent_assertions".to_string(),
message: message.to_string(),
context: context.unwrap_or(serde_json::json!({})),
fix_steps: vec![],
}
}