1use agentforge_core::{AgentFile, Result, Scenario};
2
3use crate::{
4 adversarial::generate_adversarial_scenarios,
5 domain_seeded::{generate_domain_seeded_scenarios, DomainSeededConfig},
6 schema_derived::generate_schema_derived_scenarios,
7};
8
9#[derive(Debug, Clone)]
11pub struct ScenarioGeneratorConfig {
12 pub total_count: u32,
13 pub schema_derived_ratio: f64,
15 pub adversarial_ratio: f64,
17 pub domain_seeded_ratio: f64,
19 pub llm_base_url: Option<String>,
21 pub llm_api_key: Option<String>,
22 pub llm_model: Option<String>,
23 pub agent_id: uuid::Uuid,
24}
25
26impl ScenarioGeneratorConfig {
27 pub fn validate(&self) -> bool {
29 let total = self.schema_derived_ratio + self.adversarial_ratio + self.domain_seeded_ratio;
30 (total - 1.0).abs() < 0.01
31 }
32}
33
34impl Default for ScenarioGeneratorConfig {
35 fn default() -> Self {
36 Self {
37 total_count: 100,
38 schema_derived_ratio: 0.5,
39 adversarial_ratio: 0.3,
40 domain_seeded_ratio: 0.2,
41 llm_base_url: None,
42 llm_api_key: None,
43 llm_model: None,
44 agent_id: uuid::Uuid::new_v4(),
45 }
46 }
47}
48
49pub async fn generate_scenarios(
51 agent: &AgentFile,
52 config: &ScenarioGeneratorConfig,
53) -> Result<Vec<Scenario>> {
54 let total = config.total_count as usize;
55 let schema_n = (total as f64 * config.schema_derived_ratio).floor() as usize;
58 let adversarial_n = (total as f64 * config.adversarial_ratio).floor() as usize;
59 let domain_n = total.saturating_sub(schema_n + adversarial_n);
60
61 tracing::info!(
62 agent = %agent.name,
63 total = total,
64 schema_n = schema_n,
65 adversarial_n = adversarial_n,
66 domain_n = domain_n,
67 "Generating scenarios"
68 );
69
70 let mut scenarios = Vec::with_capacity(total);
71
72 let schema_scenarios = generate_schema_derived_scenarios(agent, schema_n, config.agent_id)?;
74 scenarios.extend(schema_scenarios);
75
76 let adversarial = generate_adversarial_scenarios(agent, adversarial_n, config.agent_id)?;
78 scenarios.extend(adversarial);
79
80 if domain_n > 0 {
82 let domain_config = DomainSeededConfig {
83 count: domain_n,
84 agent_id: config.agent_id,
85 llm_base_url: config.llm_base_url.clone(),
86 llm_api_key: config.llm_api_key.clone(),
87 llm_model: config
88 .llm_model
89 .clone()
90 .unwrap_or_else(|| "gpt-4o-mini".to_string()),
91 };
92
93 match generate_domain_seeded_scenarios(agent, &domain_config).await {
94 Ok(ds) => scenarios.extend(ds),
95 Err(e) => {
96 tracing::warn!(error = %e, "Domain-seeded generation failed, falling back to adversarial");
97 let fallback = generate_adversarial_scenarios(agent, domain_n, config.agent_id)?;
98 scenarios.extend(fallback);
99 }
100 }
101 }
102
103 tracing::info!(generated = scenarios.len(), "Scenario generation complete");
104 Ok(scenarios)
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use agentforge_core::{EvalHints, ModelConfig, ModelProvider, ToolDefinition};
111
112 fn make_test_agent() -> AgentFile {
113 AgentFile {
114 agentforge_schema_version: "1".to_string(),
115 name: "test-agent".to_string(),
116 version: "1.0.0".to_string(),
117 model: ModelConfig {
118 provider: ModelProvider::Openai,
119 model_id: "gpt-4o".to_string(),
120 temperature: Some(0.2),
121 max_tokens: Some(2048),
122 top_p: None,
123 },
124 system_prompt: "You are a customer support agent. Help users with orders.".to_string(),
125 tools: vec![
126 ToolDefinition {
127 name: "get_order_status".to_string(),
128 description: "Get order status by ID".to_string(),
129 parameters: serde_json::json!({
130 "type": "object",
131 "properties": {
132 "order_id": {"type": "string", "description": "Order ID"}
133 },
134 "required": ["order_id"]
135 }),
136 },
137 ToolDefinition {
138 name: "cancel_order".to_string(),
139 description: "Cancel an order".to_string(),
140 parameters: serde_json::json!({
141 "type": "object",
142 "properties": {
143 "order_id": {"type": "string"},
144 "reason": {"type": "string"}
145 },
146 "required": ["order_id"]
147 }),
148 },
149 ],
150 output_schema: Some(serde_json::json!({
151 "type": "object",
152 "properties": {
153 "response": {"type": "string"},
154 "action_taken": {"type": "string"}
155 },
156 "required": ["response"]
157 })),
158 constraints: vec![
159 "Always confirm order ID before calling get_order_status.".to_string()
160 ],
161 eval_hints: Some(EvalHints {
162 domain: Some("customer_support".to_string()),
163 typical_turns: Some(3),
164 critical_tools: vec!["get_order_status".to_string()],
165 pass_threshold: Some(0.85),
166 scenario_count: Some(20),
167 }),
168 metadata: None,
169 }
170 }
171
172 #[test]
173 fn default_config_valid() {
174 let config = ScenarioGeneratorConfig::default();
175 assert!(config.validate());
176 }
177
178 #[tokio::test]
179 async fn generates_expected_count() {
180 let agent = make_test_agent();
181 let config = ScenarioGeneratorConfig {
182 total_count: 10,
183 schema_derived_ratio: 0.5,
184 adversarial_ratio: 0.3,
185 domain_seeded_ratio: 0.2,
186 llm_base_url: None,
187 llm_api_key: None,
188 llm_model: None,
189 agent_id: uuid::Uuid::new_v4(),
190 };
191 let scenarios = generate_scenarios(&agent, &config).await.unwrap();
192 assert!(!scenarios.is_empty());
193 assert!(scenarios.len() >= 8);
195 }
196}