Skip to main content

kaizen/store/
span_tree.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! In-memory span tree assembled from flat `ToolSpanView` rows.
3
4use crate::metrics::types::ToolSpanView;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SpanNode {
10    pub span: ToolSpanView,
11    pub children: Vec<SpanNode>,
12    pub subtree_cost_usd_e6: i64,
13    pub subtree_token_count: u64,
14}
15
16/// Assemble a forest of `SpanNode` from a flat ordered list.
17pub fn build_tree(spans: Vec<ToolSpanView>) -> Vec<SpanNode> {
18    let input = SpanInput::new(spans);
19    let mut seen = HashSet::new();
20    let mut roots = build_ids(input.root_ids(), &input, &mut seen);
21    roots.extend(build_ids(input.remaining_ids(&seen), &input, &mut seen));
22    roots
23}
24
25struct SpanInput {
26    order: Vec<String>,
27    nodes: HashMap<String, ToolSpanView>,
28    children: HashMap<String, Vec<String>>,
29}
30
31impl SpanInput {
32    fn new(spans: Vec<ToolSpanView>) -> Self {
33        let mut input = Self {
34            order: vec![],
35            nodes: HashMap::new(),
36            children: HashMap::new(),
37        };
38        spans.into_iter().for_each(|span| input.insert(span));
39        input.link_children();
40        input
41    }
42
43    fn insert(&mut self, span: ToolSpanView) {
44        if self.nodes.contains_key(&span.span_id) {
45            return;
46        }
47        self.order.push(span.span_id.clone());
48        self.nodes.insert(span.span_id.clone(), span);
49    }
50
51    fn link_children(&mut self) {
52        let edges: Vec<_> = self
53            .order
54            .iter()
55            .filter_map(|id| self.parent_edge(id))
56            .collect();
57        edges.into_iter().for_each(|(p, c)| {
58            self.children.entry(p).or_default().push(c);
59        });
60    }
61
62    fn parent_edge(&self, id: &str) -> Option<(String, String)> {
63        let parent = self.nodes[id].parent_span_id.as_ref()?;
64        self.nodes
65            .contains_key(parent)
66            .then(|| (parent.clone(), id.into()))
67    }
68
69    fn root_ids(&self) -> Vec<String> {
70        self.order
71            .iter()
72            .filter(|id| self.is_root(id))
73            .cloned()
74            .collect()
75    }
76
77    fn is_root(&self, id: &str) -> bool {
78        self.nodes[id]
79            .parent_span_id
80            .as_ref()
81            .is_none_or(|p| !self.nodes.contains_key(p))
82    }
83
84    fn remaining_ids(&self, seen: &HashSet<String>) -> Vec<String> {
85        self.order
86            .iter()
87            .filter(|id| !seen.contains(*id))
88            .cloned()
89            .collect()
90    }
91}
92
93fn build_ids(ids: Vec<String>, input: &SpanInput, seen: &mut HashSet<String>) -> Vec<SpanNode> {
94    ids.into_iter()
95        .filter_map(|id| assemble(&id, input, seen, &mut vec![]))
96        .collect()
97}
98
99fn assemble(
100    id: &str,
101    input: &SpanInput,
102    seen: &mut HashSet<String>,
103    stack: &mut Vec<String>,
104) -> Option<SpanNode> {
105    if seen.contains(id) || stack.iter().any(|item| item == id) {
106        return None;
107    }
108    let span = input.nodes.get(id)?.clone();
109    stack.push(id.into());
110    let children = child_nodes(id, input, seen, stack);
111    stack.pop();
112    seen.insert(id.into());
113    Some(span_node(span, children))
114}
115
116fn child_nodes(
117    id: &str,
118    input: &SpanInput,
119    seen: &mut HashSet<String>,
120    stack: &mut Vec<String>,
121) -> Vec<SpanNode> {
122    input
123        .children
124        .get(id)
125        .into_iter()
126        .flatten()
127        .filter_map(|child| assemble(child, input, seen, stack))
128        .collect()
129}
130
131fn span_node(span: ToolSpanView, children: Vec<SpanNode>) -> SpanNode {
132    let subtree_cost_usd_e6 = span.subtree_cost_usd_e6.unwrap_or_default();
133    let subtree_token_count = span.subtree_token_count.unwrap_or_default() as u64;
134    SpanNode {
135        span,
136        children,
137        subtree_cost_usd_e6,
138        subtree_token_count,
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    #[test]
146    fn build_tree_keeps_grandchildren() {
147        let roots = build_tree(vec![
148            span("root", None),
149            span("child", Some("root")),
150            span("grandchild", Some("child")),
151        ]);
152        assert_eq!(ids(&roots), vec!["root"]);
153        assert_eq!(ids(&roots[0].children), vec!["child"]);
154        assert_eq!(ids(&roots[0].children[0].children), vec!["grandchild"]);
155    }
156
157    #[test]
158    fn build_tree_keeps_missing_parent_and_cycle_nodes_once() {
159        let roots = build_tree(vec![
160            span("orphan", Some("missing")),
161            span("a", Some("b")),
162            span("b", Some("a")),
163        ]);
164        assert_eq!(flat_ids(&roots), vec!["orphan", "a", "b"]);
165    }
166
167    fn span(id: &str, parent: Option<&str>) -> ToolSpanView {
168        ToolSpanView {
169            span_id: id.into(),
170            parent_span_id: parent.map(str::to_string),
171            ..Default::default()
172        }
173    }
174
175    fn ids(nodes: &[SpanNode]) -> Vec<&str> {
176        nodes.iter().map(|n| n.span.span_id.as_str()).collect()
177    }
178
179    fn flat_ids(nodes: &[SpanNode]) -> Vec<&str> {
180        nodes
181            .iter()
182            .flat_map(|n| std::iter::once(n.span.span_id.as_str()).chain(flat_ids(&n.children)))
183            .collect()
184    }
185}