use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::datasets::evaluator::Evaluator;
use crate::datasets::types::DatasetItem;
#[derive(Debug, Clone)]
pub struct ExperimentResult {
pub item_id: String,
pub output: serde_json::Value,
pub scores: Vec<(String, f64)>,
pub dataset_run_url: String,
}
impl ExperimentResult {
pub fn format(&self) -> String {
let mut summary = format!("Item: {}\n", self.item_id);
summary.push_str(&format!("Output: {}\n", self.output));
if self.scores.is_empty() {
summary.push_str("Scores: (none)\n");
} else {
summary.push_str("Scores:\n");
for (name, value) in &self.scores {
summary.push_str(&format!(" {name}: {value}\n"));
}
}
if !self.dataset_run_url.is_empty() {
summary.push_str(&format!("Run URL: {}\n", self.dataset_run_url));
}
summary
}
}
impl fmt::Display for ExperimentResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.format())
}
}
pub fn format_experiment_summary(results: &[ExperimentResult]) -> String {
let mut summary = format!("Experiment Summary ({} items)\n", results.len());
summary.push_str(&"─".repeat(40));
summary.push('\n');
if results.is_empty() {
summary.push_str("No results.\n");
return summary;
}
let mut score_sums: HashMap<String, (f64, usize)> = HashMap::new();
for result in results {
for (name, value) in &result.scores {
let entry = score_sums.entry(name.clone()).or_insert((0.0, 0));
entry.0 += value;
entry.1 += 1;
}
}
if !score_sums.is_empty() {
summary.push_str("Average Scores:\n");
let mut names: Vec<_> = score_sums.keys().collect();
names.sort();
for name in names {
let (total, count) = score_sums[name];
let avg = total / count as f64;
summary.push_str(&format!(" {name}: {avg:.4} (n={count})\n"));
}
}
summary
}
#[derive(Debug, Clone)]
pub struct ExperimentConfig {
pub name: String,
pub max_concurrency: usize,
pub base_url: String,
pub dataset_name: String,
}
impl Default for ExperimentConfig {
fn default() -> Self {
Self {
name: format!("experiment-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")),
max_concurrency: 10,
base_url: String::new(),
dataset_name: String::new(),
}
}
}
impl ExperimentConfig {
pub fn dataset_run_url(&self) -> String {
if self.base_url.is_empty() || self.dataset_name.is_empty() {
return String::new();
}
format!(
"{}/datasets/{}/runs/{}",
self.base_url.trim_end_matches('/'),
self.dataset_name,
self.name,
)
}
}
pub async fn run_experiment<T, E>(
items: Vec<DatasetItem>,
config: ExperimentConfig,
task_fn: T,
evaluator_fn: E,
) -> Vec<ExperimentResult>
where
T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
+ Send
+ Sync
+ 'static,
E: Fn(&DatasetItem, &serde_json::Value) -> Vec<(String, f64)> + Send + Sync + 'static,
{
let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
let run_url = config.dataset_run_url();
let task_fn = Arc::new(task_fn);
let evaluator_fn = Arc::new(evaluator_fn);
let handles: Vec<_> = items
.into_iter()
.map(|item| {
let sem = semaphore.clone();
let task = task_fn.clone();
let eval = evaluator_fn.clone();
let url = run_url.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
let output = task(item.clone()).await;
let scores = eval(&item, &output);
ExperimentResult {
item_id: item.id,
output,
scores,
dataset_run_url: url,
}
})
})
.collect();
let mut results = Vec::new();
for handle in handles {
if let Ok(result) = handle.await {
results.push(result);
}
}
results
}
pub async fn run_experiment_with_evaluators<T>(
items: Vec<DatasetItem>,
config: ExperimentConfig,
task_fn: T,
evaluators: Vec<Box<dyn Evaluator>>,
) -> Vec<ExperimentResult>
where
T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
+ Send
+ Sync
+ 'static,
{
let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
let run_url = config.dataset_run_url();
let task_fn = Arc::new(task_fn);
let evaluators: Arc<Vec<Box<dyn Evaluator>>> = Arc::new(evaluators);
let handles: Vec<_> = items
.into_iter()
.map(|item| {
let sem = semaphore.clone();
let task = task_fn.clone();
let evals = evaluators.clone();
let url = run_url.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
let output = task(item.clone()).await;
let mut scores = Vec::new();
for evaluator in evals.iter() {
match evaluator
.evaluate(&output, item.expected_output.as_ref())
.await
{
Ok(evaluations) => {
for evaluation in evaluations {
let numeric = match evaluation.value {
langfuse_core::types::ScoreValue::Numeric(v) => v,
langfuse_core::types::ScoreValue::Boolean(b) => {
if b {
1.0
} else {
0.0
}
}
langfuse_core::types::ScoreValue::Categorical(_) => 0.0,
};
scores.push((evaluation.name, numeric));
}
}
Err(err) => {
tracing::warn!(
item_id = %item.id,
error = %err,
"Evaluator failed for item"
);
}
}
}
ExperimentResult {
item_id: item.id,
output,
scores,
dataset_run_url: url,
}
})
})
.collect();
let mut results = Vec::new();
for handle in handles {
if let Ok(result) = handle.await {
results.push(result);
}
}
results
}