Skip to main content

arbor_graph/
slice.rs

1//! Dynamic context slicing for LLM prompts.
2//!
3//! This module provides token-bounded context extraction from the code graph.
4//! Given a target node, it collects the minimal set of related nodes that fit
5//! within a token budget.
6
7use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14
15/// Reason for stopping context collection.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum TruncationReason {
18    /// All reachable nodes within max_depth were included.
19    Complete,
20    /// Stopped because token budget was reached.
21    TokenBudget,
22    /// Stopped because max depth was reached.
23    MaxDepth,
24}
25
26impl std::fmt::Display for TruncationReason {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            TruncationReason::Complete => write!(f, "complete"),
30            TruncationReason::TokenBudget => write!(f, "token_budget"),
31            TruncationReason::MaxDepth => write!(f, "max_depth"),
32        }
33    }
34}
35
36/// A node included in the context slice.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ContextNode {
39    /// Node information.
40    pub node_info: NodeInfo,
41    /// Estimated token count for this node's source.
42    pub token_estimate: usize,
43    /// Hop distance from the target node.
44    pub depth: usize,
45    /// Whether this node was pinned (always included).
46    pub pinned: bool,
47}
48
49/// Result of a context slicing operation.
50#[derive(Debug, Serialize, Deserialize)]
51pub struct ContextSlice {
52    /// The target node being queried.
53    pub target: NodeInfo,
54    /// Nodes included in the context, ordered by relevance.
55    pub nodes: Vec<ContextNode>,
56    /// Total estimated tokens in this slice.
57    pub total_tokens: usize,
58    /// Maximum tokens that were allowed.
59    pub max_tokens: usize,
60    /// Why slicing stopped.
61    pub truncation_reason: TruncationReason,
62    /// Query time in milliseconds.
63    pub query_time_ms: u64,
64}
65
66impl ContextSlice {
67    /// Returns a summary suitable for CLI output.
68    pub fn summary(&self) -> String {
69        format!(
70            "Context: {} nodes, ~{} tokens ({})",
71            self.nodes.len(),
72            self.total_tokens,
73            self.truncation_reason
74        )
75    }
76
77    pub fn pinned_only(&self) -> Vec<&ContextNode> {
78        self.nodes.iter().filter(|n| n.pinned).collect()
79    }
80}
81
82use once_cell::sync::Lazy;
83use tiktoken_rs::cl100k_base;
84
85/// Global tokenizer instance (lazy-loaded once)
86static TOKENIZER: Lazy<tiktoken_rs::CoreBPE> =
87    Lazy::new(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"));
88
89/// Threshold for falling back to heuristic tokenizer (800 KB)
90const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
91
92/// Estimates tokens for a code node using tiktoken cl100k_base.
93///
94/// Falls back to heuristic (4 chars/token) for large content to avoid
95/// performance issues with massive JS bundles.
96fn estimate_tokens(node: &NodeInfo) -> usize {
97    let base = node.name.len() + node.qualified_name.len() + node.file.len();
98    let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
99    let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
100    let estimated_chars = base + signature_len + (lines * 40);
101
102    // Performance guardrail: use heuristic for very large content
103    if estimated_chars > LARGE_FILE_THRESHOLD {
104        return estimated_chars.div_ceil(4); // Heuristic fallback
105    }
106
107    // Build text representation for accurate tokenization
108    let text = format!(
109        "{} {} {}{}",
110        node.qualified_name,
111        node.file,
112        node.signature.as_deref().unwrap_or(""),
113        " ".repeat(lines * 40) // Approximate code content
114    );
115
116    // Use tiktoken for accurate count
117    TOKENIZER.encode_with_special_tokens(&text).len()
118}
119
120impl ArborGraph {
121    /// Extracts a token-bounded context slice around a target node.
122    ///
123    /// Collects nodes in BFS order:
124    /// 1. Target node itself
125    /// 2. Direct upstream (callers) at depth 1
126    /// 3. Direct downstream (callees) at depth 1
127    /// 4. Continues outward until budget or max_depth reached
128    ///
129    /// Pinned nodes are always included regardless of budget.
130    ///
131    /// # Arguments
132    /// * `target` - The node to center the slice around
133    /// * `max_tokens` - Maximum token budget (0 = unlimited)
134    /// * `max_depth` - Maximum hop distance (0 = unlimited, default: 2)
135    /// * `pinned` - Nodes that must be included regardless of budget
136    pub fn slice_context(
137        &self,
138        target: NodeId,
139        max_tokens: usize,
140        max_depth: usize,
141        pinned: &[NodeId],
142    ) -> ContextSlice {
143        let start = Instant::now();
144
145        let target_node = match self.get(target) {
146            Some(node) => {
147                let mut info = NodeInfo::from(node);
148                info.centrality = self.centrality(target);
149                info
150            }
151            None => {
152                return ContextSlice {
153                    target: NodeInfo {
154                        id: String::new(),
155                        name: String::new(),
156                        qualified_name: String::new(),
157                        kind: String::new(),
158                        file: String::new(),
159                        line_start: 0,
160                        line_end: 0,
161                        signature: None,
162                        centrality: 0.0,
163                    },
164                    nodes: Vec::new(),
165                    total_tokens: 0,
166                    max_tokens,
167                    truncation_reason: TruncationReason::Complete,
168                    query_time_ms: 0,
169                };
170            }
171        };
172
173        let effective_max = if max_depth == 0 {
174            usize::MAX
175        } else {
176            max_depth
177        };
178        let effective_tokens = if max_tokens == 0 {
179            usize::MAX
180        } else {
181            max_tokens
182        };
183
184        let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
185        let mut visited: HashSet<NodeId> = HashSet::new();
186        let mut result: Vec<ContextNode> = Vec::new();
187        let mut total_tokens = 0usize;
188        let mut truncation_reason = TruncationReason::Complete;
189
190        // BFS queue: (node_id, depth)
191        let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
192
193        // Start with target
194        queue.push_back((target, 0));
195
196        while let Some((current, depth)) = queue.pop_front() {
197            if visited.contains(&current) {
198                continue;
199            }
200
201            if depth > effective_max {
202                truncation_reason = TruncationReason::MaxDepth;
203                continue;
204            }
205
206            visited.insert(current);
207
208            let is_pinned = pinned_set.contains(&current);
209
210            if let Some(node) = self.get(current) {
211                let mut node_info = NodeInfo::from(node);
212                node_info.centrality = self.centrality(current);
213
214                let token_est = estimate_tokens(&node_info);
215
216                // Check budget (pinned nodes bypass budget)
217                let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
218
219                if within_budget {
220                    total_tokens += token_est;
221
222                    result.push(ContextNode {
223                        node_info,
224                        token_estimate: token_est,
225                        depth,
226                        pinned: is_pinned,
227                    });
228                } else {
229                    truncation_reason = TruncationReason::TokenBudget;
230                    // Don't add to result, but STILL explore neighbors to find pinned nodes
231                }
232            }
233
234            // Always add neighbors to queue (to find pinned nodes even when budget exceeded)
235            if depth < effective_max {
236                // Upstream (incoming)
237                for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
238                    let neighbor = edge_ref.source();
239                    if !visited.contains(&neighbor) {
240                        queue.push_back((neighbor, depth + 1));
241                    }
242                }
243
244                // Downstream (outgoing)
245                for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
246                    let neighbor = edge_ref.target();
247                    if !visited.contains(&neighbor) {
248                        queue.push_back((neighbor, depth + 1));
249                    }
250                }
251            }
252        }
253
254        // Sort by: pinned first, then by depth, then by centrality (desc)
255        result.sort_by(|a, b| {
256            b.pinned
257                .cmp(&a.pinned)
258                .then_with(|| a.depth.cmp(&b.depth))
259                .then_with(|| {
260                    b.node_info
261                        .centrality
262                        .partial_cmp(&a.node_info.centrality)
263                        .unwrap_or(std::cmp::Ordering::Equal)
264                })
265        });
266
267        let elapsed = start.elapsed().as_millis() as u64;
268
269        ContextSlice {
270            target: target_node,
271            nodes: result,
272            total_tokens,
273            max_tokens,
274            truncation_reason,
275            query_time_ms: elapsed,
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::edge::{Edge, EdgeKind};
284    use arbor_core::{CodeNode, NodeKind};
285
286    fn make_node(name: &str) -> CodeNode {
287        CodeNode::new(name, name, NodeKind::Function, "test.rs")
288    }
289
290    #[test]
291    fn test_empty_graph() {
292        let graph = ArborGraph::new();
293        let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
294        assert!(result.nodes.is_empty());
295        assert_eq!(result.total_tokens, 0);
296    }
297
298    #[test]
299    fn test_single_node() {
300        let mut graph = ArborGraph::new();
301        let id = graph.add_node(make_node("lonely"));
302
303        let result = graph.slice_context(id, 1000, 2, &[]);
304        assert_eq!(result.nodes.len(), 1);
305        assert_eq!(result.nodes[0].node_info.name, "lonely");
306        assert_eq!(result.truncation_reason, TruncationReason::Complete);
307    }
308
309    #[test]
310    fn test_linear_chain_depth_limit() {
311        // A → B → C → D
312        let mut graph = ArborGraph::new();
313        let a = graph.add_node(make_node("a"));
314        let b = graph.add_node(make_node("b"));
315        let c = graph.add_node(make_node("c"));
316        let d = graph.add_node(make_node("d"));
317
318        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
319        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
320        graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
321
322        // Slice from B with max_depth = 1
323        let result = graph.slice_context(b, 10000, 1, &[]);
324
325        // Should include B (depth 0), A (depth 1), C (depth 1)
326        // D is depth 2, excluded
327        let names: Vec<&str> = result
328            .nodes
329            .iter()
330            .map(|n| n.node_info.name.as_str())
331            .collect();
332        assert!(names.contains(&"b"));
333        assert!(names.contains(&"a"));
334        assert!(names.contains(&"c"));
335        assert!(!names.contains(&"d"));
336    }
337
338    #[test]
339    fn test_token_budget() {
340        let mut graph = ArborGraph::new();
341        let a = graph.add_node(make_node("a"));
342        let b = graph.add_node(make_node("b"));
343        let c = graph.add_node(make_node("c"));
344
345        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
346        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
347
348        // Very small budget - should truncate
349        let result = graph.slice_context(a, 5, 10, &[]);
350
351        // Should hit token budget
352        assert!(result.nodes.len() < 3);
353        assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
354    }
355
356    #[test]
357    fn test_pinned_nodes_bypass_budget() {
358        let mut graph = ArborGraph::new();
359        let a = graph.add_node(make_node("a"));
360        let b = graph.add_node(make_node("important_node"));
361        let c = graph.add_node(make_node("c"));
362
363        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
364        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
365
366        // Very small budget, but b is pinned
367        let result = graph.slice_context(a, 5, 10, &[b]);
368
369        // Pinned node should still be included
370        let has_important = result
371            .nodes
372            .iter()
373            .any(|n| n.node_info.name == "important_node");
374        assert!(has_important);
375    }
376
377    #[test]
378    fn test_complete_traversal() {
379        let mut graph = ArborGraph::new();
380        let a = graph.add_node(make_node("a"));
381        let b = graph.add_node(make_node("b"));
382
383        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
384
385        // Large budget, should complete
386        let result = graph.slice_context(a, 100000, 10, &[]);
387        assert_eq!(result.truncation_reason, TruncationReason::Complete);
388        assert_eq!(result.nodes.len(), 2);
389    }
390}