1use std::sync::Arc;
24
25use adk_core::{Event, Llm, LlmRequest, Part};
26use futures::StreamExt;
27use serde::{Deserialize, Serialize};
28use tracing::warn;
29
30use crate::error::{EvalError, Result};
31use crate::schema::{ContentData, EvalCase, Turn};
32
33#[derive(Debug, Clone)]
35pub struct GeneratorConfig {
36 pub cases_per_description: usize,
38 pub include_tool_expectations: bool,
40}
41
42impl Default for GeneratorConfig {
43 fn default() -> Self {
44 Self { cases_per_description: 5, include_tool_expectations: true }
45 }
46}
47
48#[derive(Debug, Clone, Default, Serialize, Deserialize)]
50pub struct EvalCaseMetadata {
51 #[serde(default)]
53 pub generated: bool,
54 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub source: Option<String>,
57}
58
59pub struct TestGenerator {
61 model: Arc<dyn Llm>,
62 config: GeneratorConfig,
63}
64
65impl TestGenerator {
66 pub fn new(model: Arc<dyn Llm>) -> Self {
68 Self { model, config: GeneratorConfig::default() }
69 }
70
71 pub fn with_config(model: Arc<dyn Llm>, config: GeneratorConfig) -> Self {
73 Self { model, config }
74 }
75
76 pub async fn generate_from_description(&self, description: &str) -> Result<Vec<EvalCase>> {
82 let prompt = self.build_generation_prompt(description);
83
84 let request = LlmRequest::new(
85 self.model.name().to_string(),
86 vec![adk_core::Content::new("user").with_text(&prompt)],
87 );
88
89 let mut stream = self
90 .model
91 .generate_content(request, false)
92 .await
93 .map_err(|e| EvalError::GenerationError(format!("LLM request failed: {e}")))?;
94
95 let mut response_text = String::new();
97 while let Some(chunk) = stream.next().await {
98 match chunk {
99 Ok(response) => {
100 if let Some(content) = &response.content {
101 for part in &content.parts {
102 if let Part::Text { text } = part {
103 response_text.push_str(text);
104 }
105 }
106 }
107 }
108 Err(e) => {
109 return Err(EvalError::GenerationError(format!("LLM stream error: {e}")));
110 }
111 }
112 }
113
114 self.parse_generated_cases(&response_text, description)
116 }
117
118 pub fn generate_from_events(&self, events: &[Event]) -> Result<Vec<EvalCase>> {
123 if events.is_empty() {
124 return Ok(Vec::new());
125 }
126
127 let mut invocations: Vec<(String, Vec<&Event>)> = Vec::new();
129 for event in events {
130 if let Some(last) = invocations.last_mut()
131 && last.0 == event.invocation_id
132 {
133 last.1.push(event);
134 continue;
135 }
136 invocations.push((event.invocation_id.clone(), vec![event]));
137 }
138
139 let mut turns = Vec::new();
140
141 for (invocation_id, inv_events) in &invocations {
142 let mut user_text = String::new();
143 let mut model_text = String::new();
144 let mut tool_uses = Vec::new();
145
146 for event in inv_events {
147 if let Some(content) = event.content() {
148 match content.role.as_str() {
149 "user" => {
150 for part in &content.parts {
151 if let Part::Text { text } = part {
152 if !user_text.is_empty() {
153 user_text.push(' ');
154 }
155 user_text.push_str(text);
156 }
157 }
158 }
159 "model" => {
160 for part in &content.parts {
161 match part {
162 Part::Text { text } => {
163 if !model_text.is_empty() {
164 model_text.push(' ');
165 }
166 model_text.push_str(text);
167 }
168 Part::FunctionCall { name, args, .. }
169 if self.config.include_tool_expectations =>
170 {
171 tool_uses.push(crate::schema::ToolUse {
172 name: name.clone(),
173 args: args.clone(),
174 expected_response: None,
175 });
176 }
177 _ => {}
178 }
179 }
180 }
181 _ => {}
182 }
183 }
184 }
185
186 if !user_text.is_empty() {
188 let final_response = if model_text.is_empty() {
189 None
190 } else {
191 Some(ContentData::model_response(&model_text))
192 };
193
194 let intermediate_data = if tool_uses.is_empty() {
195 None
196 } else {
197 Some(crate::schema::IntermediateData {
198 tool_uses,
199 intermediate_responses: Vec::new(),
200 })
201 };
202
203 turns.push(Turn {
204 invocation_id: invocation_id.clone(),
205 user_content: ContentData::text(&user_text),
206 final_response,
207 intermediate_data,
208 });
209 }
210 }
211
212 if turns.is_empty() {
213 return Ok(Vec::new());
214 }
215
216 let eval_case = EvalCase {
217 eval_id: format!("generated_from_events_{}", uuid::Uuid::new_v4()),
218 description: "Generated from event logs".to_string(),
219 conversation: turns,
220 session_input: Default::default(),
221 tags: vec!["generated".to_string()],
222 metadata: Some(EvalCaseMetadata {
223 generated: true,
224 source: Some("events".to_string()),
225 }),
226 };
227
228 Ok(vec![eval_case])
229 }
230
231 fn build_generation_prompt(&self, description: &str) -> String {
233 let tool_instruction = if self.config.include_tool_expectations {
234 r#"Include "intermediate_data" with "tool_uses" where appropriate, each with "name" and "args" fields."#
235 } else {
236 r#"Do not include "intermediate_data" in the output."#
237 };
238
239 format!(
240 r#"Generate exactly {count} evaluation test cases for the following agent description:
241
242"{description}"
243
244Each test case must be a JSON object with these fields:
245- "eval_id": a unique string identifier (e.g., "test_1", "test_2")
246- "description": a brief description of what the test case validates
247- "conversation": an array of conversation turns, each with:
248 - "invocation_id": a unique string (e.g., "inv_1")
249 - "user_content": object with "parts": [{{"text": "..."}}] and "role": "user"
250 - "final_response": object with "parts": [{{"text": "..."}}] and "role": "model"
251 {tool_instruction}
252
253Output ONLY a JSON array of test case objects. No markdown fences, no explanation text.
254Example format:
255[
256 {{
257 "eval_id": "test_1",
258 "description": "Basic greeting test",
259 "conversation": [
260 {{
261 "invocation_id": "inv_1",
262 "user_content": {{"parts": [{{"text": "Hello"}}], "role": "user"}},
263 "final_response": {{"parts": [{{"text": "Hi there! How can I help?"}}], "role": "model"}}
264 }}
265 ]
266 }}
267]"#,
268 count = self.config.cases_per_description,
269 description = description,
270 tool_instruction = tool_instruction,
271 )
272 }
273
274 fn parse_generated_cases(
276 &self,
277 response_text: &str,
278 description: &str,
279 ) -> Result<Vec<EvalCase>> {
280 let json_text = extract_json_array(response_text).unwrap_or(response_text);
281
282 let raw_cases: Vec<serde_json::Value> = match serde_json::from_str(json_text) {
284 Ok(cases) => cases,
285 Err(e) => {
286 warn!("failed to parse LLM response as JSON array: {e}");
288 return Err(EvalError::GenerationError(format!(
289 "LLM returned unparseable response: {e}"
290 )));
291 }
292 };
293
294 let source = format!("description: {description}");
295 let mut cases = Vec::new();
296
297 for (i, raw_case) in raw_cases.iter().enumerate() {
298 match serde_json::from_value::<EvalCase>(raw_case.clone()) {
299 Ok(mut eval_case) => {
300 if !eval_case.tags.contains(&"generated".to_string()) {
302 eval_case.tags.push("generated".to_string());
303 }
304 cases.push(eval_case);
305 }
306 Err(e) => {
307 warn!(
309 case_index = i,
310 error = %e,
311 "skipping unparseable generated case"
312 );
313 }
314 }
315 }
316
317 if cases.is_empty() && !raw_cases.is_empty() {
318 return Err(EvalError::GenerationError(format!(
319 "all {count} generated cases failed to parse (source: {source})",
320 count = raw_cases.len(),
321 )));
322 }
323
324 for case in &mut cases {
327 if !case.tags.contains(&source) {
328 case.tags.push(source.clone());
329 }
330 }
331
332 Ok(cases)
333 }
334}
335
336fn extract_json_array(text: &str) -> Option<&str> {
343 let trimmed = text.trim();
344
345 if trimmed.starts_with('[') {
347 return Some(trimmed);
348 }
349
350 if let Some(start) = trimmed.find("```json") {
352 let content_start = start + "```json".len();
353 if let Some(end) = trimmed[content_start..].find("```") {
354 let json_content = trimmed[content_start..content_start + end].trim();
355 if json_content.starts_with('[') {
356 return Some(json_content);
357 }
358 }
359 }
360
361 if let Some(start) = trimmed.find("```") {
363 let content_start = start + 3;
364 let line_end = trimmed[content_start..]
366 .find('\n')
367 .map(|i| content_start + i + 1)
368 .unwrap_or(content_start);
369 if let Some(end) = trimmed[line_end..].find("```") {
370 let json_content = trimmed[line_end..line_end + end].trim();
371 if json_content.starts_with('[') {
372 return Some(json_content);
373 }
374 }
375 }
376
377 if let Some(start) = trimmed.find('[')
379 && let Some(end) = trimmed.rfind(']')
380 && end > start
381 {
382 return Some(&trimmed[start..=end]);
383 }
384
385 None
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_generator_config_defaults() {
394 let config = GeneratorConfig::default();
395 assert_eq!(config.cases_per_description, 5);
396 assert!(config.include_tool_expectations);
397 }
398
399 #[test]
400 fn test_extract_json_array_raw() {
401 let input = r#"[{"eval_id": "test_1"}]"#;
402 let result = extract_json_array(input);
403 assert_eq!(result, Some(input));
404 }
405
406 #[test]
407 fn test_extract_json_array_fenced() {
408 let input = "Here are the cases:\n```json\n[{\"eval_id\": \"test_1\"}]\n```\nDone!";
409 let result = extract_json_array(input);
410 assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
411 }
412
413 #[test]
414 fn test_extract_json_array_embedded() {
415 let input = "Sure, here are the cases: [{\"eval_id\": \"test_1\"}] and that's all.";
416 let result = extract_json_array(input);
417 assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
418 }
419
420 #[test]
421 fn test_extract_json_array_no_array() {
422 let input = "No JSON here at all.";
423 let result = extract_json_array(input);
424 assert_eq!(result, None);
425 }
426
427 #[test]
428 fn test_extract_json_array_with_whitespace() {
429 let input = " \n [{\"eval_id\": \"test_1\"}] \n ";
430 let result = extract_json_array(input);
431 assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
432 }
433
434 #[test]
435 fn test_generate_from_events_empty() {
436 use adk_core::Llm;
437 use async_trait::async_trait;
438
439 struct MockLlm;
440
441 #[async_trait]
442 impl Llm for MockLlm {
443 fn name(&self) -> &str {
444 "mock"
445 }
446 async fn generate_content(
447 &self,
448 _req: LlmRequest,
449 _stream: bool,
450 ) -> adk_core::Result<adk_core::LlmResponseStream> {
451 unimplemented!()
452 }
453 }
454
455 let generator = TestGenerator::new(Arc::new(MockLlm));
456 let result = generator.generate_from_events(&[]).unwrap();
457 assert!(result.is_empty());
458 }
459
460 #[test]
461 fn test_generate_from_events_with_conversation() {
462 use adk_core::{Content, Llm, LlmResponse};
463 use async_trait::async_trait;
464
465 struct MockLlm;
466
467 #[async_trait]
468 impl Llm for MockLlm {
469 fn name(&self) -> &str {
470 "mock"
471 }
472 async fn generate_content(
473 &self,
474 _req: LlmRequest,
475 _stream: bool,
476 ) -> adk_core::Result<adk_core::LlmResponseStream> {
477 unimplemented!()
478 }
479 }
480
481 let mut events = Vec::new();
482
483 let mut user_event = Event::new("inv_1");
485 user_event.author = "user".to_string();
486 user_event.llm_response = LlmResponse {
487 content: Some(Content::new("user").with_text("What is the weather?")),
488 ..Default::default()
489 };
490 events.push(user_event);
491
492 let mut model_event = Event::new("inv_1");
494 model_event.author = "model".to_string();
495 model_event.llm_response = LlmResponse {
496 content: Some(Content::new("model").with_text("The weather is sunny.")),
497 ..Default::default()
498 };
499 events.push(model_event);
500
501 let generator = TestGenerator::new(Arc::new(MockLlm));
502 let cases = generator.generate_from_events(&events).unwrap();
503
504 assert_eq!(cases.len(), 1);
505 let case = &cases[0];
506 assert!(case.eval_id.starts_with("generated_from_events_"));
507 assert_eq!(case.conversation.len(), 1);
508
509 let turn = &case.conversation[0];
510 assert_eq!(turn.invocation_id, "inv_1");
511 assert_eq!(turn.user_content.get_text(), "What is the weather?");
512 assert_eq!(turn.final_response.as_ref().unwrap().get_text(), "The weather is sunny.");
513 assert!(case.tags.contains(&"generated".to_string()));
514 }
515
516 #[test]
517 fn test_generate_from_events_with_tool_calls() {
518 use adk_core::{Content, Llm, LlmResponse, Part};
519 use async_trait::async_trait;
520
521 struct MockLlm;
522
523 #[async_trait]
524 impl Llm for MockLlm {
525 fn name(&self) -> &str {
526 "mock"
527 }
528 async fn generate_content(
529 &self,
530 _req: LlmRequest,
531 _stream: bool,
532 ) -> adk_core::Result<adk_core::LlmResponseStream> {
533 unimplemented!()
534 }
535 }
536
537 let mut events = Vec::new();
538
539 let mut user_event = Event::new("inv_1");
541 user_event.llm_response = LlmResponse {
542 content: Some(Content::new("user").with_text("Get weather in NYC")),
543 ..Default::default()
544 };
545 events.push(user_event);
546
547 let mut model_event = Event::new("inv_1");
549 model_event.llm_response = LlmResponse {
550 content: Some(Content {
551 role: "model".to_string(),
552 parts: vec![
553 Part::FunctionCall {
554 name: "get_weather".to_string(),
555 args: serde_json::json!({"location": "NYC"}),
556 id: Some("call_1".to_string()),
557 thought_signature: None,
558 },
559 Part::Text { text: "It's 72°F in NYC.".to_string() },
560 ],
561 }),
562 ..Default::default()
563 };
564 events.push(model_event);
565
566 let generator = TestGenerator::new(Arc::new(MockLlm));
567 let cases = generator.generate_from_events(&events).unwrap();
568
569 assert_eq!(cases.len(), 1);
570 let turn = &cases[0].conversation[0];
571 let intermediate = turn.intermediate_data.as_ref().unwrap();
572 assert_eq!(intermediate.tool_uses.len(), 1);
573 assert_eq!(intermediate.tool_uses[0].name, "get_weather");
574 assert_eq!(intermediate.tool_uses[0].args, serde_json::json!({"location": "NYC"}));
575 }
576
577 #[test]
578 fn test_parse_generated_cases_valid() {
579 use adk_core::Llm;
580 use async_trait::async_trait;
581
582 struct MockLlm;
583
584 #[async_trait]
585 impl Llm for MockLlm {
586 fn name(&self) -> &str {
587 "mock"
588 }
589 async fn generate_content(
590 &self,
591 _req: LlmRequest,
592 _stream: bool,
593 ) -> adk_core::Result<adk_core::LlmResponseStream> {
594 unimplemented!()
595 }
596 }
597
598 let generator = TestGenerator::new(Arc::new(MockLlm));
599 let response = r#"[
600 {
601 "eval_id": "test_1",
602 "description": "Greeting test",
603 "conversation": [{
604 "invocation_id": "inv_1",
605 "user_content": {"parts": [{"text": "Hello"}], "role": "user"},
606 "final_response": {"parts": [{"text": "Hi!"}], "role": "model"}
607 }]
608 }
609 ]"#;
610
611 let cases = generator.parse_generated_cases(response, "test agent").unwrap();
612 assert_eq!(cases.len(), 1);
613 assert_eq!(cases[0].eval_id, "test_1");
614 assert!(cases[0].tags.contains(&"generated".to_string()));
615 assert!(cases[0].tags.contains(&"description: test agent".to_string()));
616 }
617
618 #[test]
619 fn test_parse_generated_cases_partial_failure() {
620 use adk_core::Llm;
621 use async_trait::async_trait;
622
623 struct MockLlm;
624
625 #[async_trait]
626 impl Llm for MockLlm {
627 fn name(&self) -> &str {
628 "mock"
629 }
630 async fn generate_content(
631 &self,
632 _req: LlmRequest,
633 _stream: bool,
634 ) -> adk_core::Result<adk_core::LlmResponseStream> {
635 unimplemented!()
636 }
637 }
638
639 let generator = TestGenerator::new(Arc::new(MockLlm));
640 let response = r#"[
642 {
643 "eval_id": "test_1",
644 "description": "Valid case",
645 "conversation": [{
646 "invocation_id": "inv_1",
647 "user_content": {"parts": [{"text": "Hello"}], "role": "user"},
648 "final_response": {"parts": [{"text": "Hi!"}], "role": "model"}
649 }]
650 },
651 {
652 "invalid_field": "This is not a valid EvalCase"
653 }
654 ]"#;
655
656 let cases = generator.parse_generated_cases(response, "test").unwrap();
657 assert_eq!(cases.len(), 1);
659 assert_eq!(cases[0].eval_id, "test_1");
660 }
661
662 #[test]
663 fn test_parse_generated_cases_all_invalid() {
664 use adk_core::Llm;
665 use async_trait::async_trait;
666
667 struct MockLlm;
668
669 #[async_trait]
670 impl Llm for MockLlm {
671 fn name(&self) -> &str {
672 "mock"
673 }
674 async fn generate_content(
675 &self,
676 _req: LlmRequest,
677 _stream: bool,
678 ) -> adk_core::Result<adk_core::LlmResponseStream> {
679 unimplemented!()
680 }
681 }
682
683 let generator = TestGenerator::new(Arc::new(MockLlm));
684 let response = r#"[{"bad": true}, {"also_bad": "yes"}]"#;
685
686 let result = generator.parse_generated_cases(response, "test");
687 assert!(result.is_err());
688 let err = result.unwrap_err().to_string();
689 assert!(err.contains("all 2 generated cases failed to parse"));
690 }
691
692 #[test]
693 fn test_eval_case_metadata_serialization() {
694 let meta = EvalCaseMetadata { generated: true, source: Some("events".to_string()) };
695
696 let json = serde_json::to_string(&meta).unwrap();
697 assert!(json.contains("\"generated\":true"));
698 assert!(json.contains("\"source\":\"events\""));
699
700 let deserialized: EvalCaseMetadata = serde_json::from_str(&json).unwrap();
701 assert!(deserialized.generated);
702 assert_eq!(deserialized.source.as_deref(), Some("events"));
703 }
704
705 #[test]
706 fn test_eval_case_metadata_defaults() {
707 let meta = EvalCaseMetadata::default();
708 assert!(!meta.generated);
709 assert!(meta.source.is_none());
710
711 let json = serde_json::to_string(&meta).unwrap();
713 assert!(!json.contains("source"));
714 }
715
716 #[test]
717 fn test_generate_from_events_no_tool_expectations() {
718 use adk_core::{Content, Llm, LlmResponse, Part};
719 use async_trait::async_trait;
720
721 struct MockLlm;
722
723 #[async_trait]
724 impl Llm for MockLlm {
725 fn name(&self) -> &str {
726 "mock"
727 }
728 async fn generate_content(
729 &self,
730 _req: LlmRequest,
731 _stream: bool,
732 ) -> adk_core::Result<adk_core::LlmResponseStream> {
733 unimplemented!()
734 }
735 }
736
737 let config = GeneratorConfig { cases_per_description: 5, include_tool_expectations: false };
738 let generator = TestGenerator::with_config(Arc::new(MockLlm), config);
739
740 let mut events = Vec::new();
741
742 let mut user_event = Event::new("inv_1");
743 user_event.llm_response = LlmResponse {
744 content: Some(Content::new("user").with_text("Get weather")),
745 ..Default::default()
746 };
747 events.push(user_event);
748
749 let mut model_event = Event::new("inv_1");
750 model_event.llm_response = LlmResponse {
751 content: Some(Content {
752 role: "model".to_string(),
753 parts: vec![
754 Part::FunctionCall {
755 name: "get_weather".to_string(),
756 args: serde_json::json!({"location": "NYC"}),
757 id: None,
758 thought_signature: None,
759 },
760 Part::Text { text: "Sunny".to_string() },
761 ],
762 }),
763 ..Default::default()
764 };
765 events.push(model_event);
766
767 let cases = generator.generate_from_events(&events).unwrap();
768 assert_eq!(cases.len(), 1);
769 assert!(cases[0].conversation[0].intermediate_data.is_none());
771 }
772}