dspy_rs/optimizer/
gepa.rs

1/// GEPA (Genetic-Pareto) Optimizer Implementation
2///
3/// GEPA is a reflective prompt optimizer that uses:
4/// 1. Rich textual feedback (not just scores)
5/// 2. Pareto-based candidate selection
6/// 3. LLM-driven reflection and mutation
7/// 4. Per-example dominance tracking
8///
9/// Reference: "GEPA: Reflective Prompt Evolution Can Outperform Reinforcement Learning"
10/// (Agrawal et al., 2025, arxiv:2507.19457)
11use anyhow::{Context, Result};
12use bon::Builder;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16use crate as dspy_rs;
17use crate::{
18    Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
19    evaluate::FeedbackEvaluator, example,
20};
21use dsrs_macros::Signature;
22
23use super::pareto::ParetoFrontier;
24
25// ============================================================================
26// Core Data Structures
27// ============================================================================
28
29/// A candidate program in the evolutionary process
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct GEPACandidate {
32    /// Unique identifier
33    pub id: usize,
34
35    /// The instruction/prompt for this candidate
36    pub instruction: String,
37
38    /// Name of the module this candidate targets
39    pub module_name: String,
40
41    /// Scores achieved on each evaluation example
42    pub example_scores: Vec<f32>,
43
44    /// Parent candidate ID (for lineage tracking)
45    pub parent_id: Option<usize>,
46
47    /// Generation number in the evolutionary process
48    pub generation: usize,
49}
50
51impl GEPACandidate {
52    /// Create a new candidate from a predictor
53    pub fn from_predictor(predictor: &dyn Optimizable, module_name: impl Into<String>) -> Self {
54        Self {
55            id: 0,
56            instruction: predictor.get_signature().instruction(),
57            module_name: module_name.into(),
58            example_scores: Vec::new(),
59            parent_id: None,
60            generation: 0,
61        }
62    }
63
64    /// Calculate average score across all examples
65    pub fn average_score(&self) -> f32 {
66        if self.example_scores.is_empty() {
67            return 0.0;
68        }
69        self.example_scores.iter().sum::<f32>() / self.example_scores.len() as f32
70    }
71
72    /// Create a mutated child candidate
73    pub fn mutate(&self, new_instruction: String, generation: usize) -> Self {
74        Self {
75            id: 0, // Will be assigned by frontier
76            instruction: new_instruction,
77            module_name: self.module_name.clone(),
78            example_scores: Vec::new(),
79            parent_id: Some(self.id),
80            generation,
81        }
82    }
83}
84
85/// Detailed results from GEPA optimization
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct GEPAResult {
88    /// Best candidate found
89    pub best_candidate: GEPACandidate,
90
91    /// All candidates evaluated during optimization
92    pub all_candidates: Vec<GEPACandidate>,
93
94    /// Total number of rollouts performed
95    pub total_rollouts: usize,
96
97    /// Total LM calls made during optimization
98    pub total_lm_calls: usize,
99
100    /// Evolution history: generation -> best score at that generation
101    pub evolution_history: Vec<(usize, f32)>,
102
103    /// Highest score achieved on each validation task
104    pub highest_score_achieved_per_val_task: Vec<f32>,
105
106    /// Best outputs on validation set (if tracked)
107    pub best_outputs_valset: Option<Vec<Prediction>>,
108
109    /// Pareto frontier statistics over time
110    pub frontier_history: Vec<ParetoStatistics>,
111}
112
113/// Statistics about Pareto frontier (re-exported from pareto module)
114pub use super::pareto::ParetoStatistics;
115
116// ============================================================================
117// LLM Signatures for Reflection and Mutation
118// ============================================================================
119
120#[Signature]
121struct ReflectOnTrace {
122    /// You are an expert at analyzing program execution traces and identifying
123    /// areas for improvement. Given the module instruction, example traces showing
124    /// inputs, outputs, and feedback, identify specific weaknesses and suggest
125    /// targeted improvements.
126
127    #[input(desc = "The current instruction for the module")]
128    pub current_instruction: String,
129
130    #[input(desc = "Execution traces showing inputs, outputs, and evaluation feedback")]
131    pub traces: String,
132
133    #[input(desc = "Description of what the module should accomplish")]
134    pub task_description: String,
135
136    #[output(desc = "Analysis of weaknesses and specific improvement suggestions")]
137    pub reflection: String,
138}
139
140#[Signature]
141struct ProposeImprovedInstruction {
142    /// You are an expert prompt engineer. Given the current instruction, execution
143    /// traces, feedback, and reflection on weaknesses, propose an improved instruction
144    /// that addresses the identified issues. Be creative and consider various prompting
145    /// techniques.
146
147    #[input(desc = "The current instruction")]
148    pub current_instruction: String,
149
150    #[input(desc = "Reflection on weaknesses and improvement suggestions")]
151    pub reflection: String,
152
153    #[input(desc = "Execution traces and feedback from recent rollouts")]
154    pub traces_and_feedback: String,
155
156    #[output(desc = "An improved instruction that addresses the identified weaknesses")]
157    pub improved_instruction: String,
158}
159
160#[Signature]
161struct SelectModuleToImprove {
162    /// Given multiple modules in a program and their performance feedback, select which
163    /// module would benefit most from optimization. Consider which module's errors are
164    /// most impactful and addressable through instruction changes.
165
166    #[input(desc = "List of modules with their current instructions and performance")]
167    pub module_summary: String,
168
169    #[input(desc = "Recent execution traces showing module interactions")]
170    pub execution_traces: String,
171
172    #[output(desc = "Name of the module to optimize and reasoning")]
173    pub selected_module: String,
174}
175
176// ============================================================================
177// GEPA Optimizer
178// ============================================================================
179
180/// GEPA Optimizer Configuration
181#[derive(Builder)]
182pub struct GEPA {
183    /// Maximum number of evolutionary iterations
184    #[builder(default = 20)]
185    pub num_iterations: usize,
186
187    /// Size of minibatch for each rollout
188    #[builder(default = 25)]
189    pub minibatch_size: usize,
190
191    /// Number of trials per candidate evaluation
192    #[builder(default = 10)]
193    pub num_trials: usize,
194
195    /// Temperature for LLM-based mutations
196    #[builder(default = 1.0)]
197    pub temperature: f32,
198
199    /// Track detailed statistics
200    #[builder(default = true)]
201    pub track_stats: bool,
202
203    /// Track best outputs on validation set (for inference-time search)
204    #[builder(default = false)]
205    pub track_best_outputs: bool,
206
207    /// Maximum total rollouts (budget control)
208    pub max_rollouts: Option<usize>,
209
210    /// Maximum LM calls (budget control)
211    pub max_lm_calls: Option<usize>,
212
213    /// Optional separate LM for meta-prompting (instruction generation)
214    pub prompt_model: Option<LM>,
215
216    /// Validation set for Pareto evaluation (if None, uses trainset)
217    pub valset: Option<Vec<Example>>,
218}
219
220impl GEPA {
221    /// Initialize the Pareto frontier with the seed program
222    async fn initialize_frontier<M>(
223        &self,
224        module: &mut M,
225        trainset: &[Example],
226    ) -> Result<ParetoFrontier>
227    where
228        M: Module + Optimizable + FeedbackEvaluator,
229    {
230        let mut frontier = ParetoFrontier::new();
231
232        // Collect predictor information first (to release mutable borrow)
233        let candidate_infos: Vec<GEPACandidate> = {
234            let predictors = module.parameters();
235            predictors
236                .into_iter()
237                .map(|(name, predictor)| GEPACandidate::from_predictor(predictor, name))
238                .collect()
239        };
240
241        // Now evaluate each candidate (module is no longer borrowed mutably)
242        for candidate in candidate_infos {
243            let scores = self
244                .evaluate_candidate(module, trainset, &candidate)
245                .await?;
246            frontier.add_candidate(candidate, &scores);
247        }
248
249        Ok(frontier)
250    }
251
252    /// Evaluate a candidate on a set of examples (in parallel for speed)
253    async fn evaluate_candidate<M>(
254        &self,
255        module: &M,
256        examples: &[Example],
257        _candidate: &GEPACandidate,
258    ) -> Result<Vec<f32>>
259    where
260        M: Module + FeedbackEvaluator,
261    {
262        use futures::future::join_all;
263
264        let futures: Vec<_> = examples
265            .iter()
266            .map(|example| async move {
267                let prediction = module.forward(example.clone()).await?;
268                let feedback = module.feedback_metric(example, &prediction).await;
269                Ok::<f32, anyhow::Error>(feedback.score)
270            })
271            .collect();
272
273        let results = join_all(futures).await;
274        results.into_iter().collect()
275    }
276
277    /// Collect execution traces with feedback
278    async fn collect_traces<M>(
279        &self,
280        module: &M,
281        minibatch: &[Example],
282    ) -> Result<Vec<(Example, Prediction, String)>>
283    where
284        M: Module + FeedbackEvaluator,
285    {
286        let mut traces = Vec::with_capacity(minibatch.len());
287
288        for example in minibatch {
289            let prediction = module.forward(example.clone()).await?;
290            let feedback = module.feedback_metric(example, &prediction).await;
291
292            // Format trace for LLM reflection
293            let trace_text = format!(
294                "Input: {:?}\nOutput: {:?}\nScore: {:.3}\nFeedback: {}",
295                example, prediction, feedback.score, feedback.feedback
296            );
297
298            traces.push((example.clone(), prediction, trace_text));
299        }
300
301        Ok(traces)
302    }
303
304    /// Generate improved instruction through LLM reflection
305    async fn generate_mutation(
306        &self,
307        current_instruction: &str,
308        traces: &[(Example, Prediction, String)],
309        task_description: &str,
310    ) -> Result<String> {
311        // Combine traces into a single string
312        let traces_text = traces
313            .iter()
314            .enumerate()
315            .map(|(i, (_, _, trace))| format!("=== Trace {} ===\n{}\n", i + 1, trace))
316            .collect::<Vec<_>>()
317            .join("\n");
318
319        // First, reflect on the traces
320        let reflect_predictor = Predict::new(ReflectOnTrace::new());
321        let reflection_input = example! {
322            "current_instruction": "input" => current_instruction,
323            "traces": "input" => &traces_text,
324            "task_description": "input" => task_description
325        };
326
327        let reflection_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
328            prompt_model.temperature = self.temperature;
329            reflect_predictor
330                .forward_with_config(reflection_input, Arc::new(prompt_model))
331                .await?
332        } else {
333            reflect_predictor.forward(reflection_input).await?
334        };
335
336        let reflection = reflection_output
337            .get("reflection", None)
338            .as_str()
339            .unwrap_or("")
340            .to_string();
341
342        // Then, propose improved instruction
343        let propose_predictor = Predict::new(ProposeImprovedInstruction::new());
344        let proposal_input = example! {
345            "current_instruction": "input" => current_instruction,
346            "reflection": "input" => &reflection,
347            "traces_and_feedback": "input" => &traces_text
348        };
349
350        let proposal_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
351            prompt_model.temperature = self.temperature;
352            propose_predictor
353                .forward_with_config(proposal_input, Arc::new(prompt_model))
354                .await?
355        } else {
356            propose_predictor.forward(proposal_input).await?
357        };
358
359        let improved = proposal_output
360            .get("improved_instruction", None)
361            .as_str()
362            .unwrap_or(current_instruction)
363            .to_string();
364
365        Ok(improved)
366    }
367}
368
369impl Optimizer for GEPA {
370    async fn compile<M>(&self, _module: &mut M, _trainset: Vec<Example>) -> Result<()>
371    where
372        M: Module + Optimizable + crate::Evaluator,
373    {
374        // GEPA requires FeedbackEvaluator, not just Evaluator
375        // This is a compilation error that guides users to implement the right trait
376        anyhow::bail!(
377            "GEPA requires the module to implement FeedbackEvaluator trait. \
378             Please implement feedback_metric() method that returns FeedbackMetric."
379        )
380    }
381}
382
383impl GEPA {
384    /// Compile method specifically for FeedbackEvaluator modules
385    pub async fn compile_with_feedback<M>(
386        &self,
387        module: &mut M,
388        trainset: Vec<Example>,
389    ) -> Result<GEPAResult>
390    where
391        M: Module + Optimizable + FeedbackEvaluator,
392    {
393        println!("GEPA: Starting reflective prompt optimization");
394        println!("  Iterations: {}", self.num_iterations);
395        println!("  Minibatch size: {}", self.minibatch_size);
396
397        // Use valset if provided, otherwise use trainset for Pareto evaluation
398        let eval_set = self.valset.as_ref().unwrap_or(&trainset);
399
400        // Initialize frontier with seed program
401        let mut frontier = self.initialize_frontier(&mut *module, eval_set).await?;
402        println!("  Initialized frontier with {} candidates", frontier.len());
403
404        // Track statistics
405        let mut all_candidates = Vec::new();
406        let mut evolution_history = Vec::new();
407        let mut frontier_history = Vec::new();
408        let mut total_rollouts = 0;
409        let mut total_lm_calls = 0;
410
411        // Main evolutionary loop
412        for generation in 0..self.num_iterations {
413            println!("\nGeneration {}/{}", generation + 1, self.num_iterations);
414
415            // Check budget constraints
416            if let Some(max_rollouts) = self.max_rollouts
417                && total_rollouts >= max_rollouts
418            {
419                println!("  Budget limit reached: max rollouts");
420                break;
421            }
422
423            // Sample candidate from frontier (proportional to coverage)
424            let parent = frontier
425                .sample_proportional_to_coverage()
426                .context("Failed to sample from frontier")?
427                .clone();
428
429            println!(
430                "  Sampled parent (ID {}): avg score {:.3}",
431                parent.id,
432                parent.average_score()
433            );
434
435            // Sample minibatch
436            let minibatch: Vec<Example> =
437                trainset.iter().take(self.minibatch_size).cloned().collect();
438
439            // Apply parent instruction to module
440            {
441                let mut predictors = module.parameters();
442                if let Some(predictor) = predictors.get_mut(&parent.module_name) {
443                    predictor.update_signature_instruction(parent.instruction.clone())?;
444                }
445            }
446
447            // Collect execution traces
448            let traces = self.collect_traces(module, &minibatch).await?;
449            total_rollouts += traces.len();
450
451            // Generate mutation through LLM reflection
452            let task_desc = "Perform the task as specified";
453            let new_instruction = self
454                .generate_mutation(&parent.instruction, &traces, task_desc)
455                .await?;
456
457            total_lm_calls += 2; // Reflection + proposal
458
459            println!("  Generated new instruction through reflection");
460
461            // Create child candidate
462            let child = parent.mutate(new_instruction.clone(), generation + 1);
463
464            // Apply child instruction and evaluate
465            {
466                let mut predictors = module.parameters();
467                if let Some(predictor) = predictors.get_mut(&child.module_name) {
468                    predictor.update_signature_instruction(child.instruction.clone())?;
469                }
470            }
471
472            let child_scores = self.evaluate_candidate(module, eval_set, &child).await?;
473            total_rollouts += child_scores.len();
474
475            let child_avg = child_scores.iter().sum::<f32>() / child_scores.len() as f32;
476            println!("  Child avg score: {:.3}", child_avg);
477
478            // Add to frontier
479            let added = frontier.add_candidate(child.clone(), &child_scores);
480            if added {
481                println!("  Added to Pareto frontier");
482            } else {
483                println!("  Dominated, not added");
484            }
485
486            // Track statistics
487            if self.track_stats {
488                all_candidates.push(child);
489                let best_avg = frontier
490                    .best_by_average()
491                    .map(|c| c.average_score())
492                    .unwrap_or(0.0);
493                evolution_history.push((generation, best_avg));
494                frontier_history.push(frontier.statistics());
495            }
496
497            println!("  Frontier size: {}", frontier.len());
498        }
499
500        // Get best candidate
501        let best_candidate = frontier
502            .best_by_average()
503            .context("No candidates on frontier")?
504            .clone();
505
506        println!("\nGEPA optimization complete");
507        println!(
508            "  Best average score: {:.3}",
509            best_candidate.average_score()
510        );
511        println!("  Total rollouts: {}", total_rollouts);
512        println!("  Total LM calls: {}", total_lm_calls);
513
514        // Apply best instruction to module
515        {
516            let mut predictors = module.parameters();
517            if let Some(predictor) = predictors.get_mut(&best_candidate.module_name) {
518                predictor.update_signature_instruction(best_candidate.instruction.clone())?;
519            }
520        }
521
522        Ok(GEPAResult {
523            best_candidate,
524            all_candidates,
525            total_rollouts,
526            total_lm_calls,
527            evolution_history,
528            highest_score_achieved_per_val_task: vec![], // TODO: Track per-task bests
529            best_outputs_valset: None, // TODO: Implement if track_best_outputs is true
530            frontier_history,
531        })
532    }
533}