1use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum TruncationReason {
18 Complete,
20 TokenBudget,
22 MaxDepth,
24}
25
26impl std::fmt::Display for TruncationReason {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 TruncationReason::Complete => write!(f, "complete"),
30 TruncationReason::TokenBudget => write!(f, "token_budget"),
31 TruncationReason::MaxDepth => write!(f, "max_depth"),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ContextNode {
39 pub node_info: NodeInfo,
41 pub token_estimate: usize,
43 pub depth: usize,
45 pub pinned: bool,
47}
48
49#[derive(Debug, Serialize, Deserialize)]
51pub struct ContextSlice {
52 pub target: NodeInfo,
54 pub nodes: Vec<ContextNode>,
56 pub total_tokens: usize,
58 pub max_tokens: usize,
60 pub truncation_reason: TruncationReason,
62 pub query_time_ms: u64,
64}
65
66impl ContextSlice {
67 pub fn summary(&self) -> String {
69 format!(
70 "Context: {} nodes, ~{} tokens ({})",
71 self.nodes.len(),
72 self.total_tokens,
73 self.truncation_reason
74 )
75 }
76
77 pub fn pinned_only(&self) -> Vec<&ContextNode> {
78 self.nodes.iter().filter(|n| n.pinned).collect()
79 }
80}
81
82use once_cell::sync::Lazy;
83use tiktoken_rs::cl100k_base;
84
85static TOKENIZER: Lazy<tiktoken_rs::CoreBPE> =
87 Lazy::new(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"));
88
89const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
91
92fn estimate_tokens(node: &NodeInfo) -> usize {
97 let base = node.name.len() + node.qualified_name.len() + node.file.len();
98 let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
99 let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
100 let estimated_chars = base + signature_len + (lines * 40);
101
102 if estimated_chars > LARGE_FILE_THRESHOLD {
104 return estimated_chars.div_ceil(4); }
106
107 let text = format!(
109 "{} {} {}{}",
110 node.qualified_name,
111 node.file,
112 node.signature.as_deref().unwrap_or(""),
113 " ".repeat(lines * 40) );
115
116 TOKENIZER.encode_with_special_tokens(&text).len()
118}
119
120impl ArborGraph {
121 pub fn slice_context(
137 &self,
138 target: NodeId,
139 max_tokens: usize,
140 max_depth: usize,
141 pinned: &[NodeId],
142 ) -> ContextSlice {
143 let start = Instant::now();
144
145 let target_node = match self.get(target) {
146 Some(node) => {
147 let mut info = NodeInfo::from(node);
148 info.centrality = self.centrality(target);
149 info
150 }
151 None => {
152 return ContextSlice {
153 target: NodeInfo {
154 id: String::new(),
155 name: String::new(),
156 qualified_name: String::new(),
157 kind: String::new(),
158 file: String::new(),
159 line_start: 0,
160 line_end: 0,
161 signature: None,
162 centrality: 0.0,
163 },
164 nodes: Vec::new(),
165 total_tokens: 0,
166 max_tokens,
167 truncation_reason: TruncationReason::Complete,
168 query_time_ms: 0,
169 };
170 }
171 };
172
173 let effective_max = if max_depth == 0 {
174 usize::MAX
175 } else {
176 max_depth
177 };
178 let effective_tokens = if max_tokens == 0 {
179 usize::MAX
180 } else {
181 max_tokens
182 };
183
184 let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
185 let mut visited: HashSet<NodeId> = HashSet::new();
186 let mut result: Vec<ContextNode> = Vec::new();
187 let mut total_tokens = 0usize;
188 let mut truncation_reason = TruncationReason::Complete;
189
190 let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
192
193 queue.push_back((target, 0));
195
196 while let Some((current, depth)) = queue.pop_front() {
197 if visited.contains(¤t) {
198 continue;
199 }
200
201 if depth > effective_max {
202 truncation_reason = TruncationReason::MaxDepth;
203 continue;
204 }
205
206 visited.insert(current);
207
208 let is_pinned = pinned_set.contains(¤t);
209
210 if let Some(node) = self.get(current) {
211 let mut node_info = NodeInfo::from(node);
212 node_info.centrality = self.centrality(current);
213
214 let token_est = estimate_tokens(&node_info);
215
216 let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
218
219 if within_budget {
220 total_tokens += token_est;
221
222 result.push(ContextNode {
223 node_info,
224 token_estimate: token_est,
225 depth,
226 pinned: is_pinned,
227 });
228 } else {
229 truncation_reason = TruncationReason::TokenBudget;
230 }
232 }
233
234 if depth < effective_max {
236 for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
238 let neighbor = edge_ref.source();
239 if !visited.contains(&neighbor) {
240 queue.push_back((neighbor, depth + 1));
241 }
242 }
243
244 for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
246 let neighbor = edge_ref.target();
247 if !visited.contains(&neighbor) {
248 queue.push_back((neighbor, depth + 1));
249 }
250 }
251 }
252 }
253
254 result.sort_by(|a, b| {
256 b.pinned
257 .cmp(&a.pinned)
258 .then_with(|| a.depth.cmp(&b.depth))
259 .then_with(|| {
260 b.node_info
261 .centrality
262 .partial_cmp(&a.node_info.centrality)
263 .unwrap_or(std::cmp::Ordering::Equal)
264 })
265 });
266
267 let elapsed = start.elapsed().as_millis() as u64;
268
269 ContextSlice {
270 target: target_node,
271 nodes: result,
272 total_tokens,
273 max_tokens,
274 truncation_reason,
275 query_time_ms: elapsed,
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::edge::{Edge, EdgeKind};
284 use arbor_core::{CodeNode, NodeKind};
285
286 fn make_node(name: &str) -> CodeNode {
287 CodeNode::new(name, name, NodeKind::Function, "test.rs")
288 }
289
290 #[test]
291 fn test_empty_graph() {
292 let graph = ArborGraph::new();
293 let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
294 assert!(result.nodes.is_empty());
295 assert_eq!(result.total_tokens, 0);
296 }
297
298 #[test]
299 fn test_single_node() {
300 let mut graph = ArborGraph::new();
301 let id = graph.add_node(make_node("lonely"));
302
303 let result = graph.slice_context(id, 1000, 2, &[]);
304 assert_eq!(result.nodes.len(), 1);
305 assert_eq!(result.nodes[0].node_info.name, "lonely");
306 assert_eq!(result.truncation_reason, TruncationReason::Complete);
307 }
308
309 #[test]
310 fn test_linear_chain_depth_limit() {
311 let mut graph = ArborGraph::new();
313 let a = graph.add_node(make_node("a"));
314 let b = graph.add_node(make_node("b"));
315 let c = graph.add_node(make_node("c"));
316 let d = graph.add_node(make_node("d"));
317
318 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
319 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
320 graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
321
322 let result = graph.slice_context(b, 10000, 1, &[]);
324
325 let names: Vec<&str> = result
328 .nodes
329 .iter()
330 .map(|n| n.node_info.name.as_str())
331 .collect();
332 assert!(names.contains(&"b"));
333 assert!(names.contains(&"a"));
334 assert!(names.contains(&"c"));
335 assert!(!names.contains(&"d"));
336 }
337
338 #[test]
339 fn test_token_budget() {
340 let mut graph = ArborGraph::new();
341 let a = graph.add_node(make_node("a"));
342 let b = graph.add_node(make_node("b"));
343 let c = graph.add_node(make_node("c"));
344
345 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
346 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
347
348 let result = graph.slice_context(a, 5, 10, &[]);
350
351 assert!(result.nodes.len() < 3);
353 assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
354 }
355
356 #[test]
357 fn test_pinned_nodes_bypass_budget() {
358 let mut graph = ArborGraph::new();
359 let a = graph.add_node(make_node("a"));
360 let b = graph.add_node(make_node("important_node"));
361 let c = graph.add_node(make_node("c"));
362
363 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
364 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
365
366 let result = graph.slice_context(a, 5, 10, &[b]);
368
369 let has_important = result
371 .nodes
372 .iter()
373 .any(|n| n.node_info.name == "important_node");
374 assert!(has_important);
375 }
376
377 #[test]
378 fn test_complete_traversal() {
379 let mut graph = ArborGraph::new();
380 let a = graph.add_node(make_node("a"));
381 let b = graph.add_node(make_node("b"));
382
383 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
384
385 let result = graph.slice_context(a, 100000, 10, &[]);
387 assert_eq!(result.truncation_reason, TruncationReason::Complete);
388 assert_eq!(result.nodes.len(), 2);
389 }
390}