kaizen/store/
span_tree.rs1use crate::metrics::types::ToolSpanView;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SpanNode {
10 pub span: ToolSpanView,
11 pub children: Vec<SpanNode>,
12 pub subtree_cost_usd_e6: i64,
13 pub subtree_token_count: u64,
14}
15
16pub fn build_tree(spans: Vec<ToolSpanView>) -> Vec<SpanNode> {
18 let input = SpanInput::new(spans);
19 let mut seen = HashSet::new();
20 let mut roots = build_ids(input.root_ids(), &input, &mut seen);
21 roots.extend(build_ids(input.remaining_ids(&seen), &input, &mut seen));
22 roots
23}
24
25struct SpanInput {
26 order: Vec<String>,
27 nodes: HashMap<String, ToolSpanView>,
28 children: HashMap<String, Vec<String>>,
29}
30
31impl SpanInput {
32 fn new(spans: Vec<ToolSpanView>) -> Self {
33 let mut input = Self {
34 order: vec![],
35 nodes: HashMap::new(),
36 children: HashMap::new(),
37 };
38 spans.into_iter().for_each(|span| input.insert(span));
39 input.link_children();
40 input
41 }
42
43 fn insert(&mut self, span: ToolSpanView) {
44 if self.nodes.contains_key(&span.span_id) {
45 return;
46 }
47 self.order.push(span.span_id.clone());
48 self.nodes.insert(span.span_id.clone(), span);
49 }
50
51 fn link_children(&mut self) {
52 let edges: Vec<_> = self
53 .order
54 .iter()
55 .filter_map(|id| self.parent_edge(id))
56 .collect();
57 edges.into_iter().for_each(|(p, c)| {
58 self.children.entry(p).or_default().push(c);
59 });
60 }
61
62 fn parent_edge(&self, id: &str) -> Option<(String, String)> {
63 let parent = self.nodes[id].parent_span_id.as_ref()?;
64 self.nodes
65 .contains_key(parent)
66 .then(|| (parent.clone(), id.into()))
67 }
68
69 fn root_ids(&self) -> Vec<String> {
70 self.order
71 .iter()
72 .filter(|id| self.is_root(id))
73 .cloned()
74 .collect()
75 }
76
77 fn is_root(&self, id: &str) -> bool {
78 self.nodes[id]
79 .parent_span_id
80 .as_ref()
81 .is_none_or(|p| !self.nodes.contains_key(p))
82 }
83
84 fn remaining_ids(&self, seen: &HashSet<String>) -> Vec<String> {
85 self.order
86 .iter()
87 .filter(|id| !seen.contains(*id))
88 .cloned()
89 .collect()
90 }
91}
92
93fn build_ids(ids: Vec<String>, input: &SpanInput, seen: &mut HashSet<String>) -> Vec<SpanNode> {
94 ids.into_iter()
95 .filter_map(|id| assemble(&id, input, seen, &mut vec![]))
96 .collect()
97}
98
99fn assemble(
100 id: &str,
101 input: &SpanInput,
102 seen: &mut HashSet<String>,
103 stack: &mut Vec<String>,
104) -> Option<SpanNode> {
105 if seen.contains(id) || stack.iter().any(|item| item == id) {
106 return None;
107 }
108 let span = input.nodes.get(id)?.clone();
109 stack.push(id.into());
110 let children = child_nodes(id, input, seen, stack);
111 stack.pop();
112 seen.insert(id.into());
113 Some(span_node(span, children))
114}
115
116fn child_nodes(
117 id: &str,
118 input: &SpanInput,
119 seen: &mut HashSet<String>,
120 stack: &mut Vec<String>,
121) -> Vec<SpanNode> {
122 input
123 .children
124 .get(id)
125 .into_iter()
126 .flatten()
127 .filter_map(|child| assemble(child, input, seen, stack))
128 .collect()
129}
130
131fn span_node(span: ToolSpanView, children: Vec<SpanNode>) -> SpanNode {
132 let subtree_cost_usd_e6 = span.subtree_cost_usd_e6.unwrap_or_default();
133 let subtree_token_count = span.subtree_token_count.unwrap_or_default() as u64;
134 SpanNode {
135 span,
136 children,
137 subtree_cost_usd_e6,
138 subtree_token_count,
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 #[test]
146 fn build_tree_keeps_grandchildren() {
147 let roots = build_tree(vec![
148 span("root", None),
149 span("child", Some("root")),
150 span("grandchild", Some("child")),
151 ]);
152 assert_eq!(ids(&roots), vec!["root"]);
153 assert_eq!(ids(&roots[0].children), vec!["child"]);
154 assert_eq!(ids(&roots[0].children[0].children), vec!["grandchild"]);
155 }
156
157 #[test]
158 fn build_tree_keeps_missing_parent_and_cycle_nodes_once() {
159 let roots = build_tree(vec![
160 span("orphan", Some("missing")),
161 span("a", Some("b")),
162 span("b", Some("a")),
163 ]);
164 assert_eq!(flat_ids(&roots), vec!["orphan", "a", "b"]);
165 }
166
167 fn span(id: &str, parent: Option<&str>) -> ToolSpanView {
168 ToolSpanView {
169 span_id: id.into(),
170 parent_span_id: parent.map(str::to_string),
171 ..Default::default()
172 }
173 }
174
175 fn ids(nodes: &[SpanNode]) -> Vec<&str> {
176 nodes.iter().map(|n| n.span.span_id.as_str()).collect()
177 }
178
179 fn flat_ids(nodes: &[SpanNode]) -> Vec<&str> {
180 nodes
181 .iter()
182 .flat_map(|n| std::iter::once(n.span.span_id.as_str()).chain(flat_ids(&n.children)))
183 .collect()
184 }
185}