1use std::sync::Arc;
8
9use tokio::sync::Mutex;
10use tracing::{debug, error, warn};
11
12use crate::langfuse::LangfuseExporter;
13use crate::models::{Observation, Session, Trace};
14use crate::trace_store::{StoreError, TraceStore};
15
16const DEFAULT_BATCH_SIZE: usize = 50;
18const DEFAULT_FLUSH_INTERVAL_MS: u64 = 5_000;
20
21#[derive(Clone, Debug)]
23pub enum TelemetryItem {
24 Trace(Trace),
26 Observation(Observation),
28 Session(Session),
30}
31
32#[derive(Clone)]
40pub struct BatchWriter {
41 inner: Arc<BatchWriterInner>,
42}
43
44impl std::fmt::Debug for BatchWriter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("BatchWriter")
47 .field("batch_size", &self.inner.batch_size)
48 .finish()
49 }
50}
51
52struct BatchWriterInner {
53 buffer: Mutex<Vec<TelemetryItem>>,
54 store: Arc<dyn TraceStore>,
55 langfuse: Option<LangfuseExporter>,
56 batch_size: usize,
57 shutdown: Mutex<bool>,
58}
59
60impl BatchWriter {
61 #[must_use]
65 pub fn new(store: Arc<dyn TraceStore>) -> Self {
66 Self::with_config(store, DEFAULT_BATCH_SIZE, DEFAULT_FLUSH_INTERVAL_MS)
67 }
68
69 #[must_use]
71 pub fn with_config(
72 store: Arc<dyn TraceStore>,
73 batch_size: usize,
74 flush_interval_ms: u64,
75 ) -> Self {
76 Self::with_config_and_langfuse(store, None, batch_size, flush_interval_ms)
77 }
78
79 #[must_use]
81 pub fn with_config_and_langfuse(
82 store: Arc<dyn TraceStore>,
83 langfuse: Option<LangfuseExporter>,
84 batch_size: usize,
85 flush_interval_ms: u64,
86 ) -> Self {
87 let inner = Arc::new(BatchWriterInner {
88 buffer: Mutex::new(Vec::with_capacity(batch_size)),
89 store,
90 langfuse,
91 batch_size,
92 shutdown: Mutex::new(false),
93 });
94
95 let inner_clone = Arc::clone(&inner);
96 tokio::spawn(async move {
97 let mut interval =
98 tokio::time::interval(std::time::Duration::from_millis(flush_interval_ms));
99 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
100
101 loop {
102 interval.tick().await;
103
104 let is_shutdown = {
106 let guard = inner_clone.shutdown.lock().await;
107 *guard
108 };
109
110 let mut buffer = inner_clone.buffer.lock().await;
111 if !buffer.is_empty() {
112 let batch: Vec<TelemetryItem> = buffer.drain(..).collect();
113 drop(buffer);
114 Self::flush_batch(&inner_clone.store, inner_clone.langfuse.as_ref(), batch)
115 .await;
116 }
117
118 if is_shutdown {
119 break;
120 }
121 }
122 });
123
124 Self { inner }
125 }
126
127 pub async fn submit(&self, item: TelemetryItem) -> Result<(), StoreError> {
138 let mut buffer = self.inner.buffer.lock().await;
139 buffer.push(item);
140 if buffer.len() >= self.inner.batch_size {
141 let batch: Vec<TelemetryItem> = buffer.drain(..).collect();
142 drop(buffer);
143 Self::flush_batch(&self.inner.store, self.inner.langfuse.as_ref(), batch).await;
144 }
145 Ok(())
146 }
147
148 pub async fn submit_trace(&self, trace: Trace) -> Result<(), StoreError> {
154 self.submit(TelemetryItem::Trace(trace)).await
155 }
156
157 pub async fn submit_observation(&self, observation: Observation) -> Result<(), StoreError> {
163 self.submit(TelemetryItem::Observation(observation)).await
164 }
165
166 pub async fn submit_session(&self, session: Session) -> Result<(), StoreError> {
172 self.submit(TelemetryItem::Session(session)).await
173 }
174
175 pub async fn flush(&self) -> Result<(), StoreError> {
184 let batch: Vec<TelemetryItem> = {
185 let mut buffer = self.inner.buffer.lock().await;
186 buffer.drain(..).collect()
187 };
188 if !batch.is_empty() {
189 Self::flush_batch(&self.inner.store, self.inner.langfuse.as_ref(), batch).await;
190 }
191 Ok(())
192 }
193
194 pub async fn shutdown(self) -> Result<(), StoreError> {
203 *self.inner.shutdown.lock().await = true;
204
205 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
207 self.flush().await
209 }
210
211 #[expect(
212 clippy::cognitive_complexity,
213 reason = "flush_batch partitions items, writes to store, and exports to langfuse"
214 )]
215 async fn flush_batch(
216 store: &Arc<dyn TraceStore>,
217 langfuse: Option<&LangfuseExporter>,
218 batch: Vec<TelemetryItem>,
219 ) {
220 let (sessions, traces, observations) = Self::partition_items(batch);
221 let mut errors = 0;
222 errors += Self::flush_sessions(store, &sessions).await;
223 errors += Self::flush_traces(store, &traces).await;
224 errors += Self::flush_observations(store, &observations).await;
225 if errors > 0 {
226 warn!("batch writer: {errors} items failed to write");
227 } else {
228 debug!("batch writer: flush complete");
229 }
230
231 if let Some(exporter) = langfuse {
233 for trace in &traces {
234 let trace_obs: Vec<Observation> = observations
235 .iter()
236 .filter(|o| o.trace_id == trace.id)
237 .cloned()
238 .collect();
239 if let Err(e) = exporter.export(trace, &trace_obs).await {
240 warn!("langfuse export failed: {e}");
241 }
242 }
243 }
244 }
245
246 fn partition_items(batch: Vec<TelemetryItem>) -> (Vec<Session>, Vec<Trace>, Vec<Observation>) {
247 let mut sessions = Vec::new();
248 let mut traces = Vec::new();
249 let mut observations = Vec::new();
250 for item in batch {
251 match item {
252 TelemetryItem::Session(s) => sessions.push(s),
253 TelemetryItem::Trace(t) => traces.push(t),
254 TelemetryItem::Observation(o) => observations.push(o),
255 }
256 }
257 (sessions, traces, observations)
258 }
259
260 async fn flush_sessions(store: &Arc<dyn TraceStore>, sessions: &[Session]) -> u32 {
261 let mut errors = 0;
262 for session in sessions {
263 if let Err(e) = store.upsert_session(session).await {
264 errors += 1;
265 error!("batch writer: failed to write session: {e}");
266 }
267 }
268 errors
269 }
270
271 async fn flush_traces(store: &Arc<dyn TraceStore>, traces: &[Trace]) -> u32 {
272 let mut errors = 0;
273 for trace in traces {
274 if let Err(e) = store.upsert_trace(trace).await {
275 errors += 1;
276 error!("batch writer: failed to write trace: {e}");
277 }
278 }
279 errors
280 }
281
282 async fn flush_observations(store: &Arc<dyn TraceStore>, observations: &[Observation]) -> u32 {
283 let mut errors = 0;
284 for obs in observations {
285 if let Err(e) = store.insert_observation(obs).await {
286 errors += 1;
287 error!("batch writer: failed to write observation: {e}");
288 }
289 }
290 errors
291 }
292}
293
294#[cfg(test)]
295#[expect(
296 clippy::clone_on_ref_ptr,
297 reason = ".clone() needed for unsized coercion Arc<SqliteStore> -> Arc<dyn TraceStore>"
298)]
299mod tests {
300 use super::*;
301 use crate::sqlite_store::SqliteStore;
302
303 #[tokio::test]
304 async fn batch_writer_submit_and_flush() {
305 let store = Arc::new(SqliteStore::new_memory().await.unwrap());
306 let writer = BatchWriter::with_config(store.clone(), 2, 60_000);
307
308 let trace = Trace::new("test");
309 writer.submit_trace(trace.clone()).await.unwrap();
310 writer.flush().await.unwrap();
311
312 let loaded = store.get_trace(trace.id).await.unwrap();
313 assert!(loaded.is_some());
314 }
315
316 #[tokio::test]
317 async fn batch_writer_auto_flush() {
318 let store = Arc::new(SqliteStore::new_memory().await.unwrap());
319 let writer = BatchWriter::with_config(store.clone(), 2, 60_000);
320
321 let trace1 = Trace::new("test1");
322 let trace2 = Trace::new("test2");
323 writer.submit_trace(trace1.clone()).await.unwrap();
324 writer.submit_trace(trace2.clone()).await.unwrap();
325
326 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
328
329 let loaded1 = store.get_trace(trace1.id).await.unwrap();
330 let loaded2 = store.get_trace(trace2.id).await.unwrap();
331 assert!(loaded1.is_some());
332 assert!(loaded2.is_some());
333 }
334
335 #[tokio::test]
336 async fn batch_writer_shutdown() {
337 let store = Arc::new(SqliteStore::new_memory().await.unwrap());
338 let writer = BatchWriter::with_config(store.clone(), 100, 60_000);
339
340 let trace = Trace::new("test");
341 writer.submit_trace(trace.clone()).await.unwrap();
342 writer.shutdown().await.unwrap();
343
344 let loaded = store.get_trace(trace.id).await.unwrap();
345 assert!(loaded.is_some());
346 }
347
348 #[tokio::test]
349 async fn batch_writer_trace_and_observation() {
350 let store = Arc::new(SqliteStore::new_memory().await.unwrap());
351 let writer = BatchWriter::with_config(store.clone(), 100, 60_000);
352
353 let trace = Trace::new("test");
354 let trace_id = trace.id;
355 writer.submit_trace(trace).await.unwrap();
356
357 let obs = Observation::span(trace_id, "test_span");
358 writer.submit_observation(obs).await.unwrap();
359
360 writer.flush().await.unwrap();
361
362 let loaded = store.get_trace(trace_id).await.unwrap();
363 assert!(loaded.is_some(), "trace should exist");
364 let loaded = loaded.unwrap();
365 assert_eq!(
366 loaded.observations.len(),
367 1,
368 "expected 1 observation, got {}",
369 loaded.observations.len()
370 );
371 }
372
373 #[tokio::test]
374 async fn batch_writer_periodic_flush() {
375 let store = Arc::new(SqliteStore::new_memory().await.unwrap());
376 let writer = BatchWriter::with_config(store.clone(), 100, 50);
377
378 let trace = Trace::new("test");
379 writer.submit_trace(trace.clone()).await.unwrap();
380
381 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
382
383 let loaded = store.get_trace(trace.id).await.unwrap();
384 assert!(loaded.is_some());
385 }
386}