use anyhow::{Context, Result};
use bon::Builder;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate as dspy_rs;
use crate::{
Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
evaluate::FeedbackEvaluator, example,
};
use dsrs_macros::Signature;
use super::pareto::ParetoFrontier;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GEPACandidate {
pub id: usize,
pub instruction: String,
pub module_name: String,
pub example_scores: Vec<f32>,
pub parent_id: Option<usize>,
pub generation: usize,
}
impl GEPACandidate {
pub fn from_predictor(predictor: &dyn Optimizable, module_name: impl Into<String>) -> Self {
Self {
id: 0,
instruction: predictor.get_signature().instruction(),
module_name: module_name.into(),
example_scores: Vec::new(),
parent_id: None,
generation: 0,
}
}
pub fn average_score(&self) -> f32 {
if self.example_scores.is_empty() {
return 0.0;
}
self.example_scores.iter().sum::<f32>() / self.example_scores.len() as f32
}
pub fn mutate(&self, new_instruction: String, generation: usize) -> Self {
Self {
id: 0, instruction: new_instruction,
module_name: self.module_name.clone(),
example_scores: Vec::new(),
parent_id: Some(self.id),
generation,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GEPAResult {
pub best_candidate: GEPACandidate,
pub all_candidates: Vec<GEPACandidate>,
pub total_rollouts: usize,
pub total_lm_calls: usize,
pub evolution_history: Vec<(usize, f32)>,
pub highest_score_achieved_per_val_task: Vec<f32>,
pub best_outputs_valset: Option<Vec<Prediction>>,
pub frontier_history: Vec<ParetoStatistics>,
}
pub use super::pareto::ParetoStatistics;
#[Signature]
struct ReflectOnTrace {
#[input(desc = "The current instruction for the module")]
pub current_instruction: String,
#[input(desc = "Execution traces showing inputs, outputs, and evaluation feedback")]
pub traces: String,
#[input(desc = "Description of what the module should accomplish")]
pub task_description: String,
#[output(desc = "Analysis of weaknesses and specific improvement suggestions")]
pub reflection: String,
}
#[Signature]
struct ProposeImprovedInstruction {
#[input(desc = "The current instruction")]
pub current_instruction: String,
#[input(desc = "Reflection on weaknesses and improvement suggestions")]
pub reflection: String,
#[input(desc = "Execution traces and feedback from recent rollouts")]
pub traces_and_feedback: String,
#[output(desc = "An improved instruction that addresses the identified weaknesses")]
pub improved_instruction: String,
}
#[Signature]
struct SelectModuleToImprove {
#[input(desc = "List of modules with their current instructions and performance")]
pub module_summary: String,
#[input(desc = "Recent execution traces showing module interactions")]
pub execution_traces: String,
#[output(desc = "Name of the module to optimize and reasoning")]
pub selected_module: String,
}
#[derive(Builder)]
pub struct GEPA {
#[builder(default = 20)]
pub num_iterations: usize,
#[builder(default = 25)]
pub minibatch_size: usize,
#[builder(default = 10)]
pub num_trials: usize,
#[builder(default = 1.0)]
pub temperature: f32,
#[builder(default = true)]
pub track_stats: bool,
#[builder(default = false)]
pub track_best_outputs: bool,
pub max_rollouts: Option<usize>,
pub max_lm_calls: Option<usize>,
pub prompt_model: Option<LM>,
pub valset: Option<Vec<Example>>,
}
impl GEPA {
async fn initialize_frontier<M>(
&self,
module: &mut M,
trainset: &[Example],
) -> Result<ParetoFrontier>
where
M: Module + Optimizable + FeedbackEvaluator,
{
let mut frontier = ParetoFrontier::new();
let candidate_infos: Vec<GEPACandidate> = {
let predictors = module.parameters();
predictors
.into_iter()
.map(|(name, predictor)| GEPACandidate::from_predictor(predictor, name))
.collect()
};
for candidate in candidate_infos {
let scores = self
.evaluate_candidate(module, trainset, &candidate)
.await?;
frontier.add_candidate(candidate, &scores);
}
Ok(frontier)
}
async fn evaluate_candidate<M>(
&self,
module: &M,
examples: &[Example],
_candidate: &GEPACandidate,
) -> Result<Vec<f32>>
where
M: Module + FeedbackEvaluator,
{
use futures::future::join_all;
let futures: Vec<_> = examples
.iter()
.map(|example| async move {
let prediction = module.forward(example.clone()).await?;
let feedback = module.feedback_metric(example, &prediction).await;
Ok::<f32, anyhow::Error>(feedback.score)
})
.collect();
let results = join_all(futures).await;
results.into_iter().collect()
}
async fn collect_traces<M>(
&self,
module: &M,
minibatch: &[Example],
) -> Result<Vec<(Example, Prediction, String)>>
where
M: Module + FeedbackEvaluator,
{
let mut traces = Vec::with_capacity(minibatch.len());
for example in minibatch {
let prediction = module.forward(example.clone()).await?;
let feedback = module.feedback_metric(example, &prediction).await;
let trace_text = format!(
"Input: {:?}\nOutput: {:?}\nScore: {:.3}\nFeedback: {}",
example, prediction, feedback.score, feedback.feedback
);
traces.push((example.clone(), prediction, trace_text));
}
Ok(traces)
}
async fn generate_mutation(
&self,
current_instruction: &str,
traces: &[(Example, Prediction, String)],
task_description: &str,
) -> Result<String> {
let traces_text = traces
.iter()
.enumerate()
.map(|(i, (_, _, trace))| format!("=== Trace {} ===\n{}\n", i + 1, trace))
.collect::<Vec<_>>()
.join("\n");
let reflect_predictor = Predict::new(ReflectOnTrace::new());
let reflection_input = example! {
"current_instruction": "input" => current_instruction,
"traces": "input" => &traces_text,
"task_description": "input" => task_description
};
let reflection_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
prompt_model.temperature = self.temperature;
reflect_predictor
.forward_with_config(reflection_input, Arc::new(prompt_model))
.await?
} else {
reflect_predictor.forward(reflection_input).await?
};
let reflection = reflection_output
.get("reflection", None)
.as_str()
.unwrap_or("")
.to_string();
let propose_predictor = Predict::new(ProposeImprovedInstruction::new());
let proposal_input = example! {
"current_instruction": "input" => current_instruction,
"reflection": "input" => &reflection,
"traces_and_feedback": "input" => &traces_text
};
let proposal_output = if let Some(mut prompt_model) = self.prompt_model.clone() {
prompt_model.temperature = self.temperature;
propose_predictor
.forward_with_config(proposal_input, Arc::new(prompt_model))
.await?
} else {
propose_predictor.forward(proposal_input).await?
};
let improved = proposal_output
.get("improved_instruction", None)
.as_str()
.unwrap_or(current_instruction)
.to_string();
Ok(improved)
}
}
impl Optimizer for GEPA {
async fn compile<M>(&self, _module: &mut M, _trainset: Vec<Example>) -> Result<()>
where
M: Module + Optimizable + crate::Evaluator,
{
anyhow::bail!(
"GEPA requires the module to implement FeedbackEvaluator trait. \
Please implement feedback_metric() method that returns FeedbackMetric."
)
}
}
impl GEPA {
pub async fn compile_with_feedback<M>(
&self,
module: &mut M,
trainset: Vec<Example>,
) -> Result<GEPAResult>
where
M: Module + Optimizable + FeedbackEvaluator,
{
println!("GEPA: Starting reflective prompt optimization");
println!(" Iterations: {}", self.num_iterations);
println!(" Minibatch size: {}", self.minibatch_size);
let eval_set = self.valset.as_ref().unwrap_or(&trainset);
let mut frontier = self.initialize_frontier(&mut *module, eval_set).await?;
println!(" Initialized frontier with {} candidates", frontier.len());
let mut all_candidates = Vec::new();
let mut evolution_history = Vec::new();
let mut frontier_history = Vec::new();
let mut total_rollouts = 0;
let mut total_lm_calls = 0;
for generation in 0..self.num_iterations {
println!("\nGeneration {}/{}", generation + 1, self.num_iterations);
if let Some(max_rollouts) = self.max_rollouts
&& total_rollouts >= max_rollouts
{
println!(" Budget limit reached: max rollouts");
break;
}
let parent = frontier
.sample_proportional_to_coverage()
.context("Failed to sample from frontier")?
.clone();
println!(
" Sampled parent (ID {}): avg score {:.3}",
parent.id,
parent.average_score()
);
let minibatch: Vec<Example> =
trainset.iter().take(self.minibatch_size).cloned().collect();
{
let mut predictors = module.parameters();
if let Some(predictor) = predictors.get_mut(&parent.module_name) {
predictor.update_signature_instruction(parent.instruction.clone())?;
}
}
let traces = self.collect_traces(module, &minibatch).await?;
total_rollouts += traces.len();
let task_desc = "Perform the task as specified";
let new_instruction = self
.generate_mutation(&parent.instruction, &traces, task_desc)
.await?;
total_lm_calls += 2;
println!(" Generated new instruction through reflection");
let child = parent.mutate(new_instruction.clone(), generation + 1);
{
let mut predictors = module.parameters();
if let Some(predictor) = predictors.get_mut(&child.module_name) {
predictor.update_signature_instruction(child.instruction.clone())?;
}
}
let child_scores = self.evaluate_candidate(module, eval_set, &child).await?;
total_rollouts += child_scores.len();
let child_avg = child_scores.iter().sum::<f32>() / child_scores.len() as f32;
println!(" Child avg score: {:.3}", child_avg);
let added = frontier.add_candidate(child.clone(), &child_scores);
if added {
println!(" Added to Pareto frontier");
} else {
println!(" Dominated, not added");
}
if self.track_stats {
all_candidates.push(child);
let best_avg = frontier
.best_by_average()
.map(|c| c.average_score())
.unwrap_or(0.0);
evolution_history.push((generation, best_avg));
frontier_history.push(frontier.statistics());
}
println!(" Frontier size: {}", frontier.len());
}
let best_candidate = frontier
.best_by_average()
.context("No candidates on frontier")?
.clone();
println!("\nGEPA optimization complete");
println!(
" Best average score: {:.3}",
best_candidate.average_score()
);
println!(" Total rollouts: {}", total_rollouts);
println!(" Total LM calls: {}", total_lm_calls);
{
let mut predictors = module.parameters();
if let Some(predictor) = predictors.get_mut(&best_candidate.module_name) {
predictor.update_signature_instruction(best_candidate.instruction.clone())?;
}
}
Ok(GEPAResult {
best_candidate,
all_candidates,
total_rollouts,
total_lm_calls,
evolution_history,
highest_score_achieved_per_val_task: vec![], best_outputs_valset: None, frontier_history,
})
}
}