use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::Deserialize;
use super::models::{SessionTrace, TraceEvent, TraceEventKind};
#[derive(Debug, Clone, Deserialize)]
pub struct WireTraceResponse {
pub session_id: String,
#[serde(default)]
pub agent_id: String,
#[serde(default)]
pub spans: Vec<WireTraceSpan>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WireTraceSpan {
pub span_id: String,
#[serde(default)]
pub parent_span_id: Option<String>,
pub operation: String,
#[serde(default)]
pub decision: Option<String>,
pub start_time: DateTime<Utc>,
#[serde(default)]
pub end_time: Option<DateTime<Utc>>,
}
impl From<WireTraceResponse> for SessionTrace {
fn from(wire: WireTraceResponse) -> Self {
let mut children_of: HashMap<Option<String>, Vec<usize>> = HashMap::new();
let known_ids: std::collections::HashSet<&str> = wire.spans.iter().map(|s| s.span_id.as_str()).collect();
for (idx, span) in wire.spans.iter().enumerate() {
let bucket_key = match &span.parent_span_id {
Some(p) if known_ids.contains(p.as_str()) => Some(p.clone()),
_ => None,
};
children_of.entry(bucket_key).or_default().push(idx);
}
for indices in children_of.values_mut() {
indices.sort_by_key(|&i| wire.spans[i].start_time);
}
let events = build_events(None, &wire.spans, &children_of);
SessionTrace {
session_id: wire.session_id,
events,
}
}
}
fn build_events(
parent_id: Option<String>,
spans: &[WireTraceSpan],
children_of: &HashMap<Option<String>, Vec<usize>>,
) -> Vec<TraceEvent> {
let Some(indices) = children_of.get(&parent_id) else {
return Vec::new();
};
indices
.iter()
.map(|&i| {
let span = &spans[i];
TraceEvent {
kind: kind_from_span(&span.operation, span.decision.as_deref()),
label: span.operation.clone(),
duration_ms: duration_ms_from(span.start_time, span.end_time),
children: build_events(Some(span.span_id.clone()), spans, children_of),
violation_reason: None,
}
})
.collect()
}
pub fn duration_ms_from(start: DateTime<Utc>, end: Option<DateTime<Utc>>) -> u64 {
let Some(end) = end else { return 0 };
let delta = end.signed_duration_since(start).num_milliseconds();
if delta <= 0 {
0
} else {
delta as u64
}
}
pub fn kind_from_span(operation: &str, decision: Option<&str>) -> TraceEventKind {
if matches!(decision, Some(d) if d.eq_ignore_ascii_case("deny")) {
return TraceEventKind::PolicyDeny;
}
let op = operation.to_ascii_lowercase();
if op.contains("tool_result") {
TraceEventKind::ToolResult
} else if op.contains("tool_call") || op.contains("tool") {
TraceEventKind::ToolCall
} else if op.contains("llm") {
TraceEventKind::Llm
} else {
TraceEventKind::PolicyAllow
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deny_decision_wins_over_operation() {
assert_eq!(kind_from_span("llm_call", Some("deny")), TraceEventKind::PolicyDeny,);
assert_eq!(kind_from_span("tool_call", Some("DENY")), TraceEventKind::PolicyDeny,);
}
#[test]
fn llm_operation_maps_to_llm() {
assert_eq!(kind_from_span("llm_call", Some("allow")), TraceEventKind::Llm,);
assert_eq!(kind_from_span("LLM-Inference", None), TraceEventKind::Llm);
}
#[test]
fn tool_call_and_result_are_distinct() {
assert_eq!(kind_from_span("tool_call", None), TraceEventKind::ToolCall,);
assert_eq!(kind_from_span("tool_result", None), TraceEventKind::ToolResult,);
}
#[test]
fn unknown_operation_defaults_to_policy_allow() {
assert_eq!(kind_from_span("op-42", Some("allow")), TraceEventKind::PolicyAllow,);
assert_eq!(kind_from_span("misc", None), TraceEventKind::PolicyAllow);
}
#[test]
fn duration_ms_returns_zero_when_end_missing() {
let start = Utc::now();
assert_eq!(duration_ms_from(start, None), 0);
}
#[test]
fn duration_ms_computes_positive_delta() {
let start = chrono::TimeZone::with_ymd_and_hms(&Utc, 2026, 1, 1, 0, 0, 0).unwrap();
let end = start + chrono::Duration::milliseconds(250);
assert_eq!(duration_ms_from(start, Some(end)), 250);
}
#[test]
fn duration_ms_clamps_negative_to_zero() {
let start = chrono::TimeZone::with_ymd_and_hms(&Utc, 2026, 1, 1, 0, 0, 0).unwrap();
let end = start - chrono::Duration::milliseconds(100);
assert_eq!(duration_ms_from(start, Some(end)), 0);
}
fn span(span_id: &str, parent: Option<&str>, op: &str, seconds_offset: i64) -> WireTraceSpan {
let base = chrono::TimeZone::with_ymd_and_hms(&Utc, 2026, 1, 1, 0, 0, 0).unwrap();
let start = base + chrono::Duration::seconds(seconds_offset);
WireTraceSpan {
span_id: span_id.to_string(),
parent_span_id: parent.map(String::from),
operation: op.to_string(),
decision: Some("allow".to_string()),
start_time: start,
end_time: Some(start + chrono::Duration::milliseconds(100)),
}
}
#[test]
fn translate_flat_spans_become_top_level_events() {
let wire = WireTraceResponse {
session_id: "sess".into(),
agent_id: String::new(),
spans: vec![span("a", None, "llm_call", 0), span("b", None, "tool_call", 1)],
};
let trace: SessionTrace = wire.into();
assert_eq!(trace.session_id, "sess");
assert_eq!(trace.events.len(), 2);
assert_eq!(trace.events[0].label, "llm_call");
assert_eq!(trace.events[1].label, "tool_call");
assert!(trace.events.iter().all(|e| e.children.is_empty()));
}
#[test]
fn translate_builds_parent_child_tree() {
let wire = WireTraceResponse {
session_id: "sess".into(),
agent_id: String::new(),
spans: vec![
span("root", None, "llm_call", 0),
span("child1", Some("root"), "tool_call", 1),
span("child2", Some("root"), "tool_result", 2),
],
};
let trace: SessionTrace = wire.into();
assert_eq!(trace.events.len(), 1);
assert_eq!(trace.events[0].label, "llm_call");
assert_eq!(trace.events[0].children.len(), 2);
assert_eq!(trace.events[0].children[0].label, "tool_call");
assert_eq!(trace.events[0].children[1].label, "tool_result");
}
#[test]
fn translate_sorts_siblings_by_start_time() {
let wire = WireTraceResponse {
session_id: "sess".into(),
agent_id: String::new(),
spans: vec![span("late", None, "op-3", 5), span("early", None, "op-1", 1)],
};
let trace: SessionTrace = wire.into();
assert_eq!(trace.events[0].label, "op-1");
assert_eq!(trace.events[1].label, "op-3");
}
#[test]
fn translate_promotes_orphans_to_top_level() {
let wire = WireTraceResponse {
session_id: "sess".into(),
agent_id: String::new(),
spans: vec![span("orphan", Some("missing-parent"), "tool_call", 0)],
};
let trace: SessionTrace = wire.into();
assert_eq!(trace.events.len(), 1);
assert_eq!(trace.events[0].label, "tool_call");
}
}