Skip to main content

agentforge_scenarios/
generator.rs

1use agentforge_core::{AgentFile, Result, Scenario};
2
3use crate::{
4    adversarial::generate_adversarial_scenarios,
5    domain_seeded::{generate_domain_seeded_scenarios, DomainSeededConfig},
6    schema_derived::generate_schema_derived_scenarios,
7};
8
9/// Configuration for scenario generation.
10#[derive(Debug, Clone)]
11pub struct ScenarioGeneratorConfig {
12    pub total_count: u32,
13    /// Fraction of scenarios generated via schema-derived strategy (default: 0.5)
14    pub schema_derived_ratio: f64,
15    /// Fraction of scenarios generated adversarially (default: 0.3)
16    pub adversarial_ratio: f64,
17    /// Fraction of scenarios generated via domain seeding (default: 0.2)
18    pub domain_seeded_ratio: f64,
19    /// OpenAI-compatible LLM base URL for domain-seeded generation (optional)
20    pub llm_base_url: Option<String>,
21    pub llm_api_key: Option<String>,
22    pub llm_model: Option<String>,
23    pub agent_id: uuid::Uuid,
24}
25
26impl ScenarioGeneratorConfig {
27    /// Validate that ratios sum to approximately 1.0
28    pub fn validate(&self) -> bool {
29        let total = self.schema_derived_ratio + self.adversarial_ratio + self.domain_seeded_ratio;
30        (total - 1.0).abs() < 0.01
31    }
32}
33
34impl Default for ScenarioGeneratorConfig {
35    fn default() -> Self {
36        Self {
37            total_count: 100,
38            schema_derived_ratio: 0.5,
39            adversarial_ratio: 0.3,
40            domain_seeded_ratio: 0.2,
41            llm_base_url: None,
42            llm_api_key: None,
43            llm_model: None,
44            agent_id: uuid::Uuid::new_v4(),
45        }
46    }
47}
48
49/// Generate scenarios for an agent file using all three strategies.
50pub async fn generate_scenarios(
51    agent: &AgentFile,
52    config: &ScenarioGeneratorConfig,
53) -> Result<Vec<Scenario>> {
54    let total = config.total_count as usize;
55    // Use floor so that rounding never eats the domain-seeded budget entirely.
56    // With total=5: schema=2, adversarial=1, domain=2.
57    let schema_n = (total as f64 * config.schema_derived_ratio).floor() as usize;
58    let adversarial_n = (total as f64 * config.adversarial_ratio).floor() as usize;
59    let domain_n = total.saturating_sub(schema_n + adversarial_n);
60
61    tracing::info!(
62        agent = %agent.name,
63        total = total,
64        schema_n = schema_n,
65        adversarial_n = adversarial_n,
66        domain_n = domain_n,
67        "Generating scenarios"
68    );
69
70    let mut scenarios = Vec::with_capacity(total);
71
72    // Schema-derived scenarios
73    let schema_scenarios = generate_schema_derived_scenarios(agent, schema_n, config.agent_id)?;
74    scenarios.extend(schema_scenarios);
75
76    // Adversarial scenarios
77    let adversarial = generate_adversarial_scenarios(agent, adversarial_n, config.agent_id)?;
78    scenarios.extend(adversarial);
79
80    // Domain-seeded scenarios (LLM-based, optional)
81    if domain_n > 0 {
82        let domain_config = DomainSeededConfig {
83            count: domain_n,
84            agent_id: config.agent_id,
85            llm_base_url: config.llm_base_url.clone(),
86            llm_api_key: config.llm_api_key.clone(),
87            llm_model: config
88                .llm_model
89                .clone()
90                .unwrap_or_else(|| "gpt-4o-mini".to_string()),
91        };
92
93        match generate_domain_seeded_scenarios(agent, &domain_config).await {
94            Ok(ds) => scenarios.extend(ds),
95            Err(e) => {
96                tracing::warn!(error = %e, "Domain-seeded generation failed, falling back to adversarial");
97                let fallback = generate_adversarial_scenarios(agent, domain_n, config.agent_id)?;
98                scenarios.extend(fallback);
99            }
100        }
101    }
102
103    tracing::info!(generated = scenarios.len(), "Scenario generation complete");
104    Ok(scenarios)
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use agentforge_core::{EvalHints, ModelConfig, ModelProvider, ToolDefinition};
111
112    fn make_test_agent() -> AgentFile {
113        AgentFile {
114            agentforge_schema_version: "1".to_string(),
115            name: "test-agent".to_string(),
116            version: "1.0.0".to_string(),
117            model: ModelConfig {
118                provider: ModelProvider::Openai,
119                model_id: "gpt-4o".to_string(),
120                temperature: Some(0.2),
121                max_tokens: Some(2048),
122                top_p: None,
123            },
124            system_prompt: "You are a customer support agent. Help users with orders.".to_string(),
125            tools: vec![
126                ToolDefinition {
127                    name: "get_order_status".to_string(),
128                    description: "Get order status by ID".to_string(),
129                    parameters: serde_json::json!({
130                        "type": "object",
131                        "properties": {
132                            "order_id": {"type": "string", "description": "Order ID"}
133                        },
134                        "required": ["order_id"]
135                    }),
136                },
137                ToolDefinition {
138                    name: "cancel_order".to_string(),
139                    description: "Cancel an order".to_string(),
140                    parameters: serde_json::json!({
141                        "type": "object",
142                        "properties": {
143                            "order_id": {"type": "string"},
144                            "reason": {"type": "string"}
145                        },
146                        "required": ["order_id"]
147                    }),
148                },
149            ],
150            output_schema: Some(serde_json::json!({
151                "type": "object",
152                "properties": {
153                    "response": {"type": "string"},
154                    "action_taken": {"type": "string"}
155                },
156                "required": ["response"]
157            })),
158            constraints: vec![
159                "Always confirm order ID before calling get_order_status.".to_string()
160            ],
161            eval_hints: Some(EvalHints {
162                domain: Some("customer_support".to_string()),
163                typical_turns: Some(3),
164                critical_tools: vec!["get_order_status".to_string()],
165                pass_threshold: Some(0.85),
166                scenario_count: Some(20),
167            }),
168            metadata: None,
169        }
170    }
171
172    #[test]
173    fn default_config_valid() {
174        let config = ScenarioGeneratorConfig::default();
175        assert!(config.validate());
176    }
177
178    #[tokio::test]
179    async fn generates_expected_count() {
180        let agent = make_test_agent();
181        let config = ScenarioGeneratorConfig {
182            total_count: 10,
183            schema_derived_ratio: 0.5,
184            adversarial_ratio: 0.3,
185            domain_seeded_ratio: 0.2,
186            llm_base_url: None,
187            llm_api_key: None,
188            llm_model: None,
189            agent_id: uuid::Uuid::new_v4(),
190        };
191        let scenarios = generate_scenarios(&agent, &config).await.unwrap();
192        assert!(!scenarios.is_empty());
193        // Should have at least schema-derived and adversarial
194        assert!(scenarios.len() >= 8);
195    }
196}