use crate::span_record::SpanRecord;
use anyhow::{Context, Result};
use std::collections::HashMap;
use trueno_graph::{CsrGraph, NodeId};
pub struct CausalGraph {
graph: CsrGraph,
span_metadata: HashMap<NodeId, SpanRecord>,
span_id_to_node: HashMap<[u8; 8], NodeId>,
roots: Vec<NodeId>,
}
impl CausalGraph {
pub fn from_spans(spans: &[SpanRecord]) -> Result<Self> {
if spans.is_empty() {
return Ok(Self {
graph: CsrGraph::new(),
span_metadata: HashMap::new(),
span_id_to_node: HashMap::new(),
roots: Vec::new(),
});
}
let trace_id = spans[0].trace_id;
for span in spans {
if span.trace_id != trace_id {
anyhow::bail!(
"All spans must have same trace_id. Expected {:?}, got {:?}",
hex::encode(trace_id),
hex::encode(span.trace_id)
);
}
}
let mut graph = CsrGraph::new();
let mut span_metadata = HashMap::new();
let mut span_id_to_node = HashMap::new();
let mut roots = Vec::new();
for (idx, span) in spans.iter().enumerate() {
let node_id = NodeId(idx as u32);
span_metadata.insert(node_id, span.clone());
span_id_to_node.insert(span.span_id, node_id);
if span.is_root() {
roots.push(node_id);
}
}
for (idx, span) in spans.iter().enumerate() {
let child_node = NodeId(idx as u32);
if let Some(parent_span_id) = span.parent_span_id {
if let Some(&parent_node) = span_id_to_node.get(&parent_span_id) {
let weight = span.duration_nanos as f32;
graph
.add_edge(parent_node, child_node, weight)
.context("Failed to add parent-child edge")?;
}
}
}
Ok(Self { graph, span_metadata, span_id_to_node, roots })
}
pub fn node_count(&self) -> usize {
self.span_metadata.len()
}
pub fn edge_count(&self) -> usize {
self.graph.num_edges()
}
pub fn roots(&self) -> &[NodeId] {
&self.roots
}
pub fn get_span(&self, node: NodeId) -> Option<&SpanRecord> {
self.span_metadata.get(&node)
}
pub fn children(&self, node: NodeId) -> Result<Vec<(NodeId, f32)>> {
let (neighbors, weights) = self.graph.adjacency(node);
Ok(neighbors.iter().zip(weights.iter()).map(|(&n, &w)| (NodeId(n), w)).collect())
}
pub fn as_csr_graph(&self) -> &CsrGraph {
&self.graph
}
pub fn descendants(&self, root: NodeId) -> Result<Vec<NodeId>> {
let mut visited = Vec::new();
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if visited.contains(&node) {
continue;
}
visited.push(node);
let children = self.children(node)?;
for (child, _weight) in children {
if !visited.contains(&child) {
stack.push(child);
}
}
}
Ok(visited)
}
pub fn is_dag(&self) -> Result<bool> {
let mut visited = std::collections::HashSet::new();
let mut rec_stack = std::collections::HashSet::new();
for &root in &self.roots {
if !visited.contains(&root) && self.has_cycle_dfs(root, &mut visited, &mut rec_stack)? {
return Ok(false);
}
}
Ok(true)
}
fn has_cycle_dfs(
&self,
node: NodeId,
visited: &mut std::collections::HashSet<NodeId>,
rec_stack: &mut std::collections::HashSet<NodeId>,
) -> Result<bool> {
visited.insert(node);
rec_stack.insert(node);
let children = self.children(node)?;
for (child, _) in children {
if !visited.contains(&child) {
if self.has_cycle_dfs(child, visited, rec_stack)? {
return Ok(true);
}
} else if rec_stack.contains(&child) {
return Ok(true);
}
}
rec_stack.remove(&node);
Ok(false)
}
pub fn get_node_by_span_id(&self, span_id: &[u8; 8]) -> Option<NodeId> {
self.span_id_to_node.get(span_id).copied()
}
pub fn get_span_by_id(&self, span_id: &[u8; 8]) -> Option<&SpanRecord> {
self.get_node_by_span_id(span_id).and_then(|node| self.get_span(node))
}
}
static_assertions::assert_impl_all!(CausalGraph: Send, Sync);
#[cfg(test)]
mod tests {
use super::*;
use crate::span_record::{SpanKind, StatusCode};
fn create_test_span(
span_id: u8,
parent_id: Option<u8>,
logical_clock: u64,
duration_nanos: u64,
) -> SpanRecord {
SpanRecord::new(
[1; 16],
[span_id; 8],
parent_id.map(|p| [p; 8]),
format!("span_{}", span_id),
SpanKind::Internal,
logical_clock * 1000,
logical_clock * 1000 + duration_nanos,
logical_clock,
StatusCode::Ok,
String::new(),
HashMap::new(),
HashMap::new(),
1234,
5678,
)
}
#[test]
fn test_empty_graph() {
let graph = CausalGraph::from_spans(&[]).expect("test");
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_single_node() {
let span = create_test_span(1, None, 0, 1000);
let graph = CausalGraph::from_spans(&[span]).expect("test");
assert_eq!(graph.node_count(), 1);
assert_eq!(graph.edge_count(), 0);
assert_eq!(graph.roots().len(), 1);
}
#[test]
fn test_parent_child() {
let root = create_test_span(1, None, 0, 1000);
let child = create_test_span(2, Some(1), 1, 2000);
let graph = CausalGraph::from_spans(&[root, child]).expect("test");
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
assert_eq!(graph.roots().len(), 1);
let children = graph.children(NodeId(0)).expect("test");
assert_eq!(children.len(), 1);
assert_eq!(children[0].0, NodeId(1));
assert_eq!(children[0].1, 2000.0); }
#[test]
fn test_tree_structure() {
let root = create_test_span(1, None, 0, 1000);
let child1 = create_test_span(2, Some(1), 1, 500);
let child2 = create_test_span(3, Some(1), 2, 700);
let grandchild = create_test_span(4, Some(2), 3, 300);
let graph = CausalGraph::from_spans(&[root, child1, child2, grandchild]).expect("test");
assert_eq!(graph.node_count(), 4);
assert_eq!(graph.edge_count(), 3); assert_eq!(graph.roots().len(), 1);
let children = graph.children(NodeId(0)).expect("test");
assert_eq!(children.len(), 2);
let grandchildren = graph.children(NodeId(1)).expect("test");
assert_eq!(grandchildren.len(), 1);
let leaf_children = graph.children(NodeId(2)).expect("test");
assert_eq!(leaf_children.len(), 0);
}
#[test]
fn test_get_span_metadata() {
let root = create_test_span(1, None, 0, 1000);
let graph = CausalGraph::from_spans(std::slice::from_ref(&root)).expect("test");
let span = graph.get_span(NodeId(0)).expect("test");
assert_eq!(span.span_name, "span_1");
assert_eq!(span.logical_clock, 0);
}
#[test]
fn test_descendants() {
let root = create_test_span(1, None, 0, 1000);
let child1 = create_test_span(2, Some(1), 1, 500);
let child2 = create_test_span(3, Some(1), 2, 700);
let grandchild = create_test_span(4, Some(2), 3, 300);
let graph = CausalGraph::from_spans(&[root, child1, child2, grandchild]).expect("test");
let desc = graph.descendants(NodeId(0)).expect("test");
assert_eq!(desc.len(), 4); }
#[test]
fn test_is_dag() {
let root = create_test_span(1, None, 0, 1000);
let child = create_test_span(2, Some(1), 1, 500);
let graph = CausalGraph::from_spans(&[root, child]).expect("test");
assert!(graph.is_dag().expect("test"));
}
#[test]
fn test_inconsistent_trace_id() {
let mut span1 = create_test_span(1, None, 0, 1000);
let span2 = create_test_span(2, Some(1), 1, 500);
span1.trace_id = [2; 16];
let result = CausalGraph::from_spans(&[span1, span2]);
assert!(result.is_err());
}
}