Skip to main content

langfuse/datasets/
experiment.rs

1//! Experiment runner: execute a task function on each dataset item, then evaluate.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use tokio::sync::Semaphore;
10
11use crate::datasets::evaluator::Evaluator;
12use crate::datasets::types::DatasetItem;
13
14/// Result of running an experiment task on a single dataset item.
15#[derive(Debug, Clone)]
16pub struct ExperimentResult {
17    /// ID of the dataset item that was processed.
18    pub item_id: String,
19    /// Output value produced by the task function.
20    pub output: serde_json::Value,
21    /// List of `(metric_name, score)` pairs from evaluators.
22    pub scores: Vec<(String, f64)>,
23    /// URL to the dataset run in the Langfuse UI.
24    pub dataset_run_url: String,
25}
26
27impl ExperimentResult {
28    /// Format a summary of this experiment result.
29    pub fn format(&self) -> String {
30        let mut summary = format!("Item: {}\n", self.item_id);
31        summary.push_str(&format!("Output: {}\n", self.output));
32        if self.scores.is_empty() {
33            summary.push_str("Scores: (none)\n");
34        } else {
35            summary.push_str("Scores:\n");
36            for (name, value) in &self.scores {
37                summary.push_str(&format!("  {name}: {value}\n"));
38            }
39        }
40        if !self.dataset_run_url.is_empty() {
41            summary.push_str(&format!("Run URL: {}\n", self.dataset_run_url));
42        }
43        summary
44    }
45}
46
47impl fmt::Display for ExperimentResult {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        write!(f, "{}", self.format())
50    }
51}
52
53/// Format a summary of multiple experiment results.
54///
55/// Shows total count, per-metric averages, and individual item scores.
56pub fn format_experiment_summary(results: &[ExperimentResult]) -> String {
57    let mut summary = format!("Experiment Summary ({} items)\n", results.len());
58    summary.push_str(&"─".repeat(40));
59    summary.push('\n');
60
61    if results.is_empty() {
62        summary.push_str("No results.\n");
63        return summary;
64    }
65
66    // Aggregate scores by name
67    let mut score_sums: HashMap<String, (f64, usize)> = HashMap::new();
68    for result in results {
69        for (name, value) in &result.scores {
70            let entry = score_sums.entry(name.clone()).or_insert((0.0, 0));
71            entry.0 += value;
72            entry.1 += 1;
73        }
74    }
75
76    if !score_sums.is_empty() {
77        summary.push_str("Average Scores:\n");
78        let mut names: Vec<_> = score_sums.keys().collect();
79        names.sort();
80        for name in names {
81            let (total, count) = score_sums[name];
82            let avg = total / count as f64;
83            summary.push_str(&format!("  {name}: {avg:.4} (n={count})\n"));
84        }
85    }
86
87    summary
88}
89
90/// Configuration for running an experiment.
91#[derive(Debug, Clone)]
92pub struct ExperimentConfig {
93    /// Name of the experiment run.
94    pub name: String,
95    /// Maximum number of concurrent task executions.
96    pub max_concurrency: usize,
97    /// Base URL for constructing dataset run URLs.
98    pub base_url: String,
99    /// Dataset name for constructing dataset run URLs.
100    pub dataset_name: String,
101}
102
103impl Default for ExperimentConfig {
104    fn default() -> Self {
105        Self {
106            name: format!("experiment-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")),
107            max_concurrency: 10,
108            base_url: String::new(),
109            dataset_name: String::new(),
110        }
111    }
112}
113
114impl ExperimentConfig {
115    /// Build the dataset run URL from config fields.
116    pub fn dataset_run_url(&self) -> String {
117        if self.base_url.is_empty() || self.dataset_name.is_empty() {
118            return String::new();
119        }
120        format!(
121            "{}/datasets/{}/runs/{}",
122            self.base_url.trim_end_matches('/'),
123            self.dataset_name,
124            self.name,
125        )
126    }
127}
128
129/// Run an experiment: execute a task function on each dataset item, then evaluate.
130///
131/// The `task_fn` is called for each item to produce an output value.
132/// The `evaluator_fn` compares the output against the item (including its
133/// `expected_output`) and returns a list of named scores.
134///
135/// Concurrency is bounded by [`ExperimentConfig::max_concurrency`].
136pub async fn run_experiment<T, E>(
137    items: Vec<DatasetItem>,
138    config: ExperimentConfig,
139    task_fn: T,
140    evaluator_fn: E,
141) -> Vec<ExperimentResult>
142where
143    T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
144        + Send
145        + Sync
146        + 'static,
147    E: Fn(&DatasetItem, &serde_json::Value) -> Vec<(String, f64)> + Send + Sync + 'static,
148{
149    let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
150    let run_url = config.dataset_run_url();
151    let task_fn = Arc::new(task_fn);
152    let evaluator_fn = Arc::new(evaluator_fn);
153
154    let handles: Vec<_> = items
155        .into_iter()
156        .map(|item| {
157            let sem = semaphore.clone();
158            let task = task_fn.clone();
159            let eval = evaluator_fn.clone();
160            let url = run_url.clone();
161            tokio::spawn(async move {
162                let _permit = sem.acquire().await.expect("semaphore closed");
163                let output = task(item.clone()).await;
164                let scores = eval(&item, &output);
165                ExperimentResult {
166                    item_id: item.id,
167                    output,
168                    scores,
169                    dataset_run_url: url,
170                }
171            })
172        })
173        .collect();
174
175    let mut results = Vec::new();
176    for handle in handles {
177        if let Ok(result) = handle.await {
178            results.push(result);
179        }
180    }
181    results
182}
183
184/// Run an experiment with trait-based evaluators.
185///
186/// Similar to [`run_experiment`], but accepts a list of [`Evaluator`] trait
187/// objects instead of a simple closure. Each evaluator is called after the
188/// task function, and all evaluation results are converted to `(name, f64)`
189/// score tuples.
190///
191/// The original `evaluator_fn` is still called first (if provided), then
192/// each trait evaluator is invoked.
193pub async fn run_experiment_with_evaluators<T>(
194    items: Vec<DatasetItem>,
195    config: ExperimentConfig,
196    task_fn: T,
197    evaluators: Vec<Box<dyn Evaluator>>,
198) -> Vec<ExperimentResult>
199where
200    T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
201        + Send
202        + Sync
203        + 'static,
204{
205    let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
206    let run_url = config.dataset_run_url();
207    let task_fn = Arc::new(task_fn);
208    let evaluators: Arc<Vec<Box<dyn Evaluator>>> = Arc::new(evaluators);
209
210    let handles: Vec<_> = items
211        .into_iter()
212        .map(|item| {
213            let sem = semaphore.clone();
214            let task = task_fn.clone();
215            let evals = evaluators.clone();
216            let url = run_url.clone();
217            tokio::spawn(async move {
218                let _permit = sem.acquire().await.expect("semaphore closed");
219                let output = task(item.clone()).await;
220
221                let mut scores = Vec::new();
222                for evaluator in evals.iter() {
223                    match evaluator
224                        .evaluate(&output, item.expected_output.as_ref())
225                        .await
226                    {
227                        Ok(evaluations) => {
228                            for evaluation in evaluations {
229                                let numeric = match evaluation.value {
230                                    langfuse_core::types::ScoreValue::Numeric(v) => v,
231                                    langfuse_core::types::ScoreValue::Boolean(b) => {
232                                        if b {
233                                            1.0
234                                        } else {
235                                            0.0
236                                        }
237                                    }
238                                    langfuse_core::types::ScoreValue::Categorical(_) => 0.0,
239                                };
240                                scores.push((evaluation.name, numeric));
241                            }
242                        }
243                        Err(err) => {
244                            tracing::warn!(
245                                item_id = %item.id,
246                                error = %err,
247                                "Evaluator failed for item"
248                            );
249                        }
250                    }
251                }
252
253                ExperimentResult {
254                    item_id: item.id,
255                    output,
256                    scores,
257                    dataset_run_url: url,
258                }
259            })
260        })
261        .collect();
262
263    let mut results = Vec::new();
264    for handle in handles {
265        if let Ok(result) = handle.await {
266            results.push(result);
267        }
268    }
269    results
270}