1use crate::error::EvaluationError;
2use crate::tasks::evaluator::{PATH_REGEX, REGEX_FIELD_PARSE_PATTERN};
3use potato_head::ChatResponse;
4use scouter_types::genai::AgentAssertion;
5use serde_json::{json, Value};
6use tracing::error;
7
8const MAX_PATH_LEN: usize = 512;
9const MAX_PATH_SEGMENTS: usize = 32;
10
11#[derive(Debug, Clone)]
14pub struct AgentContextBuilder {
15 response: ChatResponse,
16 raw: Value,
17}
18
19impl AgentContextBuilder {
20 pub fn from_context(context: &Value) -> Result<Self, EvaluationError> {
29 let response_val = context.get("response").unwrap_or(context);
30 let response = ChatResponse::from_response_value(response_val.clone()).map_err(|e| {
31 error!("Failed to parse response: {}", e);
32 EvaluationError::InvalidProviderResponse
33 })?;
34 Ok(Self {
35 response,
36 raw: response_val.clone(),
37 })
38 }
39
40 pub fn build_context(&self, assertion: &AgentAssertion) -> Result<Value, EvaluationError> {
42 match assertion {
43 AgentAssertion::ToolCalled { name } => {
44 let found = self
45 .response
46 .get_tool_calls()
47 .iter()
48 .any(|tc| tc.name == *name);
49 Ok(json!(found))
50 }
51 AgentAssertion::ToolNotCalled { name } => {
52 let not_found = !self
53 .response
54 .get_tool_calls()
55 .iter()
56 .any(|tc| tc.name == *name);
57 Ok(json!(not_found))
58 }
59 AgentAssertion::ToolCalledWithArgs { name, arguments } => {
60 let matched =
61 self.response.get_tool_calls().iter().any(|tc| {
62 tc.name == *name && Self::partial_match(&tc.arguments, &arguments.0)
63 });
64 Ok(json!(matched))
65 }
66 AgentAssertion::ToolCallSequence { names } => {
67 let actual: Vec<String> = self
68 .response
69 .get_tool_calls()
70 .iter()
71 .map(|tc| tc.name.clone())
72 .collect();
73 let mut expected_iter = names.iter();
74 let mut current = expected_iter.next();
75 for actual_name in &actual {
76 if let Some(exp) = current {
77 if actual_name == exp {
78 current = expected_iter.next();
79 }
80 }
81 }
82 Ok(json!(current.is_none()))
83 }
84 AgentAssertion::ToolCallCount { name } => {
85 let tools = &self.response.get_tool_calls();
86 let count = if let Some(name) = name {
87 tools.iter().filter(|tc| tc.name == *name).count()
88 } else {
89 tools.len()
90 };
91 Ok(json!(count))
92 }
93 AgentAssertion::ToolArgument { name, argument_key } => {
94 let value = self
95 .response
96 .get_tool_calls()
97 .iter()
98 .find(|tc| tc.name == *name)
99 .and_then(|tc| tc.arguments.get(argument_key))
100 .cloned()
101 .unwrap_or(Value::Null);
102
103 Ok(value)
104 }
105 AgentAssertion::ToolResult { name } => {
106 let value = self
107 .response
108 .get_tool_calls()
109 .iter()
110 .find(|tc| tc.name == *name)
111 .and_then(|tc| tc.result.clone())
112 .unwrap_or(Value::Null);
113
114 Ok(value)
115 }
116 AgentAssertion::ResponseContent {} => {
117 let text = self.response.response_text();
118 if text.is_empty() {
119 Ok(Value::Null)
120 } else {
121 Ok(json!(text))
122 }
123 }
124 AgentAssertion::ResponseModel {} => Ok(self
125 .response
126 .model_name()
127 .map(|m| json!(m))
128 .unwrap_or(Value::Null)),
129 AgentAssertion::ResponseFinishReason {} => Ok(self
130 .response
131 .finish_reason_str()
132 .map(|f| json!(f))
133 .unwrap_or(Value::Null)),
134 AgentAssertion::ResponseInputTokens {} => Ok(self
135 .response
136 .input_tokens()
137 .map(|t| json!(t))
138 .unwrap_or(Value::Null)),
139 AgentAssertion::ResponseOutputTokens {} => Ok(self
140 .response
141 .output_tokens()
142 .map(|t| json!(t))
143 .unwrap_or(Value::Null)),
144 AgentAssertion::ResponseTotalTokens {} => Ok(self
145 .response
146 .total_tokens()
147 .map(|t| json!(t))
148 .unwrap_or(Value::Null)),
149 AgentAssertion::ResponseField { path } => Self::extract_by_path(&self.raw, path),
150 }
151 }
152
153 fn partial_match(actual: &Value, expected: &Value) -> bool {
157 match (actual, expected) {
158 (Value::Object(actual_map), Value::Object(expected_map)) => {
159 for (key, expected_val) in expected_map {
160 match actual_map.get(key) {
161 Some(actual_val) => {
162 if !Self::partial_match(actual_val, expected_val) {
163 return false;
164 }
165 }
166 None => return false,
167 }
168 }
169 true
170 }
171 _ => actual == expected,
172 }
173 }
174
175 fn extract_by_path(val: &Value, path: &str) -> Result<Value, EvaluationError> {
178 let mut current = val.clone();
179
180 for segment in Self::parse_path_segments(path)? {
181 match segment {
182 PathSegment::Key(key) => {
183 current = current.get(&key).cloned().unwrap_or(Value::Null);
184 }
185 PathSegment::Index(idx) => {
186 current = current
187 .as_array()
188 .and_then(|arr| arr.get(idx))
189 .cloned()
190 .unwrap_or(Value::Null);
191 }
192 }
193 }
194
195 Ok(current)
196 }
197
198 fn parse_path_segments(path: &str) -> Result<Vec<PathSegment>, EvaluationError> {
199 if path.len() > MAX_PATH_LEN {
200 return Err(EvaluationError::PathTooLong(path.len()));
201 }
202
203 let regex = PATH_REGEX.get_or_init(|| {
204 regex::Regex::new(REGEX_FIELD_PARSE_PATTERN)
205 .expect("Invalid regex pattern in REGEX_FIELD_PARSE_PATTERN")
206 });
207
208 let mut segments = Vec::new();
209
210 for capture in regex.find_iter(path) {
211 let s = capture.as_str();
212 if s.starts_with('[') && s.ends_with(']') {
213 let idx_str = &s[1..s.len() - 1];
214 let idx = idx_str
215 .parse::<usize>()
216 .map_err(|_| EvaluationError::InvalidArrayIndex(idx_str.to_string()))?;
217 segments.push(PathSegment::Index(idx));
218 } else {
219 segments.push(PathSegment::Key(s.to_string()));
220 }
221 }
222
223 if segments.is_empty() {
224 return Err(EvaluationError::EmptyFieldPath);
225 }
226
227 if segments.len() > MAX_PATH_SEGMENTS {
228 return Err(EvaluationError::TooManyPathSegments(segments.len()));
229 }
230
231 Ok(segments)
232 }
233}
234
235enum PathSegment {
236 Key(String),
237 Index(usize),
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use scouter_types::genai::PyValueWrapper;
244
245 #[test]
246 fn test_tool_called_assertion() {
247 let context = json!({
248 "model": "gpt-4o",
249 "choices": [{
250 "message": {
251 "role": "assistant",
252 "content": null,
253 "tool_calls": [{
254 "id": "call_1",
255 "type": "function",
256 "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
257 }]
258 },
259 "finish_reason": "tool_calls"
260 }]
261 });
262
263 let builder = AgentContextBuilder::from_context(&context).unwrap();
264
265 let result = builder
266 .build_context(&AgentAssertion::ToolCalled {
267 name: "web_search".to_string(),
268 })
269 .unwrap();
270 assert_eq!(result, json!(true));
271
272 let result = builder
273 .build_context(&AgentAssertion::ToolNotCalled {
274 name: "delete_user".to_string(),
275 })
276 .unwrap();
277 assert_eq!(result, json!(true));
278
279 let result = builder
280 .build_context(&AgentAssertion::ToolCallCount { name: None })
281 .unwrap();
282 assert_eq!(result, json!(1));
283 }
284
285 #[test]
286 fn test_tool_called_with_args_partial_match() {
287 let context = json!({
288 "model": "gpt-4o",
289 "choices": [{
290 "message": {
291 "role": "assistant",
292 "content": null,
293 "tool_calls": [{
294 "id": "call_1",
295 "type": "function",
296 "function": {"name": "web_search", "arguments": "{\"query\": \"weather NYC\", \"lang\": \"en\", \"limit\": 5}"}
297 }]
298 },
299 "finish_reason": "tool_calls"
300 }]
301 });
302
303 let builder = AgentContextBuilder::from_context(&context).unwrap();
304
305 let result = builder
307 .build_context(&AgentAssertion::ToolCalledWithArgs {
308 name: "web_search".to_string(),
309 arguments: PyValueWrapper(json!({"query": "weather NYC"})),
310 })
311 .unwrap();
312 assert_eq!(result, json!(true));
313
314 let result = builder
316 .build_context(&AgentAssertion::ToolCalledWithArgs {
317 name: "web_search".to_string(),
318 arguments: PyValueWrapper(json!({"query": "weather LA"})),
319 })
320 .unwrap();
321 assert_eq!(result, json!(false));
322 }
323
324 #[test]
325 fn test_tool_call_sequence() {
326 let context = json!({
327 "model": "gpt-4o",
328 "choices": [{
329 "message": {
330 "role": "assistant",
331 "content": null,
332 "tool_calls": [
333 {"id": "call_1", "type": "function", "function": {"name": "web_search", "arguments": "{}"}},
334 {"id": "call_2", "type": "function", "function": {"name": "summarize", "arguments": "{}"}},
335 {"id": "call_3", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
336 ]
337 },
338 "finish_reason": "tool_calls"
339 }]
340 });
341
342 let builder = AgentContextBuilder::from_context(&context).unwrap();
343
344 let result = builder
345 .build_context(&AgentAssertion::ToolCallSequence {
346 names: vec![
347 "web_search".to_string(),
348 "summarize".to_string(),
349 "respond".to_string(),
350 ],
351 })
352 .unwrap();
353 assert_eq!(result, json!(true));
354
355 let result = builder
357 .build_context(&AgentAssertion::ToolCallSequence {
358 names: vec!["respond".to_string(), "web_search".to_string()],
359 })
360 .unwrap();
361 assert_eq!(result, json!(false));
362 }
363
364 #[test]
365 fn test_response_field_escape_hatch() {
366 let context = json!({
367 "response": {
368 "candidates": [{
369 "content": {"role": "model", "parts": [{"text": "hello"}]},
370 "finishReason": "STOP",
371 "safety_ratings": [{"category": "HARM_CATEGORY_SAFE"}]
372 }],
373 "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
374 }
375 });
376
377 let builder = AgentContextBuilder::from_context(&context).unwrap();
378
379 let result = builder
381 .build_context(&AgentAssertion::ResponseField {
382 path: "candidates[0].safety_ratings[0].category".to_string(),
383 })
384 .unwrap();
385 assert_eq!(result, json!("HARM_CATEGORY_SAFE"));
386 }
387
388 #[test]
389 fn test_no_tool_calls() {
390 let context = json!({
391 "model": "gpt-4o",
392 "choices": [{
393 "message": {
394 "role": "assistant",
395 "content": "Just a text response."
396 },
397 "finish_reason": "stop"
398 }]
399 });
400
401 let builder = AgentContextBuilder::from_context(&context).unwrap();
402
403 let result = builder
404 .build_context(&AgentAssertion::ToolNotCalled {
405 name: "web_search".to_string(),
406 })
407 .unwrap();
408 assert_eq!(result, json!(true));
409 }
410
411 #[test]
412 fn test_from_context_invalid_json() {
413 let context = json!({});
415 let result = AgentContextBuilder::from_context(&context);
416 assert!(result.is_err());
417 assert!(matches!(
418 result,
419 Err(EvaluationError::InvalidProviderResponse)
420 ));
421 }
422
423 #[test]
424 fn test_tool_call_sequence_subsequence() {
425 let context = json!({
426 "model": "gpt-4o",
427 "choices": [{
428 "message": {
429 "role": "assistant",
430 "content": null,
431 "tool_calls": [
432 {"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}},
433 {"id": "c2", "type": "function", "function": {"name": "filter", "arguments": "{}"}},
434 {"id": "c3", "type": "function", "function": {"name": "rank", "arguments": "{}"}},
435 {"id": "c4", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
436 ]
437 },
438 "finish_reason": "tool_calls"
439 }]
440 });
441
442 let builder = AgentContextBuilder::from_context(&context).unwrap();
443
444 let result = builder
446 .build_context(&AgentAssertion::ToolCallSequence {
447 names: vec![
448 "search".to_string(),
449 "rank".to_string(),
450 "respond".to_string(),
451 ],
452 })
453 .unwrap();
454 assert_eq!(result, json!(true));
455
456 let result = builder
458 .build_context(&AgentAssertion::ToolCallSequence {
459 names: vec!["respond".to_string(), "search".to_string()],
460 })
461 .unwrap();
462 assert_eq!(result, json!(false));
463 }
464
465 #[test]
466 fn test_parse_path_segments_errors() {
467 let result = AgentContextBuilder::parse_path_segments("");
469 assert!(matches!(result, Err(EvaluationError::EmptyFieldPath)));
470
471 let long_path = "a".repeat(MAX_PATH_LEN + 1);
473 let result = AgentContextBuilder::parse_path_segments(&long_path);
474 assert!(matches!(result, Err(EvaluationError::PathTooLong(_))));
475
476 let many_segments = (0..MAX_PATH_SEGMENTS + 1)
478 .map(|i| format!("seg{}", i))
479 .collect::<Vec<_>>()
480 .join(".");
481 let result = AgentContextBuilder::parse_path_segments(&many_segments);
482 assert!(matches!(
483 result,
484 Err(EvaluationError::TooManyPathSegments(_))
485 ));
486 }
487
488 #[test]
489 fn test_response_content_empty() {
490 let context = json!({
491 "model": "gpt-4o",
492 "choices": [{
493 "message": {
494 "role": "assistant",
495 "content": null
496 },
497 "finish_reason": "stop"
498 }]
499 });
500
501 let builder = AgentContextBuilder::from_context(&context).unwrap();
502 let result = builder
503 .build_context(&AgentAssertion::ResponseContent {})
504 .unwrap();
505 assert_eq!(result, Value::Null);
506 }
507
508 #[test]
509 fn test_partial_match_nested() {
510 let context = json!({
511 "model": "gpt-4o",
512 "choices": [{
513 "message": {
514 "role": "assistant",
515 "content": null,
516 "tool_calls": [{
517 "id": "c1",
518 "type": "function",
519 "function": {"name": "create_item", "arguments": "{\"item\": {\"name\": \"widget\", \"price\": 9.99, \"tags\": [\"sale\"]}}"}
520 }]
521 },
522 "finish_reason": "tool_calls"
523 }]
524 });
525
526 let builder = AgentContextBuilder::from_context(&context).unwrap();
527
528 let result = builder
530 .build_context(&AgentAssertion::ToolCalledWithArgs {
531 name: "create_item".to_string(),
532 arguments: PyValueWrapper(json!({"item": {"name": "widget"}})),
533 })
534 .unwrap();
535 assert_eq!(result, json!(true));
536
537 let result = builder
539 .build_context(&AgentAssertion::ToolCalledWithArgs {
540 name: "create_item".to_string(),
541 arguments: PyValueWrapper(json!({"item": {"name": "gadget"}})),
542 })
543 .unwrap();
544 assert_eq!(result, json!(false));
545 }
546
547 #[test]
548 fn test_tool_result_extraction() {
549 let context = json!({
552 "model": "gpt-4o",
553 "choices": [{
554 "message": {
555 "role": "assistant",
556 "content": null,
557 "tool_calls": [{
558 "id": "c1",
559 "type": "function",
560 "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
561 }]
562 },
563 "finish_reason": "tool_calls"
564 }]
565 });
566
567 let builder = AgentContextBuilder::from_context(&context).unwrap();
568
569 let result = builder
571 .build_context(&AgentAssertion::ToolResult {
572 name: "web_search".to_string(),
573 })
574 .unwrap();
575 assert_eq!(result, Value::Null);
576
577 let result = builder
579 .build_context(&AgentAssertion::ToolResult {
580 name: "nonexistent".to_string(),
581 })
582 .unwrap();
583 assert_eq!(result, Value::Null);
584 }
585
586 #[test]
587 fn test_tool_argument_extraction() {
588 let context = json!({
589 "model": "gpt-4o",
590 "choices": [{
591 "message": {
592 "role": "assistant",
593 "content": null,
594 "tool_calls": [{
595 "id": "call_1",
596 "type": "function",
597 "function": {"name": "web_search", "arguments": "{\"query\": \"test query\", \"limit\": 10}"}
598 }]
599 },
600 "finish_reason": "tool_calls"
601 }]
602 });
603
604 let builder = AgentContextBuilder::from_context(&context).unwrap();
605
606 let result = builder
607 .build_context(&AgentAssertion::ToolArgument {
608 name: "web_search".to_string(),
609 argument_key: "query".to_string(),
610 })
611 .unwrap();
612 assert_eq!(result, json!("test query"));
613
614 let result = builder
615 .build_context(&AgentAssertion::ToolArgument {
616 name: "web_search".to_string(),
617 argument_key: "missing".to_string(),
618 })
619 .unwrap();
620 assert_eq!(result, Value::Null);
621 }
622}