1use anyhow::{Context, Result};
12use bon::Builder;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16use crate as dspy_rs;
17use crate::{
18 Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
19 evaluate::FeedbackEvaluator, example,
20};
21use dsrs_macros::Signature;
22
23use super::pareto::ParetoFrontier;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct GEPACandidate {
32 pub id: usize,
34
35 pub instruction: String,
37
38 pub module_name: String,
40
41 pub example_scores: Vec<f32>,
43
44 pub parent_id: Option<usize>,
46
47 pub generation: usize,
49}
50
51impl GEPACandidate {
52 pub fn from_predictor(predictor: &dyn Optimizable, module_name: impl Into<String>) -> Self {
54 Self {
55 id: 0,
56 instruction: predictor.get_signature().instruction(),
57 module_name: module_name.into(),
58 example_scores: Vec::new(),
59 parent_id: None,
60 generation: 0,
61 }
62 }
63
64 pub fn average_score(&self) -> f32 {
66 if self.example_scores.is_empty() {
67 return 0.0;
68 }
69 self.example_scores.iter().sum::<f32>() / self.example_scores.len() as f32
70 }
71
72 pub fn mutate(&self, new_instruction: String, generation: usize) -> Self {
74 Self {
75 id: 0, instruction: new_instruction,
77 module_name: self.module_name.clone(),
78 example_scores: Vec::new(),
79 parent_id: Some(self.id),
80 generation,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct GEPAResult {
88 pub best_candidate: GEPACandidate,
90
91 pub all_candidates: Vec<GEPACandidate>,
93
94 pub total_rollouts: usize,
96
97 pub total_lm_calls: usize,
99
100 pub evolution_history: Vec<(usize, f32)>,
102
103 pub highest_score_achieved_per_val_task: Vec<f32>,
105
106 pub best_outputs_valset: Option<Vec<Prediction>>,
108
109 pub frontier_history: Vec<ParetoStatistics>,
111}
112
113pub use super::pareto::ParetoStatistics;
115
116#[Signature]
121struct ReflectOnTrace {
122 #[input(desc = "The current instruction for the module")]
128 pub current_instruction: String,
129
130 #[input(desc = "Execution traces showing inputs, outputs, and evaluation feedback")]
131 pub traces: String,
132
133 #[input(desc = "Description of what the module should accomplish")]
134 pub task_description: String,
135
136 #[output(desc = "Analysis of weaknesses and specific improvement suggestions")]
137 pub reflection: String,
138}
139
140#[Signature]
141struct ProposeImprovedInstruction {
142 #[input(desc = "The current instruction")]
148 pub current_instruction: String,
149
150 #[input(desc = "Reflection on weaknesses and improvement suggestions")]
151 pub reflection: String,
152
153 #[input(desc = "Execution traces and feedback from recent rollouts")]
154 pub traces_and_feedback: String,
155
156 #[output(desc = "An improved instruction that addresses the identified weaknesses")]
157 pub improved_instruction: String,
158}
159
160#[Signature]
161struct SelectModuleToImprove {
162 #[input(desc = "List of modules with their current instructions and performance")]
167 pub module_summary: String,
168
169 #[input(desc = "Recent execution traces showing module interactions")]
170 pub execution_traces: String,
171
172 #[output(desc = "Name of the module to optimize and reasoning")]
173 pub selected_module: String,
174}
175
176#[derive(Builder)]
182pub struct GEPA {
183 #[builder(default = 20)]
185 pub num_iterations: usize,
186
187 #[builder(default = 25)]
189 pub minibatch_size: usize,
190
191 #[builder(default = 10)]
193 pub num_trials: usize,
194
195 #[builder(default = 1.0)]
197 pub temperature: f32,
198
199 #[builder(default = true)]
201 pub track_stats: bool,
202
203 #[builder(default = false)]
205 pub track_best_outputs: bool,
206
207 pub max_rollouts: Option<usize>,
209
210 pub max_lm_calls: Option<usize>,
212
213 pub prompt_model: Option<LM>,
215
216 pub valset: Option<Vec<Example>>,
218}
219
220impl GEPA {
221 async fn initialize_frontier<M>(
223 &self,
224 module: &mut M,
225 trainset: &[Example],
226 ) -> Result<ParetoFrontier>
227 where
228 M: Module + Optimizable + FeedbackEvaluator,
229 {
230 let mut frontier = ParetoFrontier::new();
231
232 let candidate_infos: Vec<GEPACandidate> = {
234 let predictors = module.parameters();
235 predictors
236 .into_iter()
237 .map(|(name, predictor)| GEPACandidate::from_predictor(predictor, name))
238 .collect()
239 };
240
241 for candidate in candidate_infos {
243 let scores = self
244 .evaluate_candidate(module, trainset, &candidate)
245 .await?;
246 frontier.add_candidate(candidate, &scores);
247 }
248
249 Ok(frontier)
250 }
251
252 async fn evaluate_candidate<M>(
254 &self,
255 module: &M,
256 examples: &[Example],
257 _candidate: &GEPACandidate,
258 ) -> Result<Vec<f32>>
259 where
260 M: Module + FeedbackEvaluator,
261 {
262 use futures::future::join_all;
263
264 let futures: Vec<_> = examples
265 .iter()
266 .map(|example| async move {
267 let prediction = module.forward(example.clone()).await?;
268 let feedback = module.feedback_metric(example, &prediction).await;
269 Ok::<f32, anyhow::Error>(feedback.score)
270 })
271 .collect();
272
273 let results = join_all(futures).await;
274 results.into_iter().collect()
275 }
276
277 async fn collect_traces<M>(
279 &self,
280 module: &M,
281 minibatch: &[Example],
282 ) -> Result<Vec<(Example, Prediction, String)>>
283 where
284 M: Module + FeedbackEvaluator,
285 {
286 let mut traces = Vec::with_capacity(minibatch.len());
287
288 for example in minibatch {
289 let prediction = module.forward(example.clone()).await?;
290 let feedback = module.feedback_metric(example, &prediction).await;
291
292 let trace_text = format!(
294 "Input: {:?}\nOutput: {:?}\nScore: {:.3}\nFeedback: {}",
295 example, prediction, feedback.score, feedback.feedback
296 );
297
298 traces.push((example.clone(), prediction, trace_text));
299 }
300
301 Ok(traces)
302 }
303
304 async fn generate_mutation(
306 &self,
307 current_instruction: &str,
308 traces: &[(Example, Prediction, String)],
309 task_description: &str,
310 ) -> Result<String> {
311 let traces_text = traces
313 .iter()
314 .enumerate()
315 .map(|(i, (_, _, trace))| format!("=== Trace {} ===\n{}\n", i + 1, trace))
316 .collect::<Vec<_>>()
317 .join("\n");
318
319 let reflect_predictor = Predict::new(ReflectOnTrace::new());
321 let reflection_input = example! {
322 "current_instruction": "input" => current_instruction,
323 "traces": "input" => &traces_text,
324 "task_description": "input" => task_description
325 };
326
327 let reflection_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
328 prompt_model.temperature = self.temperature;
329 reflect_predictor
330 .forward_with_config(reflection_input, Arc::new(prompt_model))
331 .await?
332 } else {
333 reflect_predictor.forward(reflection_input).await?
334 };
335
336 let reflection = reflection_output
337 .get("reflection", None)
338 .as_str()
339 .unwrap_or("")
340 .to_string();
341
342 let propose_predictor = Predict::new(ProposeImprovedInstruction::new());
344 let proposal_input = example! {
345 "current_instruction": "input" => current_instruction,
346 "reflection": "input" => &reflection,
347 "traces_and_feedback": "input" => &traces_text
348 };
349
350 let proposal_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
351 prompt_model.temperature = self.temperature;
352 propose_predictor
353 .forward_with_config(proposal_input, Arc::new(prompt_model))
354 .await?
355 } else {
356 propose_predictor.forward(proposal_input).await?
357 };
358
359 let improved = proposal_output
360 .get("improved_instruction", None)
361 .as_str()
362 .unwrap_or(current_instruction)
363 .to_string();
364
365 Ok(improved)
366 }
367}
368
369impl Optimizer for GEPA {
370 async fn compile<M>(&self, _module: &mut M, _trainset: Vec<Example>) -> Result<()>
371 where
372 M: Module + Optimizable + crate::Evaluator,
373 {
374 anyhow::bail!(
377 "GEPA requires the module to implement FeedbackEvaluator trait. \
378 Please implement feedback_metric() method that returns FeedbackMetric."
379 )
380 }
381}
382
383impl GEPA {
384 pub async fn compile_with_feedback<M>(
386 &self,
387 module: &mut M,
388 trainset: Vec<Example>,
389 ) -> Result<GEPAResult>
390 where
391 M: Module + Optimizable + FeedbackEvaluator,
392 {
393 println!("GEPA: Starting reflective prompt optimization");
394 println!(" Iterations: {}", self.num_iterations);
395 println!(" Minibatch size: {}", self.minibatch_size);
396
397 let eval_set = self.valset.as_ref().unwrap_or(&trainset);
399
400 let mut frontier = self.initialize_frontier(&mut *module, eval_set).await?;
402 println!(" Initialized frontier with {} candidates", frontier.len());
403
404 let mut all_candidates = Vec::new();
406 let mut evolution_history = Vec::new();
407 let mut frontier_history = Vec::new();
408 let mut total_rollouts = 0;
409 let mut total_lm_calls = 0;
410
411 for generation in 0..self.num_iterations {
413 println!("\nGeneration {}/{}", generation + 1, self.num_iterations);
414
415 if let Some(max_rollouts) = self.max_rollouts
417 && total_rollouts >= max_rollouts
418 {
419 println!(" Budget limit reached: max rollouts");
420 break;
421 }
422
423 let parent = frontier
425 .sample_proportional_to_coverage()
426 .context("Failed to sample from frontier")?
427 .clone();
428
429 println!(
430 " Sampled parent (ID {}): avg score {:.3}",
431 parent.id,
432 parent.average_score()
433 );
434
435 let minibatch: Vec<Example> =
437 trainset.iter().take(self.minibatch_size).cloned().collect();
438
439 {
441 let mut predictors = module.parameters();
442 if let Some(predictor) = predictors.get_mut(&parent.module_name) {
443 predictor.update_signature_instruction(parent.instruction.clone())?;
444 }
445 }
446
447 let traces = self.collect_traces(module, &minibatch).await?;
449 total_rollouts += traces.len();
450
451 let task_desc = "Perform the task as specified";
453 let new_instruction = self
454 .generate_mutation(&parent.instruction, &traces, task_desc)
455 .await?;
456
457 total_lm_calls += 2; println!(" Generated new instruction through reflection");
460
461 let child = parent.mutate(new_instruction.clone(), generation + 1);
463
464 {
466 let mut predictors = module.parameters();
467 if let Some(predictor) = predictors.get_mut(&child.module_name) {
468 predictor.update_signature_instruction(child.instruction.clone())?;
469 }
470 }
471
472 let child_scores = self.evaluate_candidate(module, eval_set, &child).await?;
473 total_rollouts += child_scores.len();
474
475 let child_avg = child_scores.iter().sum::<f32>() / child_scores.len() as f32;
476 println!(" Child avg score: {:.3}", child_avg);
477
478 let added = frontier.add_candidate(child.clone(), &child_scores);
480 if added {
481 println!(" Added to Pareto frontier");
482 } else {
483 println!(" Dominated, not added");
484 }
485
486 if self.track_stats {
488 all_candidates.push(child);
489 let best_avg = frontier
490 .best_by_average()
491 .map(|c| c.average_score())
492 .unwrap_or(0.0);
493 evolution_history.push((generation, best_avg));
494 frontier_history.push(frontier.statistics());
495 }
496
497 println!(" Frontier size: {}", frontier.len());
498 }
499
500 let best_candidate = frontier
502 .best_by_average()
503 .context("No candidates on frontier")?
504 .clone();
505
506 println!("\nGEPA optimization complete");
507 println!(
508 " Best average score: {:.3}",
509 best_candidate.average_score()
510 );
511 println!(" Total rollouts: {}", total_rollouts);
512 println!(" Total LM calls: {}", total_lm_calls);
513
514 {
516 let mut predictors = module.parameters();
517 if let Some(predictor) = predictors.get_mut(&best_candidate.module_name) {
518 predictor.update_signature_instruction(best_candidate.instruction.clone())?;
519 }
520 }
521
522 Ok(GEPAResult {
523 best_candidate,
524 all_candidates,
525 total_rollouts,
526 total_lm_calls,
527 evolution_history,
528 highest_score_achieved_per_val_task: vec![], best_outputs_valset: None, frontier_history,
531 })
532 }
533}