Skip to main content

arbor_graph/
slice.rs

1//! Dynamic context slicing for LLM prompts.
2//!
3//! This module provides token-bounded context extraction from the code graph.
4//! Given a target node, it collects the minimal set of related nodes that fit
5//! within a token budget.
6
7use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14
15/// Reason for stopping context collection.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum TruncationReason {
18    /// All reachable nodes within max_depth were included.
19    Complete,
20    /// Stopped because token budget was reached.
21    TokenBudget,
22    /// Stopped because max depth was reached.
23    MaxDepth,
24}
25
26impl std::fmt::Display for TruncationReason {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            TruncationReason::Complete => write!(f, "complete"),
30            TruncationReason::TokenBudget => write!(f, "token_budget"),
31            TruncationReason::MaxDepth => write!(f, "max_depth"),
32        }
33    }
34}
35
36/// A node included in the context slice.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ContextNode {
39    /// Node information.
40    pub node_info: NodeInfo,
41    /// Estimated token count for this node's source.
42    pub token_estimate: usize,
43    /// Hop distance from the target node.
44    pub depth: usize,
45    /// Whether this node was pinned (always included).
46    pub pinned: bool,
47}
48
49/// Result of a context slicing operation.
50#[derive(Debug, Serialize, Deserialize)]
51pub struct ContextSlice {
52    /// The target node being queried.
53    pub target: NodeInfo,
54    /// Nodes included in the context, ordered by relevance.
55    pub nodes: Vec<ContextNode>,
56    /// Total estimated tokens in this slice.
57    pub total_tokens: usize,
58    /// Maximum tokens that were allowed.
59    pub max_tokens: usize,
60    /// Why slicing stopped.
61    pub truncation_reason: TruncationReason,
62    /// Query time in milliseconds.
63    pub query_time_ms: u64,
64}
65
66impl ContextSlice {
67    /// Returns a summary suitable for CLI output.
68    pub fn summary(&self) -> String {
69        format!(
70            "Context: {} nodes, ~{} tokens ({})",
71            self.nodes.len(),
72            self.total_tokens,
73            self.truncation_reason
74        )
75    }
76
77    pub fn pinned_only(&self) -> Vec<&ContextNode> {
78        self.nodes.iter().filter(|n| n.pinned).collect()
79    }
80}
81
82use once_cell::sync::Lazy;
83use tiktoken_rs::cl100k_base;
84
85/// Global tokenizer instance (lazy-loaded once)
86static TOKENIZER: Lazy<tiktoken_rs::CoreBPE> = Lazy::new(|| {
87    cl100k_base().expect("Failed to load cl100k_base tokenizer")
88});
89
90/// Threshold for falling back to heuristic tokenizer (800 KB)
91const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
92
93/// Estimates tokens for a code node using tiktoken cl100k_base.
94///
95/// Falls back to heuristic (4 chars/token) for large content to avoid
96/// performance issues with massive JS bundles.
97fn estimate_tokens(node: &NodeInfo) -> usize {
98    let base = node.name.len() + node.qualified_name.len() + node.file.len();
99    let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
100    let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
101    let estimated_chars = base + signature_len + (lines * 40);
102
103    // Performance guardrail: use heuristic for very large content
104    if estimated_chars > LARGE_FILE_THRESHOLD {
105        return (estimated_chars + 3) / 4; // Heuristic fallback
106    }
107
108    // Build text representation for accurate tokenization
109    let text = format!(
110        "{} {} {}{}",
111        node.qualified_name,
112        node.file,
113        node.signature.as_deref().unwrap_or(""),
114        " ".repeat(lines * 40) // Approximate code content
115    );
116
117    // Use tiktoken for accurate count
118    TOKENIZER.encode_with_special_tokens(&text).len()
119}
120
121impl ArborGraph {
122    /// Extracts a token-bounded context slice around a target node.
123    ///
124    /// Collects nodes in BFS order:
125    /// 1. Target node itself
126    /// 2. Direct upstream (callers) at depth 1
127    /// 3. Direct downstream (callees) at depth 1
128    /// 4. Continues outward until budget or max_depth reached
129    ///
130    /// Pinned nodes are always included regardless of budget.
131    ///
132    /// # Arguments
133    /// * `target` - The node to center the slice around
134    /// * `max_tokens` - Maximum token budget (0 = unlimited)
135    /// * `max_depth` - Maximum hop distance (0 = unlimited, default: 2)
136    /// * `pinned` - Nodes that must be included regardless of budget
137    pub fn slice_context(
138        &self,
139        target: NodeId,
140        max_tokens: usize,
141        max_depth: usize,
142        pinned: &[NodeId],
143    ) -> ContextSlice {
144        let start = Instant::now();
145
146        let target_node = match self.get(target) {
147            Some(node) => {
148                let mut info = NodeInfo::from(node);
149                info.centrality = self.centrality(target);
150                info
151            }
152            None => {
153                return ContextSlice {
154                    target: NodeInfo {
155                        id: String::new(),
156                        name: String::new(),
157                        qualified_name: String::new(),
158                        kind: String::new(),
159                        file: String::new(),
160                        line_start: 0,
161                        line_end: 0,
162                        signature: None,
163                        centrality: 0.0,
164                    },
165                    nodes: Vec::new(),
166                    total_tokens: 0,
167                    max_tokens,
168                    truncation_reason: TruncationReason::Complete,
169                    query_time_ms: 0,
170                };
171            }
172        };
173
174        let effective_max = if max_depth == 0 {
175            usize::MAX
176        } else {
177            max_depth
178        };
179        let effective_tokens = if max_tokens == 0 {
180            usize::MAX
181        } else {
182            max_tokens
183        };
184
185        let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
186        let mut visited: HashSet<NodeId> = HashSet::new();
187        let mut result: Vec<ContextNode> = Vec::new();
188        let mut total_tokens = 0usize;
189        let mut truncation_reason = TruncationReason::Complete;
190
191        // BFS queue: (node_id, depth)
192        let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
193
194        // Start with target
195        queue.push_back((target, 0));
196
197        while let Some((current, depth)) = queue.pop_front() {
198            if visited.contains(&current) {
199                continue;
200            }
201
202            if depth > effective_max {
203                truncation_reason = TruncationReason::MaxDepth;
204                continue;
205            }
206
207            visited.insert(current);
208
209            let is_pinned = pinned_set.contains(&current);
210
211            if let Some(node) = self.get(current) {
212                let mut node_info = NodeInfo::from(node);
213                node_info.centrality = self.centrality(current);
214
215                let token_est = estimate_tokens(&node_info);
216
217                // Check budget (pinned nodes bypass budget)
218                let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
219
220                if within_budget {
221                    total_tokens += token_est;
222
223                    result.push(ContextNode {
224                        node_info,
225                        token_estimate: token_est,
226                        depth,
227                        pinned: is_pinned,
228                    });
229                } else {
230                    truncation_reason = TruncationReason::TokenBudget;
231                    // Don't add to result, but STILL explore neighbors to find pinned nodes
232                }
233            }
234
235            // Always add neighbors to queue (to find pinned nodes even when budget exceeded)
236            if depth < effective_max {
237                // Upstream (incoming)
238                for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
239                    let neighbor = edge_ref.source();
240                    if !visited.contains(&neighbor) {
241                        queue.push_back((neighbor, depth + 1));
242                    }
243                }
244
245                // Downstream (outgoing)
246                for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
247                    let neighbor = edge_ref.target();
248                    if !visited.contains(&neighbor) {
249                        queue.push_back((neighbor, depth + 1));
250                    }
251                }
252            }
253        }
254
255        // Sort by: pinned first, then by depth, then by centrality (desc)
256        result.sort_by(|a, b| {
257            b.pinned
258                .cmp(&a.pinned)
259                .then_with(|| a.depth.cmp(&b.depth))
260                .then_with(|| {
261                    b.node_info
262                        .centrality
263                        .partial_cmp(&a.node_info.centrality)
264                        .unwrap_or(std::cmp::Ordering::Equal)
265                })
266        });
267
268        let elapsed = start.elapsed().as_millis() as u64;
269
270        ContextSlice {
271            target: target_node,
272            nodes: result,
273            total_tokens,
274            max_tokens,
275            truncation_reason,
276            query_time_ms: elapsed,
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::edge::{Edge, EdgeKind};
285    use arbor_core::{CodeNode, NodeKind};
286
287    fn make_node(name: &str) -> CodeNode {
288        CodeNode::new(name, name, NodeKind::Function, "test.rs")
289    }
290
291    #[test]
292    fn test_empty_graph() {
293        let graph = ArborGraph::new();
294        let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
295        assert!(result.nodes.is_empty());
296        assert_eq!(result.total_tokens, 0);
297    }
298
299    #[test]
300    fn test_single_node() {
301        let mut graph = ArborGraph::new();
302        let id = graph.add_node(make_node("lonely"));
303
304        let result = graph.slice_context(id, 1000, 2, &[]);
305        assert_eq!(result.nodes.len(), 1);
306        assert_eq!(result.nodes[0].node_info.name, "lonely");
307        assert_eq!(result.truncation_reason, TruncationReason::Complete);
308    }
309
310    #[test]
311    fn test_linear_chain_depth_limit() {
312        // A → B → C → D
313        let mut graph = ArborGraph::new();
314        let a = graph.add_node(make_node("a"));
315        let b = graph.add_node(make_node("b"));
316        let c = graph.add_node(make_node("c"));
317        let d = graph.add_node(make_node("d"));
318
319        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
320        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
321        graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
322
323        // Slice from B with max_depth = 1
324        let result = graph.slice_context(b, 10000, 1, &[]);
325
326        // Should include B (depth 0), A (depth 1), C (depth 1)
327        // D is depth 2, excluded
328        let names: Vec<&str> = result
329            .nodes
330            .iter()
331            .map(|n| n.node_info.name.as_str())
332            .collect();
333        assert!(names.contains(&"b"));
334        assert!(names.contains(&"a"));
335        assert!(names.contains(&"c"));
336        assert!(!names.contains(&"d"));
337    }
338
339    #[test]
340    fn test_token_budget() {
341        let mut graph = ArborGraph::new();
342        let a = graph.add_node(make_node("a"));
343        let b = graph.add_node(make_node("b"));
344        let c = graph.add_node(make_node("c"));
345
346        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
347        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
348
349        // Very small budget - should truncate
350        let result = graph.slice_context(a, 5, 10, &[]);
351
352        // Should hit token budget
353        assert!(result.nodes.len() < 3);
354        assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
355    }
356
357    #[test]
358    fn test_pinned_nodes_bypass_budget() {
359        let mut graph = ArborGraph::new();
360        let a = graph.add_node(make_node("a"));
361        let b = graph.add_node(make_node("important_node"));
362        let c = graph.add_node(make_node("c"));
363
364        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
365        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
366
367        // Very small budget, but b is pinned
368        let result = graph.slice_context(a, 5, 10, &[b]);
369
370        // Pinned node should still be included
371        let has_important = result
372            .nodes
373            .iter()
374            .any(|n| n.node_info.name == "important_node");
375        assert!(has_important);
376    }
377
378    #[test]
379    fn test_complete_traversal() {
380        let mut graph = ArborGraph::new();
381        let a = graph.add_node(make_node("a"));
382        let b = graph.add_node(make_node("b"));
383
384        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
385
386        // Large budget, should complete
387        let result = graph.slice_context(a, 100000, 10, &[]);
388        assert_eq!(result.truncation_reason, TruncationReason::Complete);
389        assert_eq!(result.nodes.len(), 2);
390    }
391}