Skip to main content

langfuse/datasets/
manager.rs

1//! Dataset manager: CRUD operations against the Langfuse dataset API.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use langfuse_core::config::LangfuseConfig;
8use langfuse_core::error::LangfuseError;
9use serde::Deserialize;
10use tokio::sync::Semaphore;
11
12use crate::datasets::evaluator::Evaluator;
13use crate::datasets::experiment::{ExperimentConfig, ExperimentResult};
14use crate::datasets::types::{
15    CreateDatasetBody, CreateDatasetItemBody, Dataset, DatasetItem, DatasetRun,
16};
17use crate::http::retry_request;
18
19/// Wrapper for paginated dataset-items responses.
20#[derive(Debug, Deserialize)]
21struct DatasetItemsResponse {
22    data: Vec<DatasetItem>,
23}
24
25/// Wrapper for dataset-runs responses.
26#[derive(Debug, Deserialize)]
27struct DatasetRunsResponse {
28    data: Vec<DatasetRun>,
29}
30
31/// Configuration for batched evaluation runs.
32#[derive(Debug, Clone)]
33pub struct BatchedEvaluationConfig {
34    /// Maximum number of concurrent task executions.
35    pub max_concurrency: usize,
36    /// Page size for fetching dataset items.
37    pub page_size: i32,
38    /// Maximum number of retries for HTTP requests.
39    pub max_retries: usize,
40    /// Resume token: skip items with IDs lexicographically before this value.
41    pub start_after: Option<String>,
42    /// Name for the experiment run.
43    pub run_name: String,
44}
45
46impl Default for BatchedEvaluationConfig {
47    fn default() -> Self {
48        Self {
49            max_concurrency: 10,
50            page_size: 50,
51            max_retries: 3,
52            start_after: None,
53            run_name: format!("batch-eval-{}", chrono::Utc::now().format("%Y%m%d-%H%M%S")),
54        }
55    }
56}
57
58/// Manages dataset CRUD operations against the Langfuse API.
59pub struct DatasetManager {
60    config: LangfuseConfig,
61    http_client: reqwest::Client,
62}
63
64impl std::fmt::Debug for DatasetManager {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("DatasetManager")
67            .field("config", &self.config)
68            .finish()
69    }
70}
71
72impl DatasetManager {
73    /// Create a new `DatasetManager` from the given configuration.
74    pub fn new(config: &LangfuseConfig) -> Self {
75        let http_client = crate::http::build_http_client(config);
76
77        Self {
78            config: config.clone(),
79            http_client,
80        }
81    }
82
83    /// Create a new dataset.
84    pub async fn create_dataset(&self, body: CreateDatasetBody) -> Result<Dataset, LangfuseError> {
85        let url = format!("{}/datasets", self.config.api_base_url());
86
87        let resp = self
88            .http_client
89            .post(&url)
90            .header("Authorization", self.config.basic_auth_header())
91            .json(&body)
92            .send()
93            .await?;
94
95        self.handle_response(resp).await
96    }
97
98    /// Get a dataset by name.
99    pub async fn get_dataset(&self, name: &str) -> Result<Dataset, LangfuseError> {
100        let url = format!("{}/datasets/{}", self.config.api_base_url(), name);
101
102        let resp = self
103            .http_client
104            .get(&url)
105            .header("Authorization", self.config.basic_auth_header())
106            .send()
107            .await?;
108
109        self.handle_response(resp).await
110    }
111
112    /// Delete a dataset by name.
113    ///
114    /// Sends `DELETE /api/public/datasets/{name}` with retry logic for
115    /// transient failures.
116    pub async fn delete_dataset(&self, name: &str) -> Result<(), LangfuseError> {
117        let url = format!("{}/datasets/{}", self.config.api_base_url(), name);
118        let client = self.http_client.clone();
119        let auth = self.config.basic_auth_header();
120
121        retry_request(3, || {
122            let url = url.clone();
123            let client = client.clone();
124            let auth = auth.clone();
125            async move {
126                let resp = client
127                    .delete(&url)
128                    .header("Authorization", auth)
129                    .send()
130                    .await?;
131
132                let status = resp.status();
133                if status == reqwest::StatusCode::UNAUTHORIZED {
134                    return Err(LangfuseError::Auth);
135                }
136                if !status.is_success() {
137                    let message = resp.text().await.unwrap_or_default();
138                    return Err(LangfuseError::Api {
139                        status: status.as_u16(),
140                        message,
141                    });
142                }
143                Ok(())
144            }
145        })
146        .await
147    }
148
149    /// Create a dataset item.
150    pub async fn create_item(
151        &self,
152        body: CreateDatasetItemBody,
153    ) -> Result<DatasetItem, LangfuseError> {
154        let url = format!("{}/dataset-items", self.config.api_base_url());
155
156        let resp = self
157            .http_client
158            .post(&url)
159            .header("Authorization", self.config.basic_auth_header())
160            .json(&body)
161            .send()
162            .await?;
163
164        self.handle_response(resp).await
165    }
166
167    /// Get dataset items (paginated).
168    pub async fn get_items(
169        &self,
170        dataset_name: &str,
171        page: Option<i32>,
172        limit: Option<i32>,
173    ) -> Result<Vec<DatasetItem>, LangfuseError> {
174        let url = format!("{}/dataset-items", self.config.api_base_url());
175
176        let mut req = self
177            .http_client
178            .get(&url)
179            .header("Authorization", self.config.basic_auth_header())
180            .query(&[("datasetName", dataset_name)]);
181
182        if let Some(p) = page {
183            req = req.query(&[("page", p.to_string())]);
184        }
185        if let Some(l) = limit {
186            req = req.query(&[("limit", l.to_string())]);
187        }
188
189        let resp = req.send().await?;
190        let items_resp: DatasetItemsResponse = self.handle_response(resp).await?;
191        Ok(items_resp.data)
192    }
193
194    /// Get dataset runs.
195    pub async fn get_runs(&self, dataset_name: &str) -> Result<Vec<DatasetRun>, LangfuseError> {
196        let url = format!(
197            "{}/datasets/{}/runs",
198            self.config.api_base_url(),
199            dataset_name
200        );
201
202        let resp = self
203            .http_client
204            .get(&url)
205            .header("Authorization", self.config.basic_auth_header())
206            .send()
207            .await?;
208
209        let runs_resp: DatasetRunsResponse = self.handle_response(resp).await?;
210        Ok(runs_resp.data)
211    }
212
213    /// Delete a dataset run.
214    ///
215    /// Sends `DELETE /api/public/datasets/{dataset_name}/runs/{run_name}` with
216    /// retry logic for transient failures.
217    pub async fn delete_run(
218        &self,
219        dataset_name: &str,
220        run_name: &str,
221    ) -> Result<(), LangfuseError> {
222        let url = format!(
223            "{}/datasets/{}/runs/{}",
224            self.config.api_base_url(),
225            dataset_name,
226            run_name,
227        );
228        let client = self.http_client.clone();
229        let auth = self.config.basic_auth_header();
230
231        retry_request(3, || {
232            let url = url.clone();
233            let client = client.clone();
234            let auth = auth.clone();
235            async move {
236                let resp = client
237                    .delete(&url)
238                    .header("Authorization", auth)
239                    .send()
240                    .await?;
241
242                let status = resp.status();
243                if status == reqwest::StatusCode::UNAUTHORIZED {
244                    return Err(LangfuseError::Auth);
245                }
246                if !status.is_success() {
247                    let message = resp.text().await.unwrap_or_default();
248                    return Err(LangfuseError::Api {
249                        status: status.as_u16(),
250                        message,
251                    });
252                }
253                Ok(())
254            }
255        })
256        .await
257    }
258
259    /// Run a batched evaluation over all items in a dataset.
260    ///
261    /// Fetches dataset items in pages, executes the task function on each item
262    /// with bounded concurrency, runs evaluators, and collects results.
263    ///
264    /// If `config.start_after` is set, items with IDs lexicographically before
265    /// that value are skipped (resume token support).
266    pub async fn run_batched_evaluation<T>(
267        &self,
268        dataset_name: &str,
269        batch_config: BatchedEvaluationConfig,
270        task_fn: T,
271        evaluators: Vec<Box<dyn Evaluator>>,
272    ) -> Result<Vec<ExperimentResult>, LangfuseError>
273    where
274        T: Fn(DatasetItem) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
275            + Send
276            + Sync
277            + 'static,
278    {
279        let experiment_config = ExperimentConfig {
280            name: batch_config.run_name,
281            max_concurrency: batch_config.max_concurrency,
282            base_url: self.config.base_url.clone(),
283            dataset_name: dataset_name.to_string(),
284        };
285
286        // Fetch all items page by page
287        let mut all_items = Vec::new();
288        let mut page = 1;
289        loop {
290            let items = self
291                .get_items(dataset_name, Some(page), Some(batch_config.page_size))
292                .await?;
293            let fetched = items.len();
294            all_items.extend(items);
295            if (fetched as i32) < batch_config.page_size {
296                break;
297            }
298            page += 1;
299        }
300
301        // Apply start_after filter (resume token)
302        if let Some(ref start_after) = batch_config.start_after {
303            all_items.retain(|item| item.id.as_str() > start_after.as_str());
304        }
305
306        // Run the experiment with evaluators
307        let semaphore = Arc::new(Semaphore::new(experiment_config.max_concurrency));
308        let run_url = experiment_config.dataset_run_url();
309        let task_fn = Arc::new(task_fn);
310        let evaluators: Arc<Vec<Box<dyn Evaluator>>> = Arc::new(evaluators);
311
312        let handles: Vec<_> = all_items
313            .into_iter()
314            .map(|item| {
315                let sem = semaphore.clone();
316                let task = task_fn.clone();
317                let evals = evaluators.clone();
318                let url = run_url.clone();
319                tokio::spawn(async move {
320                    let _permit = sem.acquire().await.expect("semaphore closed");
321                    let output = task(item.clone()).await;
322
323                    let mut scores = Vec::new();
324                    for evaluator in evals.iter() {
325                        match evaluator
326                            .evaluate(&output, item.expected_output.as_ref())
327                            .await
328                        {
329                            Ok(evaluations) => {
330                                for evaluation in evaluations {
331                                    let numeric = match evaluation.value {
332                                        langfuse_core::types::ScoreValue::Numeric(v) => v,
333                                        langfuse_core::types::ScoreValue::Boolean(b) => {
334                                            if b {
335                                                1.0
336                                            } else {
337                                                0.0
338                                            }
339                                        }
340                                        langfuse_core::types::ScoreValue::Categorical(_) => 0.0,
341                                    };
342                                    scores.push((evaluation.name, numeric));
343                                }
344                            }
345                            Err(err) => {
346                                tracing::warn!(
347                                    item_id = %item.id,
348                                    error = %err,
349                                    "Evaluator failed for item in batched evaluation"
350                                );
351                            }
352                        }
353                    }
354
355                    ExperimentResult {
356                        item_id: item.id,
357                        output,
358                        scores,
359                        dataset_run_url: url,
360                    }
361                })
362            })
363            .collect();
364
365        let mut results = Vec::new();
366        for handle in handles {
367            if let Ok(result) = handle.await {
368                results.push(result);
369            }
370        }
371
372        Ok(results)
373    }
374
375    /// Handle an HTTP response: check status, parse JSON body.
376    async fn handle_response<T: serde::de::DeserializeOwned>(
377        &self,
378        resp: reqwest::Response,
379    ) -> Result<T, LangfuseError> {
380        let status = resp.status();
381
382        if status == reqwest::StatusCode::UNAUTHORIZED {
383            return Err(LangfuseError::Auth);
384        }
385        if !status.is_success() {
386            let message = resp.text().await.unwrap_or_default();
387            return Err(LangfuseError::Api {
388                status: status.as_u16(),
389                message,
390            });
391        }
392
393        let body = resp.json::<T>().await?;
394        Ok(body)
395    }
396}