Skip to main content

adk_eval/
test_generator.rs

1//! LLM-driven test case generation.
2//!
3//! Generates evaluation test cases from natural language descriptions (via LLM)
4//! or from production event logs (direct extraction). Produced cases follow
5//! the standard [`TestFile`] JSON format and include generation metadata.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use adk_eval::test_generator::{TestGenerator, GeneratorConfig};
11//! use std::sync::Arc;
12//!
13//! let generator = TestGenerator::with_config(model, GeneratorConfig {
14//!     cases_per_description: 3,
15//!     include_tool_expectations: true,
16//! });
17//!
18//! let cases = generator
19//!     .generate_from_description("A weather assistant that can look up forecasts")
20//!     .await?;
21//! ```
22
23use std::sync::Arc;
24
25use adk_core::{Event, Llm, LlmRequest, Part};
26use futures::StreamExt;
27use serde::{Deserialize, Serialize};
28use tracing::warn;
29
30use crate::error::{EvalError, Result};
31use crate::schema::{ContentData, EvalCase, Turn};
32
33/// Configuration for test case generation.
34#[derive(Debug, Clone)]
35pub struct GeneratorConfig {
36    /// Number of cases to generate per description.
37    pub cases_per_description: usize,
38    /// Whether to include tool use expectations in generated cases.
39    pub include_tool_expectations: bool,
40}
41
42impl Default for GeneratorConfig {
43    fn default() -> Self {
44        Self { cases_per_description: 5, include_tool_expectations: true }
45    }
46}
47
48/// Metadata for generated eval cases.
49#[derive(Debug, Clone, Default, Serialize, Deserialize)]
50pub struct EvalCaseMetadata {
51    /// Whether this case was auto-generated.
52    #[serde(default)]
53    pub generated: bool,
54    /// Source description (e.g., "description: ..." or "events").
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub source: Option<String>,
57}
58
59/// Generates evaluation test cases from descriptions or event logs.
60pub struct TestGenerator {
61    model: Arc<dyn Llm>,
62    config: GeneratorConfig,
63}
64
65impl TestGenerator {
66    /// Creates a new test generator with default configuration.
67    pub fn new(model: Arc<dyn Llm>) -> Self {
68        Self { model, config: GeneratorConfig::default() }
69    }
70
71    /// Creates a new test generator with custom configuration.
72    pub fn with_config(model: Arc<dyn Llm>, config: GeneratorConfig) -> Self {
73        Self { model, config }
74    }
75
76    /// Generate eval cases from a natural language description.
77    ///
78    /// Prompts the LLM to produce eval case definitions as JSON. On parse failure
79    /// for any individual case, a warning is logged and that case is skipped
80    /// without aborting the batch.
81    pub async fn generate_from_description(&self, description: &str) -> Result<Vec<EvalCase>> {
82        let prompt = self.build_generation_prompt(description);
83
84        let request = LlmRequest::new(
85            self.model.name().to_string(),
86            vec![adk_core::Content::new("user").with_text(&prompt)],
87        );
88
89        let mut stream = self
90            .model
91            .generate_content(request, false)
92            .await
93            .map_err(|e| EvalError::GenerationError(format!("LLM request failed: {e}")))?;
94
95        // Collect the full response text
96        let mut response_text = String::new();
97        while let Some(chunk) = stream.next().await {
98            match chunk {
99                Ok(response) => {
100                    if let Some(content) = &response.content {
101                        for part in &content.parts {
102                            if let Part::Text { text } = part {
103                                response_text.push_str(text);
104                            }
105                        }
106                    }
107                }
108                Err(e) => {
109                    return Err(EvalError::GenerationError(format!("LLM stream error: {e}")));
110                }
111            }
112        }
113
114        // Parse the response into eval cases
115        self.parse_generated_cases(&response_text, description)
116    }
117
118    /// Generate eval cases from production event logs.
119    ///
120    /// Extracts conversation turns from the provided events, constructing
121    /// [`EvalCase`] objects directly without invoking the LLM.
122    pub fn generate_from_events(&self, events: &[Event]) -> Result<Vec<EvalCase>> {
123        if events.is_empty() {
124            return Ok(Vec::new());
125        }
126
127        // Group events by invocation_id to form conversation turns
128        let mut invocations: Vec<(String, Vec<&Event>)> = Vec::new();
129        for event in events {
130            if let Some(last) = invocations.last_mut()
131                && last.0 == event.invocation_id
132            {
133                last.1.push(event);
134                continue;
135            }
136            invocations.push((event.invocation_id.clone(), vec![event]));
137        }
138
139        let mut turns = Vec::new();
140
141        for (invocation_id, inv_events) in &invocations {
142            let mut user_text = String::new();
143            let mut model_text = String::new();
144            let mut tool_uses = Vec::new();
145
146            for event in inv_events {
147                if let Some(content) = event.content() {
148                    match content.role.as_str() {
149                        "user" => {
150                            for part in &content.parts {
151                                if let Part::Text { text } = part {
152                                    if !user_text.is_empty() {
153                                        user_text.push(' ');
154                                    }
155                                    user_text.push_str(text);
156                                }
157                            }
158                        }
159                        "model" => {
160                            for part in &content.parts {
161                                match part {
162                                    Part::Text { text } => {
163                                        if !model_text.is_empty() {
164                                            model_text.push(' ');
165                                        }
166                                        model_text.push_str(text);
167                                    }
168                                    Part::FunctionCall { name, args, .. }
169                                        if self.config.include_tool_expectations =>
170                                    {
171                                        tool_uses.push(crate::schema::ToolUse {
172                                            name: name.clone(),
173                                            args: args.clone(),
174                                            expected_response: None,
175                                        });
176                                    }
177                                    _ => {}
178                                }
179                            }
180                        }
181                        _ => {}
182                    }
183                }
184            }
185
186            // Only create a turn if we have user content
187            if !user_text.is_empty() {
188                let final_response = if model_text.is_empty() {
189                    None
190                } else {
191                    Some(ContentData::model_response(&model_text))
192                };
193
194                let intermediate_data = if tool_uses.is_empty() {
195                    None
196                } else {
197                    Some(crate::schema::IntermediateData {
198                        tool_uses,
199                        intermediate_responses: Vec::new(),
200                    })
201                };
202
203                turns.push(Turn {
204                    invocation_id: invocation_id.clone(),
205                    user_content: ContentData::text(&user_text),
206                    final_response,
207                    intermediate_data,
208                });
209            }
210        }
211
212        if turns.is_empty() {
213            return Ok(Vec::new());
214        }
215
216        let eval_case = EvalCase {
217            eval_id: format!("generated_from_events_{}", uuid::Uuid::new_v4()),
218            description: "Generated from event logs".to_string(),
219            conversation: turns,
220            session_input: Default::default(),
221            tags: vec!["generated".to_string()],
222            metadata: Some(EvalCaseMetadata {
223                generated: true,
224                source: Some("events".to_string()),
225            }),
226        };
227
228        Ok(vec![eval_case])
229    }
230
231    /// Build the prompt for LLM-based case generation.
232    fn build_generation_prompt(&self, description: &str) -> String {
233        let tool_instruction = if self.config.include_tool_expectations {
234            r#"Include "intermediate_data" with "tool_uses" where appropriate, each with "name" and "args" fields."#
235        } else {
236            r#"Do not include "intermediate_data" in the output."#
237        };
238
239        format!(
240            r#"Generate exactly {count} evaluation test cases for the following agent description:
241
242"{description}"
243
244Each test case must be a JSON object with these fields:
245- "eval_id": a unique string identifier (e.g., "test_1", "test_2")
246- "description": a brief description of what the test case validates
247- "conversation": an array of conversation turns, each with:
248  - "invocation_id": a unique string (e.g., "inv_1")
249  - "user_content": object with "parts": [{{"text": "..."}}] and "role": "user"
250  - "final_response": object with "parts": [{{"text": "..."}}] and "role": "model"
251  {tool_instruction}
252
253Output ONLY a JSON array of test case objects. No markdown fences, no explanation text.
254Example format:
255[
256  {{
257    "eval_id": "test_1",
258    "description": "Basic greeting test",
259    "conversation": [
260      {{
261        "invocation_id": "inv_1",
262        "user_content": {{"parts": [{{"text": "Hello"}}], "role": "user"}},
263        "final_response": {{"parts": [{{"text": "Hi there! How can I help?"}}], "role": "model"}}
264      }}
265    ]
266  }}
267]"#,
268            count = self.config.cases_per_description,
269            description = description,
270            tool_instruction = tool_instruction,
271        )
272    }
273
274    /// Parse the LLM response text into eval cases, skipping unparseable entries.
275    fn parse_generated_cases(
276        &self,
277        response_text: &str,
278        description: &str,
279    ) -> Result<Vec<EvalCase>> {
280        let json_text = extract_json_array(response_text).unwrap_or(response_text);
281
282        // Try parsing as an array of eval cases
283        let raw_cases: Vec<serde_json::Value> = match serde_json::from_str(json_text) {
284            Ok(cases) => cases,
285            Err(e) => {
286                // Try to extract JSON array from the text
287                warn!("failed to parse LLM response as JSON array: {e}");
288                return Err(EvalError::GenerationError(format!(
289                    "LLM returned unparseable response: {e}"
290                )));
291            }
292        };
293
294        let source = format!("description: {description}");
295        let mut cases = Vec::new();
296
297        for (i, raw_case) in raw_cases.iter().enumerate() {
298            match serde_json::from_value::<EvalCase>(raw_case.clone()) {
299                Ok(mut eval_case) => {
300                    // Add generation tags
301                    if !eval_case.tags.contains(&"generated".to_string()) {
302                        eval_case.tags.push("generated".to_string());
303                    }
304                    cases.push(eval_case);
305                }
306                Err(e) => {
307                    // Log warning and skip this case without aborting the batch
308                    warn!(
309                        case_index = i,
310                        error = %e,
311                        "skipping unparseable generated case"
312                    );
313                }
314            }
315        }
316
317        if cases.is_empty() && !raw_cases.is_empty() {
318            return Err(EvalError::GenerationError(format!(
319                "all {count} generated cases failed to parse (source: {source})",
320                count = raw_cases.len(),
321            )));
322        }
323
324        // Attach metadata as tags for traceability
325        // The full EvalCaseMetadata integration happens in task 16.2
326        for case in &mut cases {
327            if !case.tags.contains(&source) {
328                case.tags.push(source.clone());
329            }
330        }
331
332        Ok(cases)
333    }
334}
335
336/// Extract a JSON array from text that may contain markdown fences or prose.
337///
338/// Handles common LLM output patterns:
339/// - Raw JSON array
340/// - JSON wrapped in ```json ... ``` fences
341/// - JSON embedded in prose text
342fn extract_json_array(text: &str) -> Option<&str> {
343    let trimmed = text.trim();
344
345    // If it already starts with '[', use it directly
346    if trimmed.starts_with('[') {
347        return Some(trimmed);
348    }
349
350    // Try to find JSON within markdown code fences
351    if let Some(start) = trimmed.find("```json") {
352        let content_start = start + "```json".len();
353        if let Some(end) = trimmed[content_start..].find("```") {
354            let json_content = trimmed[content_start..content_start + end].trim();
355            if json_content.starts_with('[') {
356                return Some(json_content);
357            }
358        }
359    }
360
361    // Try generic code fences
362    if let Some(start) = trimmed.find("```") {
363        let content_start = start + 3;
364        // Skip the optional language identifier on the same line
365        let line_end = trimmed[content_start..]
366            .find('\n')
367            .map(|i| content_start + i + 1)
368            .unwrap_or(content_start);
369        if let Some(end) = trimmed[line_end..].find("```") {
370            let json_content = trimmed[line_end..line_end + end].trim();
371            if json_content.starts_with('[') {
372                return Some(json_content);
373            }
374        }
375    }
376
377    // Try to find a JSON array anywhere in the text
378    if let Some(start) = trimmed.find('[')
379        && let Some(end) = trimmed.rfind(']')
380        && end > start
381    {
382        return Some(&trimmed[start..=end]);
383    }
384
385    None
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_generator_config_defaults() {
394        let config = GeneratorConfig::default();
395        assert_eq!(config.cases_per_description, 5);
396        assert!(config.include_tool_expectations);
397    }
398
399    #[test]
400    fn test_extract_json_array_raw() {
401        let input = r#"[{"eval_id": "test_1"}]"#;
402        let result = extract_json_array(input);
403        assert_eq!(result, Some(input));
404    }
405
406    #[test]
407    fn test_extract_json_array_fenced() {
408        let input = "Here are the cases:\n```json\n[{\"eval_id\": \"test_1\"}]\n```\nDone!";
409        let result = extract_json_array(input);
410        assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
411    }
412
413    #[test]
414    fn test_extract_json_array_embedded() {
415        let input = "Sure, here are the cases: [{\"eval_id\": \"test_1\"}] and that's all.";
416        let result = extract_json_array(input);
417        assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
418    }
419
420    #[test]
421    fn test_extract_json_array_no_array() {
422        let input = "No JSON here at all.";
423        let result = extract_json_array(input);
424        assert_eq!(result, None);
425    }
426
427    #[test]
428    fn test_extract_json_array_with_whitespace() {
429        let input = "  \n  [{\"eval_id\": \"test_1\"}]  \n  ";
430        let result = extract_json_array(input);
431        assert_eq!(result, Some(r#"[{"eval_id": "test_1"}]"#));
432    }
433
434    #[test]
435    fn test_generate_from_events_empty() {
436        use adk_core::Llm;
437        use async_trait::async_trait;
438
439        struct MockLlm;
440
441        #[async_trait]
442        impl Llm for MockLlm {
443            fn name(&self) -> &str {
444                "mock"
445            }
446            async fn generate_content(
447                &self,
448                _req: LlmRequest,
449                _stream: bool,
450            ) -> adk_core::Result<adk_core::LlmResponseStream> {
451                unimplemented!()
452            }
453        }
454
455        let generator = TestGenerator::new(Arc::new(MockLlm));
456        let result = generator.generate_from_events(&[]).unwrap();
457        assert!(result.is_empty());
458    }
459
460    #[test]
461    fn test_generate_from_events_with_conversation() {
462        use adk_core::{Content, Llm, LlmResponse};
463        use async_trait::async_trait;
464
465        struct MockLlm;
466
467        #[async_trait]
468        impl Llm for MockLlm {
469            fn name(&self) -> &str {
470                "mock"
471            }
472            async fn generate_content(
473                &self,
474                _req: LlmRequest,
475                _stream: bool,
476            ) -> adk_core::Result<adk_core::LlmResponseStream> {
477                unimplemented!()
478            }
479        }
480
481        let mut events = Vec::new();
482
483        // User event
484        let mut user_event = Event::new("inv_1");
485        user_event.author = "user".to_string();
486        user_event.llm_response = LlmResponse {
487            content: Some(Content::new("user").with_text("What is the weather?")),
488            ..Default::default()
489        };
490        events.push(user_event);
491
492        // Model response event
493        let mut model_event = Event::new("inv_1");
494        model_event.author = "model".to_string();
495        model_event.llm_response = LlmResponse {
496            content: Some(Content::new("model").with_text("The weather is sunny.")),
497            ..Default::default()
498        };
499        events.push(model_event);
500
501        let generator = TestGenerator::new(Arc::new(MockLlm));
502        let cases = generator.generate_from_events(&events).unwrap();
503
504        assert_eq!(cases.len(), 1);
505        let case = &cases[0];
506        assert!(case.eval_id.starts_with("generated_from_events_"));
507        assert_eq!(case.conversation.len(), 1);
508
509        let turn = &case.conversation[0];
510        assert_eq!(turn.invocation_id, "inv_1");
511        assert_eq!(turn.user_content.get_text(), "What is the weather?");
512        assert_eq!(turn.final_response.as_ref().unwrap().get_text(), "The weather is sunny.");
513        assert!(case.tags.contains(&"generated".to_string()));
514    }
515
516    #[test]
517    fn test_generate_from_events_with_tool_calls() {
518        use adk_core::{Content, Llm, LlmResponse, Part};
519        use async_trait::async_trait;
520
521        struct MockLlm;
522
523        #[async_trait]
524        impl Llm for MockLlm {
525            fn name(&self) -> &str {
526                "mock"
527            }
528            async fn generate_content(
529                &self,
530                _req: LlmRequest,
531                _stream: bool,
532            ) -> adk_core::Result<adk_core::LlmResponseStream> {
533                unimplemented!()
534            }
535        }
536
537        let mut events = Vec::new();
538
539        // User event
540        let mut user_event = Event::new("inv_1");
541        user_event.llm_response = LlmResponse {
542            content: Some(Content::new("user").with_text("Get weather in NYC")),
543            ..Default::default()
544        };
545        events.push(user_event);
546
547        // Model event with tool call and text
548        let mut model_event = Event::new("inv_1");
549        model_event.llm_response = LlmResponse {
550            content: Some(Content {
551                role: "model".to_string(),
552                parts: vec![
553                    Part::FunctionCall {
554                        name: "get_weather".to_string(),
555                        args: serde_json::json!({"location": "NYC"}),
556                        id: Some("call_1".to_string()),
557                        thought_signature: None,
558                    },
559                    Part::Text { text: "It's 72°F in NYC.".to_string() },
560                ],
561            }),
562            ..Default::default()
563        };
564        events.push(model_event);
565
566        let generator = TestGenerator::new(Arc::new(MockLlm));
567        let cases = generator.generate_from_events(&events).unwrap();
568
569        assert_eq!(cases.len(), 1);
570        let turn = &cases[0].conversation[0];
571        let intermediate = turn.intermediate_data.as_ref().unwrap();
572        assert_eq!(intermediate.tool_uses.len(), 1);
573        assert_eq!(intermediate.tool_uses[0].name, "get_weather");
574        assert_eq!(intermediate.tool_uses[0].args, serde_json::json!({"location": "NYC"}));
575    }
576
577    #[test]
578    fn test_parse_generated_cases_valid() {
579        use adk_core::Llm;
580        use async_trait::async_trait;
581
582        struct MockLlm;
583
584        #[async_trait]
585        impl Llm for MockLlm {
586            fn name(&self) -> &str {
587                "mock"
588            }
589            async fn generate_content(
590                &self,
591                _req: LlmRequest,
592                _stream: bool,
593            ) -> adk_core::Result<adk_core::LlmResponseStream> {
594                unimplemented!()
595            }
596        }
597
598        let generator = TestGenerator::new(Arc::new(MockLlm));
599        let response = r#"[
600            {
601                "eval_id": "test_1",
602                "description": "Greeting test",
603                "conversation": [{
604                    "invocation_id": "inv_1",
605                    "user_content": {"parts": [{"text": "Hello"}], "role": "user"},
606                    "final_response": {"parts": [{"text": "Hi!"}], "role": "model"}
607                }]
608            }
609        ]"#;
610
611        let cases = generator.parse_generated_cases(response, "test agent").unwrap();
612        assert_eq!(cases.len(), 1);
613        assert_eq!(cases[0].eval_id, "test_1");
614        assert!(cases[0].tags.contains(&"generated".to_string()));
615        assert!(cases[0].tags.contains(&"description: test agent".to_string()));
616    }
617
618    #[test]
619    fn test_parse_generated_cases_partial_failure() {
620        use adk_core::Llm;
621        use async_trait::async_trait;
622
623        struct MockLlm;
624
625        #[async_trait]
626        impl Llm for MockLlm {
627            fn name(&self) -> &str {
628                "mock"
629            }
630            async fn generate_content(
631                &self,
632                _req: LlmRequest,
633                _stream: bool,
634            ) -> adk_core::Result<adk_core::LlmResponseStream> {
635                unimplemented!()
636            }
637        }
638
639        let generator = TestGenerator::new(Arc::new(MockLlm));
640        // Mix of valid and invalid cases
641        let response = r#"[
642            {
643                "eval_id": "test_1",
644                "description": "Valid case",
645                "conversation": [{
646                    "invocation_id": "inv_1",
647                    "user_content": {"parts": [{"text": "Hello"}], "role": "user"},
648                    "final_response": {"parts": [{"text": "Hi!"}], "role": "model"}
649                }]
650            },
651            {
652                "invalid_field": "This is not a valid EvalCase"
653            }
654        ]"#;
655
656        let cases = generator.parse_generated_cases(response, "test").unwrap();
657        // Should parse the valid case and skip the invalid one
658        assert_eq!(cases.len(), 1);
659        assert_eq!(cases[0].eval_id, "test_1");
660    }
661
662    #[test]
663    fn test_parse_generated_cases_all_invalid() {
664        use adk_core::Llm;
665        use async_trait::async_trait;
666
667        struct MockLlm;
668
669        #[async_trait]
670        impl Llm for MockLlm {
671            fn name(&self) -> &str {
672                "mock"
673            }
674            async fn generate_content(
675                &self,
676                _req: LlmRequest,
677                _stream: bool,
678            ) -> adk_core::Result<adk_core::LlmResponseStream> {
679                unimplemented!()
680            }
681        }
682
683        let generator = TestGenerator::new(Arc::new(MockLlm));
684        let response = r#"[{"bad": true}, {"also_bad": "yes"}]"#;
685
686        let result = generator.parse_generated_cases(response, "test");
687        assert!(result.is_err());
688        let err = result.unwrap_err().to_string();
689        assert!(err.contains("all 2 generated cases failed to parse"));
690    }
691
692    #[test]
693    fn test_eval_case_metadata_serialization() {
694        let meta = EvalCaseMetadata { generated: true, source: Some("events".to_string()) };
695
696        let json = serde_json::to_string(&meta).unwrap();
697        assert!(json.contains("\"generated\":true"));
698        assert!(json.contains("\"source\":\"events\""));
699
700        let deserialized: EvalCaseMetadata = serde_json::from_str(&json).unwrap();
701        assert!(deserialized.generated);
702        assert_eq!(deserialized.source.as_deref(), Some("events"));
703    }
704
705    #[test]
706    fn test_eval_case_metadata_defaults() {
707        let meta = EvalCaseMetadata::default();
708        assert!(!meta.generated);
709        assert!(meta.source.is_none());
710
711        // source field should be skipped when None
712        let json = serde_json::to_string(&meta).unwrap();
713        assert!(!json.contains("source"));
714    }
715
716    #[test]
717    fn test_generate_from_events_no_tool_expectations() {
718        use adk_core::{Content, Llm, LlmResponse, Part};
719        use async_trait::async_trait;
720
721        struct MockLlm;
722
723        #[async_trait]
724        impl Llm for MockLlm {
725            fn name(&self) -> &str {
726                "mock"
727            }
728            async fn generate_content(
729                &self,
730                _req: LlmRequest,
731                _stream: bool,
732            ) -> adk_core::Result<adk_core::LlmResponseStream> {
733                unimplemented!()
734            }
735        }
736
737        let config = GeneratorConfig { cases_per_description: 5, include_tool_expectations: false };
738        let generator = TestGenerator::with_config(Arc::new(MockLlm), config);
739
740        let mut events = Vec::new();
741
742        let mut user_event = Event::new("inv_1");
743        user_event.llm_response = LlmResponse {
744            content: Some(Content::new("user").with_text("Get weather")),
745            ..Default::default()
746        };
747        events.push(user_event);
748
749        let mut model_event = Event::new("inv_1");
750        model_event.llm_response = LlmResponse {
751            content: Some(Content {
752                role: "model".to_string(),
753                parts: vec![
754                    Part::FunctionCall {
755                        name: "get_weather".to_string(),
756                        args: serde_json::json!({"location": "NYC"}),
757                        id: None,
758                        thought_signature: None,
759                    },
760                    Part::Text { text: "Sunny".to_string() },
761                ],
762            }),
763            ..Default::default()
764        };
765        events.push(model_event);
766
767        let cases = generator.generate_from_events(&events).unwrap();
768        assert_eq!(cases.len(), 1);
769        // Tool expectations should NOT be included
770        assert!(cases[0].conversation[0].intermediate_data.is_none());
771    }
772}