use std::collections::HashMap;
use async_trait::async_trait;
use indexmap::IndexMap;
use serde::Serialize;
use serde_json::Value;
pub type Candidate = IndexMap<String, String>;
#[derive(Debug, Clone)]
pub struct EvaluationBatch<T, RO> {
pub outputs: Vec<RO>,
pub scores: Vec<f64>,
pub trajectories: Option<Vec<T>>,
pub objective_scores: Option<Vec<HashMap<String, f64>>>,
}
impl<T, RO> EvaluationBatch<T, RO> {
pub fn new(outputs: Vec<RO>, scores: Vec<f64>) -> Self {
Self {
outputs,
scores,
trajectories: None,
objective_scores: None,
}
}
pub fn with_trajectories(mut self, trajectories: Vec<T>) -> Self {
self.trajectories = Some(trajectories);
self
}
pub fn with_objective_scores(mut self, objective_scores: Vec<HashMap<String, f64>>) -> Self {
self.objective_scores = Some(objective_scores);
self
}
pub fn validate_lengths(
&self,
expected_len: usize,
require_trajectories: bool,
) -> crate::error::Result<()> {
if self.outputs.len() != expected_len {
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned {} outputs for a batch of {expected_len}",
self.outputs.len()
)));
}
if self.scores.len() != expected_len {
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned {} scores for a batch of {expected_len}",
self.scores.len()
)));
}
if let Some((idx, score)) = self
.scores
.iter()
.enumerate()
.find(|(_, score)| !score.is_finite())
{
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned non-finite score at index {idx}: {score}"
)));
}
match (&self.trajectories, require_trajectories) {
(Some(trajectories), _) if trajectories.len() != expected_len => {
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned {} trajectories for a batch of {expected_len}",
trajectories.len()
)));
}
(None, true) => {
return Err(crate::error::GEPAError::Evaluation(
"adapter did not return trajectories for a trace-capturing evaluation".into(),
));
}
_ => {}
}
if let Some(objective_scores) = &self.objective_scores {
if objective_scores.len() != expected_len {
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned {} objective-score rows for a batch of {expected_len}",
objective_scores.len()
)));
}
for (idx, objectives) in objective_scores.iter().enumerate() {
if let Some((name, score)) = objectives.iter().find(|(_, score)| !score.is_finite())
{
return Err(crate::error::GEPAError::Evaluation(format!(
"adapter returned non-finite objective score at index {idx} for '{name}': {score}"
)));
}
}
}
Ok(())
}
}
impl<T, RO> EvaluationBatch<T, RO>
where
RO: Serialize,
{
pub fn outputs_as_json(&self) -> crate::error::Result<Vec<Value>> {
self.outputs
.iter()
.map(serde_json::to_value)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(Into::into)
}
}
pub type ReflectiveDataset = HashMap<String, Vec<Value>>;
#[async_trait]
pub trait GEPAAdapter<DataInst, T, RO>: Send + Sync
where
DataInst: Send,
T: Send,
RO: Send + Serialize,
{
async fn evaluate(
&self,
batch: &[DataInst],
candidate: &Candidate,
capture_traces: bool,
) -> crate::error::Result<EvaluationBatch<T, RO>>;
async fn make_reflective_dataset(
&self,
candidate: &Candidate,
eval_batch: &EvaluationBatch<T, RO>,
components_to_update: &[String],
) -> crate::error::Result<ReflectiveDataset>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
struct EchoAdapter;
#[async_trait]
impl GEPAAdapter<String, (), String> for EchoAdapter {
async fn evaluate(
&self,
batch: &[String],
_candidate: &Candidate,
_capture_traces: bool,
) -> Result<EvaluationBatch<(), String>> {
let n = batch.len();
Ok(EvaluationBatch::new(batch.to_vec(), vec![1.0_f64; n]))
}
async fn make_reflective_dataset(
&self,
_candidate: &Candidate,
_eval_batch: &EvaluationBatch<(), String>,
components_to_update: &[String],
) -> Result<ReflectiveDataset> {
Ok(components_to_update
.iter()
.map(|name| (name.clone(), vec![]))
.collect())
}
}
#[tokio::test]
async fn evaluation_batch_round_trip() {
let adapter = EchoAdapter;
let mut candidate = Candidate::new();
candidate.insert("instructions".into(), "Do the task.".into());
let batch = vec!["hello".to_string(), "world".to_string()];
let result = adapter
.evaluate(&batch, &candidate, false)
.await
.expect("evaluation should succeed");
assert_eq!(result.outputs.len(), 2);
assert_eq!(result.scores, vec![1.0, 1.0]);
assert!(result.trajectories.is_none());
assert!(result.objective_scores.is_none());
}
#[tokio::test]
async fn make_reflective_dataset_covers_requested_components() {
let adapter = EchoAdapter;
let candidate = Candidate::new();
let batch = EvaluationBatch::new(vec!["out".to_string()], vec![0.5]);
let components = vec!["instructions".to_string(), "refiner".to_string()];
let dataset = adapter
.make_reflective_dataset(&candidate, &batch, &components)
.await
.expect("reflective dataset should succeed");
assert!(dataset.contains_key("instructions"));
assert!(dataset.contains_key("refiner"));
}
#[test]
fn evaluation_batch_builder_methods() {
let batch: EvaluationBatch<String, i32> = EvaluationBatch::new(vec![1, 2], vec![0.5, 0.8])
.with_trajectories(vec!["t1".into(), "t2".into()])
.with_objective_scores(vec![
HashMap::from([("accuracy".into(), 0.9_f64)]),
HashMap::from([("accuracy".into(), 0.7_f64)]),
]);
assert_eq!(batch.scores.len(), 2);
assert!(batch.trajectories.is_some());
assert!(batch.objective_scores.is_some());
}
#[test]
fn evaluation_batch_with_json_values() {
let scores: Vec<f64> = vec![0.8, 0.9];
let outputs: Vec<serde_json::Value> = vec![
serde_json::json!({"answer": "yes"}),
serde_json::json!({"answer": "no"}),
];
let batch: EvaluationBatch<(), serde_json::Value> = EvaluationBatch::new(outputs, scores);
assert_eq!(batch.scores[0], 0.8);
}
}