1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use langfuse_core::config::LangfuseConfig;
8use langfuse_core::error::LangfuseError;
9use serde::Deserialize;
10use tokio::sync::Semaphore;
11
12use crate::datasets::evaluator::Evaluator;
13use crate::datasets::experiment::{ExperimentConfig, ExperimentResult};
14use crate::datasets::types::{
15 CreateDatasetBody, CreateDatasetItemBody, Dataset, DatasetItem, DatasetRun,
16};
17use crate::http::retry_request;
18
19#[derive(Debug, Deserialize)]
21struct DatasetItemsResponse {
22 data: Vec<DatasetItem>,
23}
24
25#[derive(Debug, Deserialize)]
27struct DatasetRunsResponse {
28 data: Vec<DatasetRun>,
29}
30
31#[derive(Debug, Clone)]
33pub struct BatchedEvaluationConfig {
34 pub max_concurrency: usize,
36 pub page_size: i32,
38 pub max_retries: usize,
40 pub start_after: Option<String>,
42 pub run_name: String,
44}
45
46impl Default for BatchedEvaluationConfig {
47 fn default() -> Self {
48 Self {
49 max_concurrency: 10,
50 page_size: 50,
51 max_retries: 3,
52 start_after: None,
53 run_name: format!("batch-eval-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")),
54 }
55 }
56}
57
58pub struct DatasetManager {
60 config: LangfuseConfig,
61 http_client: reqwest::Client,
62}
63
64impl std::fmt::Debug for DatasetManager {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("DatasetManager")
67 .field("config", &self.config)
68 .finish()
69 }
70}
71
72impl DatasetManager {
73 pub fn new(config: &LangfuseConfig) -> Self {
75 let http_client = crate::http::build_http_client(config);
76
77 Self {
78 config: config.clone(),
79 http_client,
80 }
81 }
82
83 pub async fn create_dataset(&self, body: CreateDatasetBody) -> Result<Dataset, LangfuseError> {
85 let url = format!("{}/datasets", self.config.api_base_url());
86
87 let resp = self
88 .http_client
89 .post(&url)
90 .header("Authorization", self.config.basic_auth_header())
91 .json(&body)
92 .send()
93 .await?;
94
95 self.handle_response(resp).await
96 }
97
98 pub async fn get_dataset(&self, name: &str) -> Result<Dataset, LangfuseError> {
100 let url = format!("{}/datasets/{}", self.config.api_base_url(), name);
101
102 let resp = self
103 .http_client
104 .get(&url)
105 .header("Authorization", self.config.basic_auth_header())
106 .send()
107 .await?;
108
109 self.handle_response(resp).await
110 }
111
112 pub async fn delete_dataset(&self, name: &str) -> Result<(), LangfuseError> {
117 let url = format!("{}/datasets/{}", self.config.api_base_url(), name);
118 let client = self.http_client.clone();
119 let auth = self.config.basic_auth_header();
120
121 retry_request(3, || {
122 let url = url.clone();
123 let client = client.clone();
124 let auth = auth.clone();
125 async move {
126 let resp = client
127 .delete(&url)
128 .header("Authorization", auth)
129 .send()
130 .await?;
131
132 let status = resp.status();
133 if status == reqwest::StatusCode::UNAUTHORIZED {
134 return Err(LangfuseError::Auth);
135 }
136 if !status.is_success() {
137 let message = resp.text().await.unwrap_or_default();
138 return Err(LangfuseError::Api {
139 status: status.as_u16(),
140 message,
141 });
142 }
143 Ok(())
144 }
145 })
146 .await
147 }
148
149 pub async fn create_item(
151 &self,
152 body: CreateDatasetItemBody,
153 ) -> Result<DatasetItem, LangfuseError> {
154 let url = format!("{}/dataset-items", self.config.api_base_url());
155
156 let resp = self
157 .http_client
158 .post(&url)
159 .header("Authorization", self.config.basic_auth_header())
160 .json(&body)
161 .send()
162 .await?;
163
164 self.handle_response(resp).await
165 }
166
167 pub async fn get_items(
169 &self,
170 dataset_name: &str,
171 page: Option<i32>,
172 limit: Option<i32>,
173 ) -> Result<Vec<DatasetItem>, LangfuseError> {
174 let url = format!("{}/dataset-items", self.config.api_base_url());
175
176 let mut req = self
177 .http_client
178 .get(&url)
179 .header("Authorization", self.config.basic_auth_header())
180 .query(&[("datasetName", dataset_name)]);
181
182 if let Some(p) = page {
183 req = req.query(&[("page", p.to_string())]);
184 }
185 if let Some(l) = limit {
186 req = req.query(&[("limit", l.to_string())]);
187 }
188
189 let resp = req.send().await?;
190 let items_resp: DatasetItemsResponse = self.handle_response(resp).await?;
191 Ok(items_resp.data)
192 }
193
194 pub async fn get_runs(&self, dataset_name: &str) -> Result<Vec<DatasetRun>, LangfuseError> {
196 let url = format!(
197 "{}/datasets/{}/runs",
198 self.config.api_base_url(),
199 dataset_name
200 );
201
202 let resp = self
203 .http_client
204 .get(&url)
205 .header("Authorization", self.config.basic_auth_header())
206 .send()
207 .await?;
208
209 let runs_resp: DatasetRunsResponse = self.handle_response(resp).await?;
210 Ok(runs_resp.data)
211 }
212
213 pub async fn delete_run(
218 &self,
219 dataset_name: &str,
220 run_name: &str,
221 ) -> Result<(), LangfuseError> {
222 let url = format!(
223 "{}/datasets/{}/runs/{}",
224 self.config.api_base_url(),
225 dataset_name,
226 run_name,
227 );
228 let client = self.http_client.clone();
229 let auth = self.config.basic_auth_header();
230
231 retry_request(3, || {
232 let url = url.clone();
233 let client = client.clone();
234 let auth = auth.clone();
235 async move {
236 let resp = client
237 .delete(&url)
238 .header("Authorization", auth)
239 .send()
240 .await?;
241
242 let status = resp.status();
243 if status == reqwest::StatusCode::UNAUTHORIZED {
244 return Err(LangfuseError::Auth);
245 }
246 if !status.is_success() {
247 let message = resp.text().await.unwrap_or_default();
248 return Err(LangfuseError::Api {
249 status: status.as_u16(),
250 message,
251 });
252 }
253 Ok(())
254 }
255 })
256 .await
257 }
258
259 pub async fn run_batched_evaluation<T>(
267 &self,
268 dataset_name: &str,
269 batch_config: BatchedEvaluationConfig,
270 task_fn: T,
271 evaluators: Vec<Box<dyn Evaluator>>,
272 ) -> Result<Vec<ExperimentResult>, LangfuseError>
273 where
274 T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
275 + Send
276 + Sync
277 + 'static,
278 {
279 let experiment_config = ExperimentConfig {
280 name: batch_config.run_name,
281 max_concurrency: batch_config.max_concurrency,
282 base_url: self.config.base_url.clone(),
283 dataset_name: dataset_name.to_string(),
284 };
285
286 let mut all_items = Vec::new();
288 let mut page = 1;
289 loop {
290 let items = self
291 .get_items(dataset_name, Some(page), Some(batch_config.page_size))
292 .await?;
293 let fetched = items.len();
294 all_items.extend(items);
295 if (fetched as i32) < batch_config.page_size {
296 break;
297 }
298 page += 1;
299 }
300
301 if let Some(ref start_after) = batch_config.start_after {
303 all_items.retain(|item| item.id.as_str() > start_after.as_str());
304 }
305
306 let semaphore = Arc::new(Semaphore::new(experiment_config.max_concurrency));
308 let run_url = experiment_config.dataset_run_url();
309 let task_fn = Arc::new(task_fn);
310 let evaluators: Arc<Vec<Box<dyn Evaluator>>> = Arc::new(evaluators);
311
312 let handles: Vec<_> = all_items
313 .into_iter()
314 .map(|item| {
315 let sem = semaphore.clone();
316 let task = task_fn.clone();
317 let evals = evaluators.clone();
318 let url = run_url.clone();
319 tokio::spawn(async move {
320 let _permit = sem.acquire().await.expect("semaphore closed");
321 let output = task(item.clone()).await;
322
323 let mut scores = Vec::new();
324 for evaluator in evals.iter() {
325 match evaluator
326 .evaluate(&output, item.expected_output.as_ref())
327 .await
328 {
329 Ok(evaluations) => {
330 for evaluation in evaluations {
331 let numeric = match evaluation.value {
332 langfuse_core::types::ScoreValue::Numeric(v) => v,
333 langfuse_core::types::ScoreValue::Boolean(b) => {
334 if b {
335 1.0
336 } else {
337 0.0
338 }
339 }
340 langfuse_core::types::ScoreValue::Categorical(_) => 0.0,
341 };
342 scores.push((evaluation.name, numeric));
343 }
344 }
345 Err(err) => {
346 tracing::warn!(
347 item_id = %item.id,
348 error = %err,
349 "Evaluator failed for item in batched evaluation"
350 );
351 }
352 }
353 }
354
355 ExperimentResult {
356 item_id: item.id,
357 output,
358 scores,
359 dataset_run_url: url,
360 }
361 })
362 })
363 .collect();
364
365 let mut results = Vec::new();
366 for handle in handles {
367 if let Ok(result) = handle.await {
368 results.push(result);
369 }
370 }
371
372 Ok(results)
373 }
374
375 async fn handle_response<T: serde::de::DeserializeOwned>(
377 &self,
378 resp: reqwest::Response,
379 ) -> Result<T, LangfuseError> {
380 let status = resp.status();
381
382 if status == reqwest::StatusCode::UNAUTHORIZED {
383 return Err(LangfuseError::Auth);
384 }
385 if !status.is_success() {
386 let message = resp.text().await.unwrap_or_default();
387 return Err(LangfuseError::Api {
388 status: status.as_u16(),
389 message,
390 });
391 }
392
393 let body = resp.json::<T>().await?;
394 Ok(body)
395 }
396}