use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error, warn};
use crate::langfuse::LangfuseExporter;
use crate::models::{Observation, Session, Trace};
use crate::trace_store::{StoreError, TraceStore};
const DEFAULT_BATCH_SIZE: usize = 50;
const DEFAULT_FLUSH_INTERVAL_MS: u64 = 5_000;
#[derive(Clone, Debug)]
pub enum TelemetryItem {
Trace(Trace),
Observation(Observation),
Session(Session),
}
#[derive(Clone)]
pub struct BatchWriter {
inner: Arc<BatchWriterInner>,
}
impl std::fmt::Debug for BatchWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchWriter")
.field("batch_size", &self.inner.batch_size)
.finish()
}
}
struct BatchWriterInner {
buffer: Mutex<Vec<TelemetryItem>>,
store: Arc<dyn TraceStore>,
langfuse: Option<LangfuseExporter>,
batch_size: usize,
shutdown: Mutex<bool>,
}
impl BatchWriter {
#[must_use]
pub fn new(store: Arc<dyn TraceStore>) -> Self {
Self::with_config(store, DEFAULT_BATCH_SIZE, DEFAULT_FLUSH_INTERVAL_MS)
}
#[must_use]
pub fn with_config(
store: Arc<dyn TraceStore>,
batch_size: usize,
flush_interval_ms: u64,
) -> Self {
Self::with_config_and_langfuse(store, None, batch_size, flush_interval_ms)
}
#[must_use]
pub fn with_config_and_langfuse(
store: Arc<dyn TraceStore>,
langfuse: Option<LangfuseExporter>,
batch_size: usize,
flush_interval_ms: u64,
) -> Self {
let inner = Arc::new(BatchWriterInner {
buffer: Mutex::new(Vec::with_capacity(batch_size)),
store,
langfuse,
batch_size,
shutdown: Mutex::new(false),
});
let inner_clone = Arc::clone(&inner);
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(flush_interval_ms));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let is_shutdown = {
let guard = inner_clone.shutdown.lock().await;
*guard
};
let mut buffer = inner_clone.buffer.lock().await;
if !buffer.is_empty() {
let batch: Vec<TelemetryItem> = buffer.drain(..).collect();
drop(buffer);
Self::flush_batch(&inner_clone.store, inner_clone.langfuse.as_ref(), batch)
.await;
}
if is_shutdown {
break;
}
}
});
Self { inner }
}
pub async fn submit(&self, item: TelemetryItem) -> Result<(), StoreError> {
let mut buffer = self.inner.buffer.lock().await;
buffer.push(item);
if buffer.len() >= self.inner.batch_size {
let batch: Vec<TelemetryItem> = buffer.drain(..).collect();
drop(buffer);
Self::flush_batch(&self.inner.store, self.inner.langfuse.as_ref(), batch).await;
}
Ok(())
}
pub async fn submit_trace(&self, trace: Trace) -> Result<(), StoreError> {
self.submit(TelemetryItem::Trace(trace)).await
}
pub async fn submit_observation(&self, observation: Observation) -> Result<(), StoreError> {
self.submit(TelemetryItem::Observation(observation)).await
}
pub async fn submit_session(&self, session: Session) -> Result<(), StoreError> {
self.submit(TelemetryItem::Session(session)).await
}
pub async fn flush(&self) -> Result<(), StoreError> {
let batch: Vec<TelemetryItem> = {
let mut buffer = self.inner.buffer.lock().await;
buffer.drain(..).collect()
};
if !batch.is_empty() {
Self::flush_batch(&self.inner.store, self.inner.langfuse.as_ref(), batch).await;
}
Ok(())
}
pub async fn shutdown(self) -> Result<(), StoreError> {
*self.inner.shutdown.lock().await = true;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
self.flush().await
}
#[expect(
clippy::cognitive_complexity,
reason = "flush_batch partitions items, writes to store, and exports to langfuse"
)]
async fn flush_batch(
store: &Arc<dyn TraceStore>,
langfuse: Option<&LangfuseExporter>,
batch: Vec<TelemetryItem>,
) {
let (sessions, traces, observations) = Self::partition_items(batch);
let mut errors = 0;
errors += Self::flush_sessions(store, &sessions).await;
errors += Self::flush_traces(store, &traces).await;
errors += Self::flush_observations(store, &observations).await;
if errors > 0 {
warn!("batch writer: {errors} items failed to write");
} else {
debug!("batch writer: flush complete");
}
if let Some(exporter) = langfuse {
for trace in &traces {
let trace_obs: Vec<Observation> = observations
.iter()
.filter(|o| o.trace_id == trace.id)
.cloned()
.collect();
if let Err(e) = exporter.export(trace, &trace_obs).await {
warn!("langfuse export failed: {e}");
}
}
}
}
fn partition_items(batch: Vec<TelemetryItem>) -> (Vec<Session>, Vec<Trace>, Vec<Observation>) {
let mut sessions = Vec::new();
let mut traces = Vec::new();
let mut observations = Vec::new();
for item in batch {
match item {
TelemetryItem::Session(s) => sessions.push(s),
TelemetryItem::Trace(t) => traces.push(t),
TelemetryItem::Observation(o) => observations.push(o),
}
}
(sessions, traces, observations)
}
async fn flush_sessions(store: &Arc<dyn TraceStore>, sessions: &[Session]) -> u32 {
let mut errors = 0;
for session in sessions {
if let Err(e) = store.upsert_session(session).await {
errors += 1;
error!("batch writer: failed to write session: {e}");
}
}
errors
}
async fn flush_traces(store: &Arc<dyn TraceStore>, traces: &[Trace]) -> u32 {
let mut errors = 0;
for trace in traces {
if let Err(e) = store.upsert_trace(trace).await {
errors += 1;
error!("batch writer: failed to write trace: {e}");
}
}
errors
}
async fn flush_observations(store: &Arc<dyn TraceStore>, observations: &[Observation]) -> u32 {
let mut errors = 0;
for obs in observations {
if let Err(e) = store.insert_observation(obs).await {
errors += 1;
error!("batch writer: failed to write observation: {e}");
}
}
errors
}
}
#[cfg(test)]
#[expect(
clippy::clone_on_ref_ptr,
reason = ".clone() needed for unsized coercion Arc<SqliteStore> -> Arc<dyn TraceStore>"
)]
mod tests {
use super::*;
use crate::sqlite_store::SqliteStore;
#[tokio::test]
async fn batch_writer_submit_and_flush() {
let store = Arc::new(SqliteStore::new_memory().await.unwrap());
let writer = BatchWriter::with_config(store.clone(), 2, 60_000);
let trace = Trace::new("test");
writer.submit_trace(trace.clone()).await.unwrap();
writer.flush().await.unwrap();
let loaded = store.get_trace(trace.id).await.unwrap();
assert!(loaded.is_some());
}
#[tokio::test]
async fn batch_writer_auto_flush() {
let store = Arc::new(SqliteStore::new_memory().await.unwrap());
let writer = BatchWriter::with_config(store.clone(), 2, 60_000);
let trace1 = Trace::new("test1");
let trace2 = Trace::new("test2");
writer.submit_trace(trace1.clone()).await.unwrap();
writer.submit_trace(trace2.clone()).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let loaded1 = store.get_trace(trace1.id).await.unwrap();
let loaded2 = store.get_trace(trace2.id).await.unwrap();
assert!(loaded1.is_some());
assert!(loaded2.is_some());
}
#[tokio::test]
async fn batch_writer_shutdown() {
let store = Arc::new(SqliteStore::new_memory().await.unwrap());
let writer = BatchWriter::with_config(store.clone(), 100, 60_000);
let trace = Trace::new("test");
writer.submit_trace(trace.clone()).await.unwrap();
writer.shutdown().await.unwrap();
let loaded = store.get_trace(trace.id).await.unwrap();
assert!(loaded.is_some());
}
#[tokio::test]
async fn batch_writer_trace_and_observation() {
let store = Arc::new(SqliteStore::new_memory().await.unwrap());
let writer = BatchWriter::with_config(store.clone(), 100, 60_000);
let trace = Trace::new("test");
let trace_id = trace.id;
writer.submit_trace(trace).await.unwrap();
let obs = Observation::span(trace_id, "test_span");
writer.submit_observation(obs).await.unwrap();
writer.flush().await.unwrap();
let loaded = store.get_trace(trace_id).await.unwrap();
assert!(loaded.is_some(), "trace should exist");
let loaded = loaded.unwrap();
assert_eq!(
loaded.observations.len(),
1,
"expected 1 observation, got {}",
loaded.observations.len()
);
}
#[tokio::test]
async fn batch_writer_periodic_flush() {
let store = Arc::new(SqliteStore::new_memory().await.unwrap());
let writer = BatchWriter::with_config(store.clone(), 100, 50);
let trace = Trace::new("test");
writer.submit_trace(trace.clone()).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(150)).await;
let loaded = store.get_trace(trace.id).await.unwrap();
assert!(loaded.is_some());
}
}