langfuse/datasets/
experiment.rs1use std::collections::HashMap;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use tokio::sync::Semaphore;
10
11use crate::datasets::evaluator::Evaluator;
12use crate::datasets::types::DatasetItem;
13
14#[derive(Debug, Clone)]
16pub struct ExperimentResult {
17 pub item_id: String,
19 pub output: serde_json::Value,
21 pub scores: Vec<(String, f64)>,
23 pub dataset_run_url: String,
25}
26
27impl ExperimentResult {
28 pub fn format(&self) -> String {
30 let mut summary = format!("Item: {}\n", self.item_id);
31 summary.push_str(&format!("Output: {}\n", self.output));
32 if self.scores.is_empty() {
33 summary.push_str("Scores: (none)\n");
34 } else {
35 summary.push_str("Scores:\n");
36 for (name, value) in &self.scores {
37 summary.push_str(&format!(" {name}: {value}\n"));
38 }
39 }
40 if !self.dataset_run_url.is_empty() {
41 summary.push_str(&format!("Run URL: {}\n", self.dataset_run_url));
42 }
43 summary
44 }
45}
46
47impl fmt::Display for ExperimentResult {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 write!(f, "{}", self.format())
50 }
51}
52
53pub fn format_experiment_summary(results: &[ExperimentResult]) -> String {
57 let mut summary = format!("Experiment Summary ({} items)\n", results.len());
58 summary.push_str(&"─".repeat(40));
59 summary.push('\n');
60
61 if results.is_empty() {
62 summary.push_str("No results.\n");
63 return summary;
64 }
65
66 let mut score_sums: HashMap<String, (f64, usize)> = HashMap::new();
68 for result in results {
69 for (name, value) in &result.scores {
70 let entry = score_sums.entry(name.clone()).or_insert((0.0, 0));
71 entry.0 += value;
72 entry.1 += 1;
73 }
74 }
75
76 if !score_sums.is_empty() {
77 summary.push_str("Average Scores:\n");
78 let mut names: Vec<_> = score_sums.keys().collect();
79 names.sort();
80 for name in names {
81 let (total, count) = score_sums[name];
82 let avg = total / count as f64;
83 summary.push_str(&format!(" {name}: {avg:.4} (n={count})\n"));
84 }
85 }
86
87 summary
88}
89
90#[derive(Debug, Clone)]
92pub struct ExperimentConfig {
93 pub name: String,
95 pub max_concurrency: usize,
97 pub base_url: String,
99 pub dataset_name: String,
101}
102
103impl Default for ExperimentConfig {
104 fn default() -> Self {
105 Self {
106 name: format!("experiment-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")),
107 max_concurrency: 10,
108 base_url: String::new(),
109 dataset_name: String::new(),
110 }
111 }
112}
113
114impl ExperimentConfig {
115 pub fn dataset_run_url(&self) -> String {
117 if self.base_url.is_empty() || self.dataset_name.is_empty() {
118 return String::new();
119 }
120 format!(
121 "{}/datasets/{}/runs/{}",
122 self.base_url.trim_end_matches('/'),
123 self.dataset_name,
124 self.name,
125 )
126 }
127}
128
129pub async fn run_experiment<T, E>(
137 items: Vec<DatasetItem>,
138 config: ExperimentConfig,
139 task_fn: T,
140 evaluator_fn: E,
141) -> Vec<ExperimentResult>
142where
143 T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
144 + Send
145 + Sync
146 + 'static,
147 E: Fn(&DatasetItem, &serde_json::Value) -> Vec<(String, f64)> + Send + Sync + 'static,
148{
149 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
150 let run_url = config.dataset_run_url();
151 let task_fn = Arc::new(task_fn);
152 let evaluator_fn = Arc::new(evaluator_fn);
153
154 let handles: Vec<_> = items
155 .into_iter()
156 .map(|item| {
157 let sem = semaphore.clone();
158 let task = task_fn.clone();
159 let eval = evaluator_fn.clone();
160 let url = run_url.clone();
161 tokio::spawn(async move {
162 let _permit = sem.acquire().await.expect("semaphore closed");
163 let output = task(item.clone()).await;
164 let scores = eval(&item, &output);
165 ExperimentResult {
166 item_id: item.id,
167 output,
168 scores,
169 dataset_run_url: url,
170 }
171 })
172 })
173 .collect();
174
175 let mut results = Vec::new();
176 for handle in handles {
177 if let Ok(result) = handle.await {
178 results.push(result);
179 }
180 }
181 results
182}
183
184pub async fn run_experiment_with_evaluators<T>(
194 items: Vec<DatasetItem>,
195 config: ExperimentConfig,
196 task_fn: T,
197 evaluators: Vec<Box<dyn Evaluator>>,
198) -> Vec<ExperimentResult>
199where
200 T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
201 + Send
202 + Sync
203 + 'static,
204{
205 let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
206 let run_url = config.dataset_run_url();
207 let task_fn = Arc::new(task_fn);
208 let evaluators: Arc<Vec<Box<dyn Evaluator>>> = Arc::new(evaluators);
209
210 let handles: Vec<_> = items
211 .into_iter()
212 .map(|item| {
213 let sem = semaphore.clone();
214 let task = task_fn.clone();
215 let evals = evaluators.clone();
216 let url = run_url.clone();
217 tokio::spawn(async move {
218 let _permit = sem.acquire().await.expect("semaphore closed");
219 let output = task(item.clone()).await;
220
221 let mut scores = Vec::new();
222 for evaluator in evals.iter() {
223 match evaluator
224 .evaluate(&output, item.expected_output.as_ref())
225 .await
226 {
227 Ok(evaluations) => {
228 for evaluation in evaluations {
229 let numeric = match evaluation.value {
230 langfuse_core::types::ScoreValue::Numeric(v) => v,
231 langfuse_core::types::ScoreValue::Boolean(b) => {
232 if b {
233 1.0
234 } else {
235 0.0
236 }
237 }
238 langfuse_core::types::ScoreValue::Categorical(_) => 0.0,
239 };
240 scores.push((evaluation.name, numeric));
241 }
242 }
243 Err(err) => {
244 tracing::warn!(
245 item_id = %item.id,
246 error = %err,
247 "Evaluator failed for item"
248 );
249 }
250 }
251 }
252
253 ExperimentResult {
254 item_id: item.id,
255 output,
256 scores,
257 dataset_run_url: url,
258 }
259 })
260 })
261 .collect();
262
263 let mut results = Vec::new();
264 for handle in handles {
265 if let Ok(result) = handle.await {
266 results.push(result);
267 }
268 }
269 results
270}