use std::sync::Arc;
use std::time::Duration;
use serde::Serialize;
use tracing::info;
use crate::core::adapter::{Candidate, GEPAAdapter};
use crate::core::callbacks::GEPACallback;
use crate::core::component::ComponentMetaMap;
use crate::core::data_loader::{DataId, DataLoader};
use crate::core::engine::GEPAEngine;
use crate::core::result::GEPAResult;
use crate::core::state::FrontierType;
use crate::error::{GEPAError, Result};
use crate::lm::OpenAICompatibleLM;
use crate::proposer::merge::MergeProposer;
use crate::proposer::reflective_mutation::ReflectiveMutationProposer;
use crate::strategies::batch_sampler::EpochShuffledSampler;
use crate::strategies::candidate_selector::{
CurrentBestSelector, EpsilonGreedySelector, ParetoCandidateSelector,
};
use crate::strategies::component_selector::{AllComponentSelector, RoundRobinSelector};
use crate::strategies::eval_policy::FullEvalPolicy;
use crate::tracking::NoopTracker;
use crate::utils::stop_condition::{
CompositeMode, CompositeStopper, MaxIterationsStopper, MaxMetricCallsStopper, StopCondition,
TimeoutStopper,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CandidateSelectorKind {
#[default]
Pareto,
CurrentBest,
EpsilonGreedy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ComponentSelectorKind {
All,
#[default]
RoundRobin,
}
#[derive(Debug, Clone)]
pub struct StopConditionConfig {
pub max_metric_calls: Option<usize>,
pub max_iterations: Option<usize>,
pub timeout: Option<Duration>,
}
impl Default for StopConditionConfig {
fn default() -> Self {
Self {
max_metric_calls: Some(500),
max_iterations: None,
timeout: None,
}
}
}
#[derive(Clone)]
pub struct LMConfig {
pub model: String,
pub api_key: String,
pub base_url: String,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub max_retries: u32,
}
impl std::fmt::Debug for LMConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LMConfig")
.field("model", &self.model)
.field("api_key", &"***REDACTED***")
.field("base_url", &self.base_url)
.field("temperature", &self.temperature)
.field("max_tokens", &self.max_tokens)
.field("max_retries", &self.max_retries)
.finish()
}
}
impl Default for LMConfig {
fn default() -> Self {
Self {
model: "gpt-4o-mini".into(),
api_key: String::new(),
base_url: "https://api.openai.com".into(),
temperature: Some(1.0),
max_tokens: Some(4096),
max_retries: 3,
}
}
}
pub struct OptimizeConfig<Id, Item, T, RO>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
pub seed_candidate: Candidate,
pub trainset: Arc<dyn DataLoader<Id, Item>>,
pub valset: Arc<dyn DataLoader<Id, Item>>,
pub adapter: Arc<dyn GEPAAdapter<Item, T, RO>>,
pub lm_config: LMConfig,
pub stop_condition: StopConditionConfig,
pub candidate_selector: CandidateSelectorKind,
pub epsilon: f64,
pub component_selector: ComponentSelectorKind,
pub minibatch_size: usize,
pub use_merge: bool,
pub max_merge_invocations: usize,
pub val_overlap_floor: usize,
pub frontier_type: FrontierType,
pub perfect_score: Option<f64>,
pub skip_perfect_score: bool,
pub reflection_prompt_template:
Option<crate::proposer::reflective_mutation::PromptTemplateConfig>,
pub component_metadata: ComponentMetaMap,
pub callbacks: Vec<Box<dyn GEPACallback<Id>>>,
pub rng_seed: Option<u64>,
pub run_dir: Option<String>,
pub str_candidate_key: Option<String>,
pub track_best_outputs: bool,
pub cache_evaluation: bool,
}
impl<Id, Item, T, RO> OptimizeConfig<Id, Item, T, RO>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
pub fn new(
seed_candidate: Candidate,
trainset: Arc<dyn DataLoader<Id, Item>>,
valset: Arc<dyn DataLoader<Id, Item>>,
adapter: Arc<dyn GEPAAdapter<Item, T, RO>>,
lm_config: LMConfig,
) -> Self {
Self {
seed_candidate,
trainset,
valset,
adapter,
lm_config,
stop_condition: StopConditionConfig::default(),
candidate_selector: CandidateSelectorKind::default(),
epsilon: 0.1,
component_selector: ComponentSelectorKind::default(),
minibatch_size: 3,
use_merge: false,
max_merge_invocations: 5,
val_overlap_floor: 5,
frontier_type: FrontierType::Instance,
perfect_score: Some(1.0),
skip_perfect_score: true,
reflection_prompt_template: None,
component_metadata: ComponentMetaMap::new(),
callbacks: vec![],
rng_seed: None,
run_dir: None,
str_candidate_key: None,
track_best_outputs: true,
cache_evaluation: false,
}
}
}
pub async fn optimize<Id, Item, T, RO>(
config: OptimizeConfig<Id, Item, T, RO>,
) -> Result<GEPAResult<Id>>
where
Id: DataId,
Item: Clone + Send + Sync + 'static,
T: Send + Sync + 'static,
RO: Send + Sync + Serialize + 'static,
{
let lm_cfg = &config.lm_config;
let lm = OpenAICompatibleLM::new(
lm_cfg.model.clone(),
lm_cfg.api_key.clone(),
lm_cfg.base_url.clone(),
lm_cfg.temperature,
lm_cfg.max_tokens,
)
.map_err(|e| GEPAError::Config(format!("Failed to construct LM client: {e}")))?
.with_max_retries(lm_cfg.max_retries);
let lm = Arc::new(lm);
let stop: Box<dyn StopCondition<Id>> = build_stop_condition(&config.stop_condition);
let rng_seed = config.rng_seed.unwrap_or(0);
let candidate_selector: Box<dyn crate::strategies::candidate_selector::CandidateSelector<Id>> =
match config.candidate_selector {
CandidateSelectorKind::Pareto => Box::new(ParetoCandidateSelector::new(rng_seed)),
CandidateSelectorKind::CurrentBest => Box::new(CurrentBestSelector),
CandidateSelectorKind::EpsilonGreedy => Box::new(
EpsilonGreedySelector::new(config.epsilon, rng_seed)
.map_err(|e| GEPAError::Config(format!("Invalid epsilon: {e}")))?,
),
};
let component_selector: Box<dyn crate::strategies::component_selector::ComponentSelector<Id>> =
match config.component_selector {
ComponentSelectorKind::All => Box::new(AllComponentSelector),
ComponentSelectorKind::RoundRobin => Box::new(RoundRobinSelector),
};
let batch_sampler = Box::new(
EpochShuffledSampler::new(config.minibatch_size, rng_seed)
.map_err(|e| GEPAError::Config(format!("Invalid minibatch_size: {e}")))?,
);
let mutation_proposer = ReflectiveMutationProposer {
trainset: config.trainset.clone(),
adapter: config.adapter.clone(),
candidate_selector,
component_selector,
batch_sampler,
reflection_lm: lm.clone(),
reflection_prompt_template: config.reflection_prompt_template.clone(),
component_metadata: config.component_metadata.clone(),
perfect_score: config.perfect_score,
skip_perfect_score: config.skip_perfect_score,
};
let merge_proposer = MergeProposer::new(
config.valset.clone(),
config.adapter.clone(),
config.use_merge,
config.max_merge_invocations,
config.val_overlap_floor,
rng_seed,
)
.map_err(|e| GEPAError::Config(format!("Failed to construct MergeProposer: {e}")))?;
let trainset_len = config.trainset.all_ids().len();
let valset_len = config.valset.all_ids().len();
info!(
trainset_size = trainset_len,
valset_size = valset_len,
stop = %stop.description(),
"Starting GEPA optimisation"
);
let mut engine = GEPAEngine {
trainset: config.trainset,
valset: config.valset,
adapter: config.adapter,
seed_candidate: config.seed_candidate,
mutation_proposer,
merge_proposer,
eval_policy: Box::new(FullEvalPolicy),
stop_condition: stop,
frontier_type: config.frontier_type,
callbacks: config.callbacks,
rng_seed: config.rng_seed,
run_dir: config.run_dir,
str_candidate_key: config.str_candidate_key,
track_best_outputs: config.track_best_outputs,
cache_evaluation: config.cache_evaluation,
tracker: Box::new(NoopTracker),
};
engine.run().await
}
fn build_stop_condition<Id: DataId>(cfg: &StopConditionConfig) -> Box<dyn StopCondition<Id>> {
let mut composite: CompositeStopper<Id> = CompositeStopper::new(CompositeMode::Any);
let mut has_any = false;
if let Some(max_calls) = cfg.max_metric_calls {
composite = composite.push_condition(MaxMetricCallsStopper::new(max_calls));
has_any = true;
}
if let Some(max_iters) = cfg.max_iterations {
composite = composite.push_condition(MaxIterationsStopper::new(max_iters));
has_any = true;
}
if let Some(timeout) = cfg.timeout {
composite = composite.push_condition(TimeoutStopper::new(timeout));
has_any = true;
}
if !has_any {
composite = composite.push_condition(MaxIterationsStopper::new(10_000));
}
Box::new(composite)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use async_trait::async_trait;
use crate::core::adapter::{Candidate, EvaluationBatch, GEPAAdapter, ReflectiveDataset};
use crate::core::data_loader::VecLoader;
use crate::error::Result;
struct ConstantAdapter;
#[async_trait]
impl GEPAAdapter<String, (), String> for ConstantAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
_capture_traces: bool,
) -> Result<EvaluationBatch<(), String>> {
Ok(EvaluationBatch::new(batch.to_vec(), vec![0.8; batch.len()]))
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components: &[String],
) -> Result<ReflectiveDataset> {
Ok(components.iter().map(|k| (k.clone(), vec![])).collect())
}
}
fn make_config(max_iters: usize) -> OptimizeConfig<usize, String, (), String> {
let trainset = Arc::new(VecLoader::new(vec![
"train0".to_string(),
"train1".to_string(),
]));
let valset = Arc::new(VecLoader::new(vec!["val0".to_string(), "val1".to_string()]));
let adapter: Arc<dyn GEPAAdapter<String, (), String>> = Arc::new(ConstantAdapter);
let mut seed = Candidate::new();
seed.insert("instructions".into(), "Be helpful.".into());
let lm_config = LMConfig {
model: "test-model".into(),
api_key: String::new(),
base_url: "http://localhost:11434".into(),
temperature: None,
max_tokens: Some(64),
max_retries: 0,
};
let mut cfg = OptimizeConfig::new(seed, trainset, valset, adapter, lm_config);
cfg.stop_condition = StopConditionConfig {
max_metric_calls: None,
max_iterations: Some(max_iters),
timeout: None,
};
cfg.use_merge = false;
cfg
}
#[test]
fn stop_condition_builds_correctly() {
let cfg = StopConditionConfig {
max_metric_calls: Some(100),
max_iterations: Some(10),
timeout: None,
};
let stop: Box<dyn StopCondition<usize>> = build_stop_condition(&cfg);
assert!(stop.description().contains("Any"));
}
#[test]
fn stop_condition_empty_uses_safety_valve() {
let cfg = StopConditionConfig {
max_metric_calls: None,
max_iterations: None,
timeout: None,
};
let stop: Box<dyn StopCondition<usize>> = build_stop_condition(&cfg);
assert!(!stop.description().is_empty());
}
#[test]
fn optimize_config_default_stop_condition() {
let cfg: StopConditionConfig = StopConditionConfig::default();
assert_eq!(cfg.max_metric_calls, Some(500));
}
#[test]
fn lm_config_default_values() {
let cfg = LMConfig::default();
assert_eq!(cfg.model, "gpt-4o-mini");
assert_eq!(cfg.max_retries, 3);
}
#[tokio::test]
async fn optimize_zero_iterations_returns_seed() {
let mut config = make_config(0);
config.lm_config.base_url = "http://127.0.0.1:19999".into();
config.lm_config.max_retries = 0;
let result = optimize(config).await.expect("should succeed with 0 iters");
assert_eq!(result.num_candidates(), 1);
assert_eq!(
result.candidates[0].get("instructions").unwrap(),
"Be helpful."
);
}
}