use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::{ClientError, Result};
use crate::ConnectionTrait;
pub type TraceId = String;
pub type SpanId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceRun {
pub trace_id: TraceId,
pub name: String,
pub start_time: u64,
pub end_time: Option<u64>,
pub status: TraceStatus,
pub attributes: HashMap<String, TraceValue>,
pub resource: HashMap<String, String>,
pub total_tokens: u64,
pub cost_millicents: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TraceStatus {
Running,
Ok,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceSpan {
pub trace_id: TraceId,
pub span_id: SpanId,
pub parent_span_id: Option<SpanId>,
pub name: String,
pub kind: SpanKind,
pub start_time: u64,
pub end_time: Option<u64>,
pub duration_us: Option<u64>,
pub status: SpanStatus,
pub attributes: HashMap<String, TraceValue>,
pub events: Vec<SpanEvent>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SpanKind {
Internal,
Server,
Client,
Producer,
Consumer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpanStatus {
pub code: SpanStatusCode,
pub message: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SpanStatusCode {
Unset,
Ok,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpanEvent {
pub name: String,
pub timestamp: u64,
pub attributes: HashMap<String, TraceValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum TraceValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
StringArray(Vec<String>),
IntArray(Vec<i64>),
}
impl From<&str> for TraceValue {
fn from(s: &str) -> Self {
TraceValue::String(s.to_string())
}
}
impl From<String> for TraceValue {
fn from(s: String) -> Self {
TraceValue::String(s)
}
}
impl From<i64> for TraceValue {
fn from(i: i64) -> Self {
TraceValue::Int(i)
}
}
impl From<f64> for TraceValue {
fn from(f: f64) -> Self {
TraceValue::Float(f)
}
}
impl From<bool> for TraceValue {
fn from(b: bool) -> Self {
TraceValue::Bool(b)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalHitEvent {
pub doc_id: String,
pub score: f32,
pub modality: String,
pub rank: usize,
pub filtered: bool,
pub collection: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallEvent {
pub tool_name: String,
pub arguments: String,
pub result: Option<String>,
pub duration_us: u64,
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextPackagingEvent {
pub sections: Vec<String>,
pub total_tokens: u64,
pub budget: u64,
pub truncated: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostEvent {
pub cost_type: String,
pub amount: u64,
pub unit_price_millicents: f64,
pub total_millicents: u64,
pub model: Option<String>,
}
const TRACE_PREFIX: &str = "_traces/";
pub struct TraceStore<C: ConnectionTrait> {
conn: C,
sample_rate: f64,
}
impl<C: ConnectionTrait> TraceStore<C> {
pub fn new(conn: C) -> Self {
Self {
conn,
sample_rate: 1.0,
}
}
pub fn with_sampling(conn: C, sample_rate: f64) -> Self {
Self {
conn,
sample_rate: sample_rate.clamp(0.0, 1.0),
}
}
fn should_sample(&self) -> bool {
if self.sample_rate >= 1.0 {
return true;
}
if self.sample_rate <= 0.0 {
return false;
}
rand::random::<f64>() < self.sample_rate
}
fn now_micros() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64
}
fn run_key(trace_id: &TraceId) -> Vec<u8> {
format!("{}runs/{}", TRACE_PREFIX, trace_id).into_bytes()
}
fn span_key(trace_id: &TraceId, span_id: &SpanId) -> Vec<u8> {
format!("{}spans/{}/{}", TRACE_PREFIX, trace_id, span_id).into_bytes()
}
fn spans_prefix(trace_id: &TraceId) -> Vec<u8> {
format!("{}spans/{}/", TRACE_PREFIX, trace_id).into_bytes()
}
fn event_key(trace_id: &TraceId, timestamp: u64, seq: u64) -> Vec<u8> {
format!(
"{}events/{}/{:016x}_{:08x}",
TRACE_PREFIX, trace_id, timestamp, seq
).into_bytes()
}
fn events_prefix(trace_id: &TraceId) -> Vec<u8> {
format!("{}events/{}/", TRACE_PREFIX, trace_id).into_bytes()
}
pub fn start_run(
&self,
name: impl Into<String>,
resource: HashMap<String, String>,
) -> Result<TraceRun> {
let trace_id = generate_trace_id();
let now = Self::now_micros();
let run = TraceRun {
trace_id: trace_id.clone(),
name: name.into(),
start_time: now,
end_time: None,
status: TraceStatus::Running,
attributes: HashMap::new(),
resource,
total_tokens: 0,
cost_millicents: 0,
};
if self.should_sample() {
let key = Self::run_key(&trace_id);
let value = serde_json::to_vec(&run)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(run)
}
pub fn end_run(&self, trace_id: &TraceId, status: TraceStatus) -> Result<()> {
let key = Self::run_key(trace_id);
if let Some(data) = self.conn.get(&key)? {
let mut run: TraceRun = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
run.end_time = Some(Self::now_micros());
run.status = status;
let value = serde_json::to_vec(&run)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn get_run(&self, trace_id: &TraceId) -> Result<Option<TraceRun>> {
let key = Self::run_key(trace_id);
if let Some(data) = self.conn.get(&key)? {
let run: TraceRun = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
Ok(Some(run))
} else {
Ok(None)
}
}
pub fn update_run_metrics(
&self,
trace_id: &TraceId,
tokens: u64,
cost_millicents: u64,
) -> Result<()> {
let key = Self::run_key(trace_id);
if let Some(data) = self.conn.get(&key)? {
let mut run: TraceRun = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
run.total_tokens += tokens;
run.cost_millicents += cost_millicents;
let value = serde_json::to_vec(&run)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn start_span(
&self,
trace_id: &TraceId,
name: impl Into<String>,
parent_span_id: Option<SpanId>,
kind: SpanKind,
) -> Result<TraceSpan> {
let span_id = generate_span_id();
let now = Self::now_micros();
let span = TraceSpan {
trace_id: trace_id.clone(),
span_id: span_id.clone(),
parent_span_id,
name: name.into(),
kind,
start_time: now,
end_time: None,
duration_us: None,
status: SpanStatus {
code: SpanStatusCode::Unset,
message: None,
},
attributes: HashMap::new(),
events: Vec::new(),
};
if self.should_sample() {
let key = Self::span_key(trace_id, &span_id);
let value = serde_json::to_vec(&span)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(span)
}
pub fn end_span(
&self,
trace_id: &TraceId,
span_id: &SpanId,
status: SpanStatusCode,
message: Option<String>,
) -> Result<()> {
let key = Self::span_key(trace_id, span_id);
if let Some(data) = self.conn.get(&key)? {
let mut span: TraceSpan = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
let now = Self::now_micros();
span.end_time = Some(now);
span.duration_us = Some(now.saturating_sub(span.start_time));
span.status = SpanStatus {
code: status,
message,
};
let value = serde_json::to_vec(&span)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn add_span_event(
&self,
trace_id: &TraceId,
span_id: &SpanId,
name: impl Into<String>,
attributes: HashMap<String, TraceValue>,
) -> Result<()> {
let key = Self::span_key(trace_id, span_id);
if let Some(data) = self.conn.get(&key)? {
let mut span: TraceSpan = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
span.events.push(SpanEvent {
name: name.into(),
timestamp: Self::now_micros(),
attributes,
});
let value = serde_json::to_vec(&span)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn set_span_attributes(
&self,
trace_id: &TraceId,
span_id: &SpanId,
attributes: HashMap<String, TraceValue>,
) -> Result<()> {
let key = Self::span_key(trace_id, span_id);
if let Some(data) = self.conn.get(&key)? {
let mut span: TraceSpan = serde_json::from_slice(&data)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
span.attributes.extend(attributes);
let value = serde_json::to_vec(&span)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
self.conn.put(&key, &value)?;
}
Ok(())
}
pub fn get_spans(&self, trace_id: &TraceId) -> Result<Vec<TraceSpan>> {
let prefix = Self::spans_prefix(trace_id);
let results = self.conn.scan(&prefix)?;
let mut spans = Vec::new();
for (_, value) in results {
let span: TraceSpan = serde_json::from_slice(&value)
.map_err(|e| ClientError::Serialization(e.to_string()))?;
spans.push(span);
}
spans.sort_by_key(|s| s.start_time);
Ok(spans)
}
pub fn log_retrieval_hit(
&self,
trace_id: &TraceId,
span_id: &SpanId,
hit: RetrievalHitEvent,
) -> Result<()> {
let mut attrs = HashMap::new();
attrs.insert("doc_id".to_string(), TraceValue::String(hit.doc_id));
attrs.insert("score".to_string(), TraceValue::Float(hit.score as f64));
attrs.insert("modality".to_string(), TraceValue::String(hit.modality));
attrs.insert("rank".to_string(), TraceValue::Int(hit.rank as i64));
attrs.insert("filtered".to_string(), TraceValue::Bool(hit.filtered));
attrs.insert("collection".to_string(), TraceValue::String(hit.collection));
self.add_span_event(trace_id, span_id, "retrieval_hit", attrs)
}
pub fn log_tool_call(
&self,
trace_id: &TraceId,
span_id: &SpanId,
call: ToolCallEvent,
) -> Result<()> {
let mut attrs = HashMap::new();
attrs.insert("tool_name".to_string(), TraceValue::String(call.tool_name));
attrs.insert("arguments".to_string(), TraceValue::String(call.arguments));
attrs.insert("duration_us".to_string(), TraceValue::Int(call.duration_us as i64));
attrs.insert("success".to_string(), TraceValue::Bool(call.success));
if let Some(result) = call.result {
let truncated = if result.len() > 1000 {
format!("{}...(truncated)", &result[..1000])
} else {
result
};
attrs.insert("result".to_string(), TraceValue::String(truncated));
}
if let Some(error) = call.error {
attrs.insert("error".to_string(), TraceValue::String(error));
}
self.add_span_event(trace_id, span_id, "tool_call", attrs)
}
pub fn log_context_packaging(
&self,
trace_id: &TraceId,
span_id: &SpanId,
event: ContextPackagingEvent,
) -> Result<()> {
let mut attrs = HashMap::new();
attrs.insert("sections".to_string(), TraceValue::StringArray(event.sections));
attrs.insert("total_tokens".to_string(), TraceValue::Int(event.total_tokens as i64));
attrs.insert("budget".to_string(), TraceValue::Int(event.budget as i64));
attrs.insert("truncated".to_string(), TraceValue::Bool(event.truncated));
self.add_span_event(trace_id, span_id, "context_packaging", attrs)
}
pub fn log_cost(
&self,
trace_id: &TraceId,
event: CostEvent,
) -> Result<()> {
self.update_run_metrics(trace_id, event.amount, event.total_millicents)?;
Ok(())
}
}
fn generate_trace_id() -> String {
use std::time::SystemTime;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
format!("{:032x}", now ^ rand::random::<u128>())
}
fn generate_span_id() -> String {
format!("{:016x}", rand::random::<u64>())
}
mod rand {
use std::cell::Cell;
use std::time::SystemTime;
thread_local! {
static SEED: Cell<u64> = Cell::new(
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64
);
}
pub fn random<T: Random>() -> T {
T::random()
}
pub trait Random {
fn random() -> Self;
}
impl Random for u64 {
fn random() -> Self {
SEED.with(|seed| {
let mut s = seed.get();
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
seed.set(s);
s
})
}
}
impl Random for u128 {
fn random() -> Self {
let high = u64::random() as u128;
let low = u64::random() as u128;
(high << 64) | low
}
}
impl Random for f64 {
fn random() -> Self {
(u64::random() as f64) / (u64::MAX as f64)
}
}
}
#[cfg(test)]
mod tests {
}