use std::sync::{Arc, Mutex};
use trace_weft_core::{EventRecord, SpanRecord, TraceWeftSpanKind};
use trace_weft_recorder::TraceStore;
#[derive(Clone, Default)]
pub struct MemoryStore {
pub spans: Arc<Mutex<Vec<SpanRecord>>>,
pub events: Arc<Mutex<Vec<EventRecord>>>,
}
impl MemoryStore {
pub fn new() -> Self {
Self::default()
}
pub fn get_trajectory(&self) -> TraceTrajectory {
let spans = self.spans.lock().unwrap().clone();
TraceTrajectory { spans }
}
pub fn clear(&self) {
self.spans.lock().unwrap().clear();
self.events.lock().unwrap().clear();
}
}
#[async_trait::async_trait]
impl TraceStore for MemoryStore {
async fn record_span(&self, span: SpanRecord) -> anyhow::Result<()> {
self.spans.lock().unwrap().push(span);
Ok(())
}
async fn record_event(&self, event: EventRecord) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
pub struct TraceTrajectory {
pub spans: Vec<SpanRecord>,
}
impl TraceTrajectory {
pub fn contains_tool_call(&self, tool_name: &str) -> bool {
self.spans
.iter()
.any(|s| s.span_kind == TraceWeftSpanKind::Tool && s.name == tool_name)
}
pub fn has_errors(&self) -> bool {
self.spans.iter().any(|s| {
s.status == trace_weft_core::SpanStatus::Error
|| s.span_kind == TraceWeftSpanKind::Error
})
}
pub fn total_cost(&self) -> f64 {
self.spans
.iter()
.filter_map(|s| s.cost_estimate.as_ref())
.map(|c| c.amount)
.sum()
}
pub fn total_latency_ms(&self) -> u64 {
self.spans
.iter()
.filter(|s| s.parent_span_id.is_none())
.map(|s| s.latency_ms.unwrap_or(0))
.sum()
}
pub fn total_input_tokens(&self) -> u64 {
self.spans
.iter()
.filter_map(|s| s.token_usage.as_ref())
.map(|u| u.input)
.sum()
}
}