1use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14use tracing::warn;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum TruncationReason {
19 Complete,
21 TokenBudget,
23 MaxDepth,
25}
26
27impl std::fmt::Display for TruncationReason {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 TruncationReason::Complete => write!(f, "complete"),
31 TruncationReason::TokenBudget => write!(f, "token_budget"),
32 TruncationReason::MaxDepth => write!(f, "max_depth"),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ContextNode {
40 pub node_info: NodeInfo,
42 pub token_estimate: usize,
44 pub depth: usize,
46 pub pinned: bool,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
52pub struct ContextSlice {
53 pub target: NodeInfo,
55 pub nodes: Vec<ContextNode>,
57 pub total_tokens: usize,
59 pub max_tokens: usize,
61 pub truncation_reason: TruncationReason,
63 pub query_time_ms: u64,
65}
66
67impl ContextSlice {
68 pub fn summary(&self) -> String {
70 format!(
71 "Context: {} nodes, ~{} tokens ({})",
72 self.nodes.len(),
73 self.total_tokens,
74 self.truncation_reason
75 )
76 }
77
78 pub fn pinned_only(&self) -> Vec<&ContextNode> {
79 self.nodes.iter().filter(|n| n.pinned).collect()
80 }
81}
82
83use once_cell::sync::Lazy;
84use tiktoken_rs::cl100k_base;
85
86static TOKENIZER: Lazy<Option<tiktoken_rs::CoreBPE>> = Lazy::new(|| match cl100k_base() {
89 Ok(tokenizer) => Some(tokenizer),
90 Err(error) => {
91 warn!(
92 "Failed to initialize cl100k_base tokenizer; falling back to heuristic token estimates: {}",
93 error
94 );
95 None
96 }
97});
98
99const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
101
102fn estimate_tokens(node: &NodeInfo) -> usize {
107 let base = node.name.len() + node.qualified_name.len() + node.file.len();
108 let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
109 let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
110 let estimated_chars = base + signature_len + (lines * 40);
111
112 if estimated_chars > LARGE_FILE_THRESHOLD {
114 return estimated_chars.div_ceil(4); }
116
117 let text = format!(
119 "{} {} {}{}",
120 node.qualified_name,
121 node.file,
122 node.signature.as_deref().unwrap_or(""),
123 " ".repeat(lines * 40) );
125
126 if let Some(tokenizer) = TOKENIZER.as_ref() {
128 tokenizer.encode_with_special_tokens(&text).len()
129 } else {
130 estimated_chars.div_ceil(4)
131 }
132}
133
134impl ArborGraph {
135 pub fn slice_context(
151 &self,
152 target: NodeId,
153 max_tokens: usize,
154 max_depth: usize,
155 pinned: &[NodeId],
156 ) -> ContextSlice {
157 let start = Instant::now();
158
159 let target_node = match self.get(target) {
160 Some(node) => {
161 let mut info = NodeInfo::from(node);
162 info.centrality = self.centrality(target);
163 info
164 }
165 None => {
166 return ContextSlice {
167 target: NodeInfo {
168 id: String::new(),
169 name: String::new(),
170 qualified_name: String::new(),
171 kind: String::new(),
172 file: String::new(),
173 line_start: 0,
174 line_end: 0,
175 signature: None,
176 centrality: 0.0,
177 },
178 nodes: Vec::new(),
179 total_tokens: 0,
180 max_tokens,
181 truncation_reason: TruncationReason::Complete,
182 query_time_ms: 0,
183 };
184 }
185 };
186
187 let effective_max = if max_depth == 0 {
188 usize::MAX
189 } else {
190 max_depth
191 };
192 let effective_tokens = if max_tokens == 0 {
193 usize::MAX
194 } else {
195 max_tokens
196 };
197
198 let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
199 let mut visited: HashSet<NodeId> = HashSet::new();
200 let mut result: Vec<ContextNode> = Vec::new();
201 let mut total_tokens = 0usize;
202 let mut truncation_reason = TruncationReason::Complete;
203
204 let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
206
207 queue.push_back((target, 0));
209
210 while let Some((current, depth)) = queue.pop_front() {
211 if visited.contains(¤t) {
212 continue;
213 }
214
215 if depth > effective_max {
216 truncation_reason = TruncationReason::MaxDepth;
217 continue;
218 }
219
220 visited.insert(current);
221
222 let is_pinned = pinned_set.contains(¤t);
223
224 if let Some(node) = self.get(current) {
225 let mut node_info = NodeInfo::from(node);
226 node_info.centrality = self.centrality(current);
227
228 let token_est = estimate_tokens(&node_info);
229
230 let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
232
233 if within_budget {
234 total_tokens += token_est;
235
236 result.push(ContextNode {
237 node_info,
238 token_estimate: token_est,
239 depth,
240 pinned: is_pinned,
241 });
242 } else {
243 truncation_reason = TruncationReason::TokenBudget;
244 }
246 }
247
248 if depth < effective_max {
250 for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
252 let neighbor = edge_ref.source();
253 if !visited.contains(&neighbor) {
254 queue.push_back((neighbor, depth + 1));
255 }
256 }
257
258 for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
260 let neighbor = edge_ref.target();
261 if !visited.contains(&neighbor) {
262 queue.push_back((neighbor, depth + 1));
263 }
264 }
265 }
266 }
267
268 result.sort_by(|a, b| {
270 b.pinned
271 .cmp(&a.pinned)
272 .then_with(|| a.depth.cmp(&b.depth))
273 .then_with(|| {
274 b.node_info
275 .centrality
276 .partial_cmp(&a.node_info.centrality)
277 .unwrap_or(std::cmp::Ordering::Equal)
278 })
279 });
280
281 let elapsed = start.elapsed().as_millis() as u64;
282
283 ContextSlice {
284 target: target_node,
285 nodes: result,
286 total_tokens,
287 max_tokens,
288 truncation_reason,
289 query_time_ms: elapsed,
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::edge::{Edge, EdgeKind};
298 use arbor_core::{CodeNode, NodeKind};
299
300 fn make_node(name: &str) -> CodeNode {
301 CodeNode::new(name, name, NodeKind::Function, "test.rs")
302 }
303
304 #[test]
305 fn test_empty_graph() {
306 let graph = ArborGraph::new();
307 let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
308 assert!(result.nodes.is_empty());
309 assert_eq!(result.total_tokens, 0);
310 }
311
312 #[test]
313 fn test_single_node() {
314 let mut graph = ArborGraph::new();
315 let id = graph.add_node(make_node("lonely"));
316
317 let result = graph.slice_context(id, 1000, 2, &[]);
318 assert_eq!(result.nodes.len(), 1);
319 assert_eq!(result.nodes[0].node_info.name, "lonely");
320 assert_eq!(result.truncation_reason, TruncationReason::Complete);
321 }
322
323 #[test]
324 fn test_linear_chain_depth_limit() {
325 let mut graph = ArborGraph::new();
327 let a = graph.add_node(make_node("a"));
328 let b = graph.add_node(make_node("b"));
329 let c = graph.add_node(make_node("c"));
330 let d = graph.add_node(make_node("d"));
331
332 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
333 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
334 graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
335
336 let result = graph.slice_context(b, 10000, 1, &[]);
338
339 let names: Vec<&str> = result
342 .nodes
343 .iter()
344 .map(|n| n.node_info.name.as_str())
345 .collect();
346 assert!(names.contains(&"b"));
347 assert!(names.contains(&"a"));
348 assert!(names.contains(&"c"));
349 assert!(!names.contains(&"d"));
350 }
351
352 #[test]
353 fn test_token_budget() {
354 let mut graph = ArborGraph::new();
355 let a = graph.add_node(make_node("a"));
356 let b = graph.add_node(make_node("b"));
357 let c = graph.add_node(make_node("c"));
358
359 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
360 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
361
362 let result = graph.slice_context(a, 5, 10, &[]);
364
365 assert!(result.nodes.len() < 3);
367 assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
368 }
369
370 #[test]
371 fn test_pinned_nodes_bypass_budget() {
372 let mut graph = ArborGraph::new();
373 let a = graph.add_node(make_node("a"));
374 let b = graph.add_node(make_node("important_node"));
375 let c = graph.add_node(make_node("c"));
376
377 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
378 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
379
380 let result = graph.slice_context(a, 5, 10, &[b]);
382
383 let has_important = result
385 .nodes
386 .iter()
387 .any(|n| n.node_info.name == "important_node");
388 assert!(has_important);
389 }
390
391 #[test]
392 fn test_complete_traversal() {
393 let mut graph = ArborGraph::new();
394 let a = graph.add_node(make_node("a"));
395 let b = graph.add_node(make_node("b"));
396
397 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
398
399 let result = graph.slice_context(a, 100000, 10, &[]);
401 assert_eq!(result.truncation_reason, TruncationReason::Complete);
402 assert_eq!(result.nodes.len(), 2);
403 }
404}