use std::sync::Arc;
use indexmap::IndexMap;
use serde::Serialize;
use tracing::{debug, warn};
use crate::core::adapter::{Candidate, GEPAAdapter, ReflectiveDataset};
use crate::core::component::{ComponentKind, ComponentMeta, ComponentMetaMap};
use crate::core::data_loader::{DataId, DataLoader};
use crate::core::state::GEPAState;
use crate::error::{GEPAError, Result};
use crate::lm::LanguageModel;
use crate::proposer::base::CandidateProposal;
use crate::strategies::batch_sampler::BatchSampler;
use crate::strategies::candidate_selector::CandidateSelector;
use crate::strategies::component_selector::ComponentSelector;
use crate::strategies::instruction_proposal::{
extract_output, format_samples_as_markdown, render_code_prompt, render_config_prompt,
render_prompt, render_prompt_with_template,
};
#[derive(Debug, Clone)]
pub enum PromptTemplateConfig {
Single(String),
PerComponent(IndexMap<String, String>),
}
impl PromptTemplateConfig {
pub fn get_for_component(&self, component: &str) -> Option<&str> {
match self {
Self::Single(s) => Some(s.as_str()),
Self::PerComponent(map) => map.get(component).map(String::as_str),
}
}
}
pub struct ReflectiveMutationProposer<Id, Item, T, RO>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
pub trainset: Arc<dyn DataLoader<Id, Item>>,
pub adapter: Arc<dyn GEPAAdapter<Item, T, RO>>,
pub candidate_selector: Box<dyn CandidateSelector<Id>>,
pub component_selector: Box<dyn ComponentSelector<Id>>,
pub(crate) batch_sampler: Box<dyn BatchSampler<Id>>,
pub(crate) reflection_lm: Arc<dyn LanguageModel>,
pub reflection_prompt_template: Option<PromptTemplateConfig>,
pub component_metadata: ComponentMetaMap,
pub perfect_score: Option<f64>,
pub skip_perfect_score: bool,
}
impl<Id, Item, T, RO> ReflectiveMutationProposer<Id, Item, T, RO>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
async fn propose_new_text_for_component(
&self,
component_name: &str,
current_text: &str,
records: &[serde_json::Value],
) -> Result<String> {
let dataset_str = format_samples_as_markdown(records);
let prompt = match &self.reflection_prompt_template {
None => {
self.render_default_prompt_for_component(component_name, current_text, &dataset_str)
}
Some(cfg) => match cfg.get_for_component(component_name) {
Some(template) => render_prompt_with_template(template, current_text, &dataset_str)
.map_err(GEPAError::Config)?,
None => self.render_default_prompt_for_component(
component_name,
current_text,
&dataset_str,
),
},
};
debug!(component = %component_name, prompt_len = %prompt.len(), "calling reflection LM");
let raw_output = self.reflection_lm.complete(&prompt).await?;
let new_text = extract_output(&raw_output).ok_or_else(|| {
GEPAError::Proposal(format!(
"LM produced no extractable output for component '{component_name}'"
))
})?;
Ok(new_text)
}
fn render_default_prompt_for_component(
&self,
component_name: &str,
current_text: &str,
dataset_str: &str,
) -> String {
let Some(meta) = self.component_metadata.get(component_name) else {
return render_prompt(current_text, dataset_str);
};
let prompt = match meta.kind {
ComponentKind::Text => render_prompt(current_text, dataset_str),
ComponentKind::Code => render_code_prompt(
current_text,
dataset_str,
component_name,
meta.language.as_deref().unwrap_or("text"),
),
ComponentKind::Config => {
let constraints = format_constraints(meta);
render_config_prompt(current_text, dataset_str, constraints.as_deref())
}
};
with_component_description(meta, prompt)
}
}
fn format_constraints(meta: &ComponentMeta) -> Option<String> {
let constraints = meta.constraints.as_ref()?;
if constraints.is_empty() {
return None;
}
let mut entries: Vec<(&String, &String)> = constraints.iter().collect();
entries.sort_by_key(|(key, _)| *key);
Some(
entries
.into_iter()
.map(|(key, value)| format!("- {key}: {value}"))
.collect::<Vec<_>>()
.join("\n"),
)
}
fn with_component_description(meta: &ComponentMeta, prompt: String) -> String {
let description = meta.description.trim();
if description.is_empty() {
prompt
} else {
format!("Component description: {description}\n\n{prompt}")
}
}
impl<Id, Item, T, RO> ReflectiveMutationProposer<Id, Item, T, RO>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
async fn propose_new_texts_for_components(
&self,
curr_prog: &Candidate,
reflective_dataset: &ReflectiveDataset,
predictor_names: &[String],
iter_i: usize,
) -> Result<Candidate> {
let mut new_texts: Candidate = Candidate::new();
for name in predictor_names {
let records = match reflective_dataset.get(name) {
Some(r) if !r.is_empty() => r.clone(),
_ => {
debug!(
component = %name,
"component has no data in reflective dataset — skipping"
);
continue;
}
};
let current_text = curr_prog.get(name).cloned().unwrap_or_default();
match self
.propose_new_text_for_component(name, ¤t_text, &records)
.await
{
Ok(text) => {
debug!(component = %name, iter = iter_i, "proposed new text");
new_texts.insert(name.clone(), text);
}
Err(e) => {
warn!(component = %name, error = %e, "failed to propose text for component");
}
}
}
Ok(new_texts)
}
pub async fn propose_mut(
&mut self,
state: &mut GEPAState<Id>,
) -> Result<Option<CandidateProposal<Id>>> {
let iter_i = state.i.wrapping_add(1);
let curr_prog_id = self.candidate_selector.select_candidate_idx(state)?;
let curr_prog = state.program_candidates[curr_prog_id].clone();
let curr_score = state.program_full_scores_val_set()[curr_prog_id];
debug!(
iteration = iter_i,
candidate_idx = curr_prog_id,
score = curr_score,
"selected program candidate"
);
let all_ids = self.trainset.all_ids();
let subsample_ids = self.batch_sampler.next_batch_ids(&all_ids, state)?;
let minibatch = self.trainset.fetch(&subsample_ids)?;
debug!(
iteration = iter_i,
batch_size = subsample_ids.len(),
"sampled minibatch"
);
let eval_curr = self
.adapter
.evaluate(&minibatch, &curr_prog, true)
.await
.map_err(|e| GEPAError::Evaluation(e.to_string()))?;
eval_curr.validate_lengths(minibatch.len(), false)?;
state.increment_evals(subsample_ids.len());
let trajectories = match &eval_curr.trajectories {
Some(t) if !t.is_empty() => t,
_ => {
debug!(iteration = iter_i, "no trajectories captured — skipping");
return Ok(None);
}
};
if self.skip_perfect_score
&& let Some(threshold) = self.perfect_score
&& eval_curr.scores.iter().all(|&s| s >= threshold)
{
debug!(
iteration = iter_i,
threshold, "all subsample scores are perfect — skipping"
);
return Ok(None);
}
let subsample_before: f64 = eval_curr.scores.iter().sum();
let traj_values: Vec<serde_json::Value> = trajectories
.iter()
.map(|_| serde_json::Value::Null)
.collect();
let predictor_names = self.component_selector.select_components(
state,
&traj_values,
&eval_curr.scores,
curr_prog_id,
&curr_prog,
)?;
if predictor_names.is_empty() {
debug!(iteration = iter_i, "no components selected — skipping");
return Ok(None);
}
let reflective_dataset = self
.adapter
.make_reflective_dataset(&curr_prog, &eval_curr, &predictor_names)
.await
.map_err(|e| GEPAError::AdapterReflectiveDataset(e.to_string()))?;
let new_texts: Candidate = self
.propose_new_texts_for_components(
&curr_prog,
&reflective_dataset,
&predictor_names,
iter_i,
)
.await?;
if new_texts.is_empty() {
debug!(
iteration = iter_i,
"no new texts proposed (all components absent from reflective dataset)"
);
return Ok(None);
}
let mut new_candidate = curr_prog.clone();
for (name, text) in &new_texts {
if !new_candidate.contains_key(name) {
return Err(GEPAError::Proposal(format!(
"proposed component '{name}' is not a key in the current candidate"
)));
}
new_candidate.insert(name.clone(), text.clone());
}
let eval_new = self
.adapter
.evaluate(&minibatch, &new_candidate, false)
.await
.map_err(|e| GEPAError::Evaluation(e.to_string()))?;
eval_new.validate_lengths(minibatch.len(), false)?;
state.increment_evals(subsample_ids.len());
let new_sum: f64 = eval_new.scores.iter().sum();
debug!(
iteration = iter_i,
subsample_before,
subsample_after = new_sum,
"evaluated new candidate on minibatch"
);
Ok(Some(
CandidateProposal::new(new_candidate, vec![curr_prog_id])
.with_tag("reflective_mutation")
.with_subsample(subsample_ids, eval_curr.scores, eval_new.scores),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::adapter::{Candidate, EvaluationBatch, GEPAAdapter, ReflectiveDataset};
use crate::core::data_loader::VecLoader;
use crate::core::state::{FrontierType, ValsetEvaluation};
use crate::lm::LanguageModel;
use crate::strategies::batch_sampler::EpochShuffledSampler;
use crate::strategies::candidate_selector::CurrentBestSelector;
use crate::strategies::component_selector::AllComponentSelector;
use std::sync::Arc;
struct PerfectAdapter;
#[async_trait::async_trait]
impl GEPAAdapter<String, (), String> for PerfectAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
capture_traces: bool,
) -> Result<EvaluationBatch<(), String>> {
let n = batch.len();
let mut eval = EvaluationBatch::new(batch.to_vec(), vec![1.0; n]);
if capture_traces {
eval = eval.with_trajectories(vec![(); n]);
}
Ok(eval)
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components: &[String],
) -> Result<ReflectiveDataset> {
Ok(components
.iter()
.map(|name| {
(
name.clone(),
vec![serde_json::json!({
"Inputs": "test input",
"Generated Outputs": "test output",
"Feedback": "looks good"
})],
)
})
.collect())
}
}
struct EchoLM;
#[async_trait::async_trait]
impl LanguageModel for EchoLM {
async fn complete(&self, _prompt: &str) -> Result<String> {
Ok("```\nImproved instruction text.\n```".into())
}
}
fn make_proposer(
trainset: Vec<String>,
) -> ReflectiveMutationProposer<usize, String, (), String> {
ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(trainset)),
adapter: Arc::new(PerfectAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(EchoLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: Some(1.0),
skip_perfect_score: false, }
}
fn make_state() -> GEPAState<usize> {
let mut seed = Candidate::new();
seed.insert("instructions".into(), "Seed instruction.".into());
let eval = ValsetEvaluation::from_vecs(
vec![0usize, 1, 2],
vec![
serde_json::json!("a"),
serde_json::json!("b"),
serde_json::json!("c"),
],
vec![0.5, 0.6, 0.7],
None,
);
let mut s =
GEPAState::new(seed, eval, FrontierType::Instance, None).expect("should construct");
s.i = s.i.wrapping_add(1); s
}
#[tokio::test]
async fn propose_mut_returns_proposal_for_non_perfect_scores() {
let mut proposer = make_proposer(vec![
"example 0".into(),
"example 1".into(),
"example 2".into(),
]);
let mut state = make_state();
let proposal = proposer
.propose_mut(&mut state)
.await
.expect("propose_mut should not error");
assert!(
proposal.is_some(),
"should return a proposal when skip_perfect_score=false"
);
let p = proposal.unwrap();
assert_eq!(p.tag, "reflective_mutation");
assert!(!p.parent_program_ids.is_empty());
}
#[tokio::test]
async fn propose_mut_skips_when_all_perfect_and_flag_set() {
let mut proposer = make_proposer(vec!["example 0".into(), "example 1".into()]);
proposer.skip_perfect_score = true;
proposer.perfect_score = Some(1.0);
let mut state = make_state();
let proposal = proposer
.propose_mut(&mut state)
.await
.expect("should not error");
assert!(
proposal.is_none(),
"should skip when all scores are perfect"
);
}
#[tokio::test]
async fn propose_mut_skips_empty_trainset() {
let mut proposer: ReflectiveMutationProposer<usize, String, (), String> =
ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec![])),
adapter: Arc::new(PerfectAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(EchoLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: None,
skip_perfect_score: false,
};
let mut state = make_state();
let result = proposer.propose_mut(&mut state).await;
assert!(result.is_err(), "empty trainset should cause an error");
}
#[tokio::test]
async fn test_lm_returns_empty_string_errors() {
struct EmptyLM;
#[async_trait::async_trait]
impl LanguageModel for EmptyLM {
async fn complete(&self, _prompt: &str) -> crate::error::Result<String> {
Ok(String::new())
}
}
let mut proposer = ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec!["ex0".to_string(), "ex1".to_string()])),
adapter: Arc::new(PerfectAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(EmptyLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: None,
skip_perfect_score: false,
};
let mut state = make_state();
let result = proposer.propose_mut(&mut state).await;
assert!(
result.is_ok(),
"propose_mut should not propagate per-component LM errors"
);
assert!(
result.unwrap().is_none(),
"when all components produce no text (empty LM), result should be Ok(None)"
);
}
#[tokio::test]
async fn test_adapter_returns_no_trajectories() {
struct NoTraceAdapter;
#[async_trait::async_trait]
impl GEPAAdapter<String, (), String> for NoTraceAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
_capture_traces: bool,
) -> crate::error::Result<EvaluationBatch<(), String>> {
Ok(EvaluationBatch::new(batch.to_vec(), vec![0.5; batch.len()]))
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components: &[String],
) -> crate::error::Result<ReflectiveDataset> {
Ok(components.iter().map(|n| (n.clone(), vec![])).collect())
}
}
let mut proposer = ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec!["ex0".to_string(), "ex1".to_string()])),
adapter: Arc::new(NoTraceAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(EchoLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: None,
skip_perfect_score: false,
};
let mut state = make_state();
let result = proposer
.propose_mut(&mut state)
.await
.expect("should not error");
assert!(
result.is_none(),
"missing trajectories should return Ok(None)"
);
}
#[tokio::test]
async fn test_empty_reflective_dataset_returns_none() {
struct EmptyDatasetAdapter;
#[async_trait::async_trait]
impl GEPAAdapter<String, (), String> for EmptyDatasetAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
capture_traces: bool,
) -> crate::error::Result<EvaluationBatch<(), String>> {
let n = batch.len();
let mut eval = EvaluationBatch::new(batch.to_vec(), vec![0.5; n]);
if capture_traces {
eval = eval.with_trajectories(vec![(); n]);
}
Ok(eval)
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components: &[String],
) -> crate::error::Result<ReflectiveDataset> {
Ok(components.iter().map(|n| (n.clone(), vec![])).collect())
}
}
let mut proposer = ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec!["ex0".to_string(), "ex1".to_string()])),
adapter: Arc::new(EmptyDatasetAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(EchoLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: None,
skip_perfect_score: false,
};
let mut state = make_state();
let result = proposer
.propose_mut(&mut state)
.await
.expect("should not error");
assert!(
result.is_none(),
"empty reflective dataset for all components should return Ok(None)"
);
}
#[tokio::test]
async fn code_component_metadata_uses_code_prompt() {
use std::sync::Mutex;
struct RecordingLM {
prompt: Arc<Mutex<Option<String>>>,
}
#[async_trait::async_trait]
impl LanguageModel for RecordingLM {
async fn complete(&self, prompt: &str) -> Result<String> {
*self.prompt.lock().unwrap() = Some(prompt.to_string());
Ok("```rust\nfn improved() {}\n```".into())
}
}
let recorded_prompt = Arc::new(Mutex::new(None));
let mut seed = Candidate::new();
seed.insert("model_code".into(), "fn current() {}".into());
let eval = ValsetEvaluation::from_vecs(
vec![0usize],
vec![serde_json::json!("out")],
vec![0.5],
None,
);
let mut state =
GEPAState::new(seed, eval, FrontierType::Instance, None).expect("should construct");
state.i = state.i.wrapping_add(1);
let mut component_metadata = ComponentMetaMap::new();
component_metadata.insert(
"model_code".into(),
ComponentMeta::code("Updates the model architecture", "rust"),
);
let mut proposer = ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec!["ex0".to_string(), "ex1".to_string()])),
adapter: Arc::new(PerfectAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(AllComponentSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 0).unwrap()),
reflection_lm: Arc::new(RecordingLM {
prompt: recorded_prompt.clone(),
}),
reflection_prompt_template: None,
component_metadata,
perfect_score: None,
skip_perfect_score: false,
};
let proposal = proposer
.propose_mut(&mut state)
.await
.expect("should not error");
assert!(proposal.is_some());
let prompt = recorded_prompt
.lock()
.unwrap()
.clone()
.expect("LM should have been called");
assert!(prompt.contains("Component description: Updates the model architecture"));
assert!(prompt.contains("model_code component"));
assert!(prompt.contains("```rust"));
}
#[tokio::test]
async fn test_round_robin_selects_different_components() {
use crate::strategies::component_selector::RoundRobinSelector;
let mut seed_map = Candidate::new();
seed_map.insert("comp_a".into(), "text_a".into());
seed_map.insert("comp_b".into(), "text_b".into());
let eval = ValsetEvaluation::from_vecs(
vec![0usize, 1, 2],
vec![
serde_json::json!("a"),
serde_json::json!("b"),
serde_json::json!("c"),
],
vec![0.5, 0.6, 0.7],
None,
);
let mut state = crate::core::state::GEPAState::new(
seed_map,
eval,
crate::core::state::FrontierType::Instance,
None,
)
.expect("should construct");
state.i = state.i.wrapping_add(1);
struct TwoCompAdapter;
#[async_trait::async_trait]
impl GEPAAdapter<String, (), String> for TwoCompAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
capture_traces: bool,
) -> crate::error::Result<EvaluationBatch<(), String>> {
let n = batch.len();
let mut e = EvaluationBatch::new(batch.to_vec(), vec![0.5; n]);
if capture_traces {
e = e.with_trajectories(vec![(); n]);
}
Ok(e)
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components: &[String],
) -> crate::error::Result<ReflectiveDataset> {
Ok(components
.iter()
.map(|n| {
(
n.clone(),
vec![serde_json::json!({"Inputs": "x", "Generated Outputs": "y", "Feedback": "ok"})],
)
})
.collect())
}
}
let mut proposer1 = ReflectiveMutationProposer {
trainset: Arc::new(VecLoader::new(vec![
"ex0".to_string(),
"ex1".to_string(),
"ex2".to_string(),
])),
adapter: Arc::new(TwoCompAdapter),
candidate_selector: Box::new(CurrentBestSelector),
component_selector: Box::new(RoundRobinSelector),
batch_sampler: Box::new(EpochShuffledSampler::new(2, 1).unwrap()),
reflection_lm: Arc::new(EchoLM),
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
perfect_score: None,
skip_perfect_score: false,
};
let proposal1 = proposer1
.propose_mut(&mut state)
.await
.expect("should not error");
state.i = state.i.wrapping_add(1);
let proposal2 = proposer1
.propose_mut(&mut state)
.await
.expect("should not error");
if let (Some(p1), Some(p2)) = (proposal1, proposal2) {
let parent = &state.program_candidates[p1.parent_program_ids[0]];
let changed1: Vec<&String> = p1
.candidate
.iter()
.filter(|(k, v)| parent.get(*k) != Some(v))
.map(|(k, _)| k)
.collect();
let changed2: Vec<&String> = p2
.candidate
.iter()
.filter(|(k, v)| parent.get(*k) != Some(v))
.map(|(k, _)| k)
.collect();
assert!(
changed1 != changed2 || changed1.is_empty(),
"RoundRobinSelector should alternate between components"
);
}
}
}