#![allow(clippy::all, clippy::pedantic, unused_mut)]
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use gepa::core::data_loader::VecLoader;
use gepa::{
Candidate, EvaluationBatch, GEPAAdapter, LMConfig, OptimizeConfig, ReflectiveDataset, Result,
StopConditionConfig, optimize,
};
use serde_json::json;
#[derive(Clone)]
struct SentimentExample {
text: String,
label: String,
}
impl SentimentExample {
fn new(text: impl Into<String>, label: impl Into<String>) -> Self {
Self {
text: text.into(),
label: label.into(),
}
}
}
struct SentimentAdapter;
#[async_trait]
impl GEPAAdapter<SentimentExample, String, String> for SentimentAdapter {
async fn evaluate(
&self,
batch: &[SentimentExample],
candidate: &Candidate,
capture_traces: bool,
) -> Result<EvaluationBatch<String, String>> {
let instructions = candidate
.get("instructions")
.map(String::as_str)
.unwrap_or("Classify the sentiment of the following text.");
let mut outputs = Vec::with_capacity(batch.len());
let mut scores = Vec::with_capacity(batch.len());
let mut trajectories = Vec::with_capacity(batch.len());
for example in batch {
let prediction = mock_classify(&example.text, instructions);
let score = if prediction.trim().to_lowercase() == example.label.trim().to_lowercase() {
1.0_f64
} else {
0.0_f64
};
let trace = format!(
"instructions: {instructions}\ninput: {}\nprediction: {prediction}\nlabel: {}",
example.text, example.label
);
outputs.push(prediction);
scores.push(score);
trajectories.push(trace);
}
let mut batch_result = EvaluationBatch::new(outputs, scores);
if capture_traces {
batch_result = batch_result.with_trajectories(trajectories);
}
Ok(batch_result)
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
eval_batch: &EvaluationBatch<String, String>,
components: &[String],
) -> Result<ReflectiveDataset> {
let traces = eval_batch.trajectories.as_deref().unwrap_or(&[]);
let mut dataset: ReflectiveDataset = HashMap::new();
for component in components {
let records = eval_batch
.scores
.iter()
.zip(eval_batch.outputs.iter())
.zip(traces.iter())
.filter(|((score, _), _)| **score < 1.0)
.map(|((_, output), trace)| {
json!({
"Inputs": { "trace": trace },
"Generated Outputs": output,
"Feedback": "The prediction was incorrect. Revise the instructions to be more precise.",
})
})
.collect::<Vec<_>>();
dataset.insert(component.clone(), records);
}
Ok(dataset)
}
}
fn mock_classify(text: &str, instructions: &str) -> String {
let text_lower = text.to_lowercase();
let boost = instructions.to_lowercase().contains("carefully");
let positive_words = [
"great",
"good",
"love",
"excellent",
"wonderful",
"fantastic",
];
let negative_words = ["bad", "terrible", "awful", "hate", "horrible", "dreadful"];
let pos_count = positive_words
.iter()
.filter(|w| text_lower.contains(*w))
.count();
let neg_count = negative_words
.iter()
.filter(|w| text_lower.contains(*w))
.count();
if pos_count > neg_count || (boost && pos_count == neg_count) {
"positive".into()
} else {
"negative".into()
}
}
fn make_train_set() -> Vec<SentimentExample> {
vec![
SentimentExample::new("This product is great!", "positive"),
SentimentExample::new("Absolutely wonderful experience.", "positive"),
SentimentExample::new("I love the new design.", "positive"),
SentimentExample::new("Terrible quality, fell apart immediately.", "negative"),
SentimentExample::new("I hate waiting in long queues.", "negative"),
SentimentExample::new("Horrible customer service.", "negative"),
SentimentExample::new("Excellent value for money.", "positive"),
SentimentExample::new("The food was awful.", "negative"),
]
}
fn make_val_set() -> Vec<SentimentExample> {
vec![
SentimentExample::new("Fantastic performance on every task.", "positive"),
SentimentExample::new("Dreadful experience from start to finish.", "negative"),
SentimentExample::new("Good build quality and fast delivery.", "positive"),
SentimentExample::new("Bad smell, returned immediately.", "negative"),
]
}
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("gepa=info".parse().unwrap()),
)
.init();
let mut seed = Candidate::new();
seed.insert(
"instructions".into(),
"Classify the sentiment of the following text as positive or negative.".into(),
);
let trainset = Arc::new(VecLoader::new(make_train_set()));
let valset = Arc::new(VecLoader::new(make_val_set()));
let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
let (base_url, model) = if api_key.is_empty() {
eprintln!(
"OPENAI_API_KEY not set — running without live reflection calls.\n\
Set the variable to enable full optimisation."
);
(
"http://localhost:19999".to_string(),
"stub-model".to_string(),
)
} else {
(
"https://api.openai.com".to_string(),
"gpt-4o-mini".to_string(),
)
};
let lm_config = LMConfig {
model,
api_key,
base_url,
temperature: Some(1.0),
max_tokens: Some(2048),
max_retries: 2,
};
let mut config = OptimizeConfig::new(
seed,
trainset,
valset,
Arc::new(SentimentAdapter),
lm_config,
);
config.stop_condition = StopConditionConfig {
max_metric_calls: Some(50),
max_iterations: None,
timeout: None,
};
config.minibatch_size = 4;
config.use_merge = false;
config.str_candidate_key = Some("instructions".into());
println!("Starting GEPA optimisation...");
let result = optimize(config).await?;
println!("\n=== Optimisation complete ===");
println!("Candidates explored : {}", result.num_candidates());
println!(
"Total evaluate() calls : {}",
result.total_metric_calls.unwrap_or(0)
);
println!(
"Validation instances tracked : {}",
result.num_val_instances()
);
if let Some(best_str) = result.best_candidate_str() {
println!("\nBest instructions:\n {best_str}");
} else if let Ok(best_map) = result.best_candidate() {
println!("\nBest candidate:");
for (k, v) in best_map {
println!(" {k}: {v}");
}
}
if let Ok(idx) = result.best_idx() {
let score = result.val_aggregate_scores[idx];
println!("\nBest validation score: {score:.4}");
}
Ok(())
}