Skip to main content

arbor_graph/
slice.rs

1//! Dynamic context slicing for LLM prompts.
2//!
3//! This module provides token-bounded context extraction from the code graph.
4//! Given a target node, it collects the minimal set of related nodes that fit
5//! within a token budget.
6
7use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14use tracing::warn;
15
16/// Reason for stopping context collection.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum TruncationReason {
19    /// All reachable nodes within max_depth were included.
20    Complete,
21    /// Stopped because token budget was reached.
22    TokenBudget,
23    /// Stopped because max depth was reached.
24    MaxDepth,
25}
26
27impl std::fmt::Display for TruncationReason {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            TruncationReason::Complete => write!(f, "complete"),
31            TruncationReason::TokenBudget => write!(f, "token_budget"),
32            TruncationReason::MaxDepth => write!(f, "max_depth"),
33        }
34    }
35}
36
37/// A node included in the context slice.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ContextNode {
40    /// Node information.
41    pub node_info: NodeInfo,
42    /// Estimated token count for this node's source.
43    pub token_estimate: usize,
44    /// Hop distance from the target node.
45    pub depth: usize,
46    /// Whether this node was pinned (always included).
47    pub pinned: bool,
48}
49
50/// Result of a context slicing operation.
51#[derive(Debug, Serialize, Deserialize)]
52pub struct ContextSlice {
53    /// The target node being queried.
54    pub target: NodeInfo,
55    /// Nodes included in the context, ordered by relevance.
56    pub nodes: Vec<ContextNode>,
57    /// Total estimated tokens in this slice.
58    pub total_tokens: usize,
59    /// Maximum tokens that were allowed.
60    pub max_tokens: usize,
61    /// Why slicing stopped.
62    pub truncation_reason: TruncationReason,
63    /// Query time in milliseconds.
64    pub query_time_ms: u64,
65}
66
67impl ContextSlice {
68    /// Returns a summary suitable for CLI output.
69    pub fn summary(&self) -> String {
70        format!(
71            "Context: {} nodes, ~{} tokens ({})",
72            self.nodes.len(),
73            self.total_tokens,
74            self.truncation_reason
75        )
76    }
77
78    pub fn pinned_only(&self) -> Vec<&ContextNode> {
79        self.nodes.iter().filter(|n| n.pinned).collect()
80    }
81}
82
83use once_cell::sync::Lazy;
84use tiktoken_rs::cl100k_base;
85
86/// Global tokenizer instance (lazy-loaded once).
87/// Falls back to heuristic mode if tokenizer init fails.
88static TOKENIZER: Lazy<Option<tiktoken_rs::CoreBPE>> = Lazy::new(|| match cl100k_base() {
89    Ok(tokenizer) => Some(tokenizer),
90    Err(error) => {
91        warn!(
92            "Failed to initialize cl100k_base tokenizer; falling back to heuristic token estimates: {}",
93            error
94        );
95        None
96    }
97});
98
99/// Threshold for falling back to heuristic tokenizer (800 KB)
100const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
101
102/// Estimates tokens for a code node using tiktoken cl100k_base.
103///
104/// Falls back to heuristic (4 chars/token) for large content to avoid
105/// performance issues with massive JS bundles.
106fn estimate_tokens(node: &NodeInfo) -> usize {
107    let base = node.name.len() + node.qualified_name.len() + node.file.len();
108    let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
109    let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
110    let estimated_chars = base + signature_len + (lines * 40);
111
112    // Performance guardrail: use heuristic for very large content
113    if estimated_chars > LARGE_FILE_THRESHOLD {
114        return estimated_chars.div_ceil(4); // Heuristic fallback
115    }
116
117    // Build text representation for accurate tokenization
118    let text = format!(
119        "{} {} {}{}",
120        node.qualified_name,
121        node.file,
122        node.signature.as_deref().unwrap_or(""),
123        " ".repeat(lines * 40) // Approximate code content
124    );
125
126    // Use tiktoken for accurate count, fall back to heuristic if unavailable.
127    if let Some(tokenizer) = TOKENIZER.as_ref() {
128        tokenizer.encode_with_special_tokens(&text).len()
129    } else {
130        estimated_chars.div_ceil(4)
131    }
132}
133
134impl ArborGraph {
135    /// Extracts a token-bounded context slice around a target node.
136    ///
137    /// Collects nodes in BFS order:
138    /// 1. Target node itself
139    /// 2. Direct upstream (callers) at depth 1
140    /// 3. Direct downstream (callees) at depth 1
141    /// 4. Continues outward until budget or max_depth reached
142    ///
143    /// Pinned nodes are always included regardless of budget.
144    ///
145    /// # Arguments
146    /// * `target` - The node to center the slice around
147    /// * `max_tokens` - Maximum token budget (0 = unlimited)
148    /// * `max_depth` - Maximum hop distance (0 = unlimited, default: 2)
149    /// * `pinned` - Nodes that must be included regardless of budget
150    pub fn slice_context(
151        &self,
152        target: NodeId,
153        max_tokens: usize,
154        max_depth: usize,
155        pinned: &[NodeId],
156    ) -> ContextSlice {
157        let start = Instant::now();
158
159        let target_node = match self.get(target) {
160            Some(node) => {
161                let mut info = NodeInfo::from(node);
162                info.centrality = self.centrality(target);
163                info
164            }
165            None => {
166                return ContextSlice {
167                    target: NodeInfo {
168                        id: String::new(),
169                        name: String::new(),
170                        qualified_name: String::new(),
171                        kind: String::new(),
172                        file: String::new(),
173                        line_start: 0,
174                        line_end: 0,
175                        signature: None,
176                        centrality: 0.0,
177                    },
178                    nodes: Vec::new(),
179                    total_tokens: 0,
180                    max_tokens,
181                    truncation_reason: TruncationReason::Complete,
182                    query_time_ms: 0,
183                };
184            }
185        };
186
187        let effective_max = if max_depth == 0 {
188            usize::MAX
189        } else {
190            max_depth
191        };
192        let effective_tokens = if max_tokens == 0 {
193            usize::MAX
194        } else {
195            max_tokens
196        };
197
198        let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
199        let mut visited: HashSet<NodeId> = HashSet::new();
200        let mut result: Vec<ContextNode> = Vec::new();
201        let mut total_tokens = 0usize;
202        let mut truncation_reason = TruncationReason::Complete;
203
204        // BFS queue: (node_id, depth)
205        let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
206
207        // Start with target
208        queue.push_back((target, 0));
209
210        while let Some((current, depth)) = queue.pop_front() {
211            if visited.contains(&current) {
212                continue;
213            }
214
215            if depth > effective_max {
216                truncation_reason = TruncationReason::MaxDepth;
217                continue;
218            }
219
220            visited.insert(current);
221
222            let is_pinned = pinned_set.contains(&current);
223
224            if let Some(node) = self.get(current) {
225                let mut node_info = NodeInfo::from(node);
226                node_info.centrality = self.centrality(current);
227
228                let token_est = estimate_tokens(&node_info);
229
230                // Check budget (pinned nodes bypass budget)
231                let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
232
233                if within_budget {
234                    total_tokens += token_est;
235
236                    result.push(ContextNode {
237                        node_info,
238                        token_estimate: token_est,
239                        depth,
240                        pinned: is_pinned,
241                    });
242                } else {
243                    truncation_reason = TruncationReason::TokenBudget;
244                    // Don't add to result, but STILL explore neighbors to find pinned nodes
245                }
246            }
247
248            // Always add neighbors to queue (to find pinned nodes even when budget exceeded)
249            if depth < effective_max {
250                // Upstream (incoming)
251                for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
252                    let neighbor = edge_ref.source();
253                    if !visited.contains(&neighbor) {
254                        queue.push_back((neighbor, depth + 1));
255                    }
256                }
257
258                // Downstream (outgoing)
259                for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
260                    let neighbor = edge_ref.target();
261                    if !visited.contains(&neighbor) {
262                        queue.push_back((neighbor, depth + 1));
263                    }
264                }
265            }
266        }
267
268        // Sort by: pinned first, then by depth, then by centrality (desc)
269        result.sort_by(|a, b| {
270            b.pinned
271                .cmp(&a.pinned)
272                .then_with(|| a.depth.cmp(&b.depth))
273                .then_with(|| {
274                    b.node_info
275                        .centrality
276                        .partial_cmp(&a.node_info.centrality)
277                        .unwrap_or(std::cmp::Ordering::Equal)
278                })
279        });
280
281        let elapsed = start.elapsed().as_millis() as u64;
282
283        ContextSlice {
284            target: target_node,
285            nodes: result,
286            total_tokens,
287            max_tokens,
288            truncation_reason,
289            query_time_ms: elapsed,
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::edge::{Edge, EdgeKind};
298    use arbor_core::{CodeNode, NodeKind};
299
300    fn make_node(name: &str) -> CodeNode {
301        CodeNode::new(name, name, NodeKind::Function, "test.rs")
302    }
303
304    #[test]
305    fn test_empty_graph() {
306        let graph = ArborGraph::new();
307        let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
308        assert!(result.nodes.is_empty());
309        assert_eq!(result.total_tokens, 0);
310    }
311
312    #[test]
313    fn test_single_node() {
314        let mut graph = ArborGraph::new();
315        let id = graph.add_node(make_node("lonely"));
316
317        let result = graph.slice_context(id, 1000, 2, &[]);
318        assert_eq!(result.nodes.len(), 1);
319        assert_eq!(result.nodes[0].node_info.name, "lonely");
320        assert_eq!(result.truncation_reason, TruncationReason::Complete);
321    }
322
323    #[test]
324    fn test_linear_chain_depth_limit() {
325        // A → B → C → D
326        let mut graph = ArborGraph::new();
327        let a = graph.add_node(make_node("a"));
328        let b = graph.add_node(make_node("b"));
329        let c = graph.add_node(make_node("c"));
330        let d = graph.add_node(make_node("d"));
331
332        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
333        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
334        graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
335
336        // Slice from B with max_depth = 1
337        let result = graph.slice_context(b, 10000, 1, &[]);
338
339        // Should include B (depth 0), A (depth 1), C (depth 1)
340        // D is depth 2, excluded
341        let names: Vec<&str> = result
342            .nodes
343            .iter()
344            .map(|n| n.node_info.name.as_str())
345            .collect();
346        assert!(names.contains(&"b"));
347        assert!(names.contains(&"a"));
348        assert!(names.contains(&"c"));
349        assert!(!names.contains(&"d"));
350    }
351
352    #[test]
353    fn test_token_budget() {
354        let mut graph = ArborGraph::new();
355        let a = graph.add_node(make_node("a"));
356        let b = graph.add_node(make_node("b"));
357        let c = graph.add_node(make_node("c"));
358
359        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
360        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
361
362        // Very small budget - should truncate
363        let result = graph.slice_context(a, 5, 10, &[]);
364
365        // Should hit token budget
366        assert!(result.nodes.len() < 3);
367        assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
368    }
369
370    #[test]
371    fn test_pinned_nodes_bypass_budget() {
372        let mut graph = ArborGraph::new();
373        let a = graph.add_node(make_node("a"));
374        let b = graph.add_node(make_node("important_node"));
375        let c = graph.add_node(make_node("c"));
376
377        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
378        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
379
380        // Very small budget, but b is pinned
381        let result = graph.slice_context(a, 5, 10, &[b]);
382
383        // Pinned node should still be included
384        let has_important = result
385            .nodes
386            .iter()
387            .any(|n| n.node_info.name == "important_node");
388        assert!(has_important);
389    }
390
391    #[test]
392    fn test_complete_traversal() {
393        let mut graph = ArborGraph::new();
394        let a = graph.add_node(make_node("a"));
395        let b = graph.add_node(make_node("b"));
396
397        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
398
399        // Large budget, should complete
400        let result = graph.slice_context(a, 100000, 10, &[]);
401        assert_eq!(result.truncation_reason, TruncationReason::Complete);
402        assert_eq!(result.nodes.len(), 2);
403    }
404}