1use crate::graph::{ArborGraph, NodeId};
8use crate::query::NodeInfo;
9use petgraph::visit::EdgeRef;
10use petgraph::Direction;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashSet, VecDeque};
13use std::time::Instant;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum TruncationReason {
18 Complete,
20 TokenBudget,
22 MaxDepth,
24}
25
26impl std::fmt::Display for TruncationReason {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 TruncationReason::Complete => write!(f, "complete"),
30 TruncationReason::TokenBudget => write!(f, "token_budget"),
31 TruncationReason::MaxDepth => write!(f, "max_depth"),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ContextNode {
39 pub node_info: NodeInfo,
41 pub token_estimate: usize,
43 pub depth: usize,
45 pub pinned: bool,
47}
48
49#[derive(Debug, Serialize, Deserialize)]
51pub struct ContextSlice {
52 pub target: NodeInfo,
54 pub nodes: Vec<ContextNode>,
56 pub total_tokens: usize,
58 pub max_tokens: usize,
60 pub truncation_reason: TruncationReason,
62 pub query_time_ms: u64,
64}
65
66impl ContextSlice {
67 pub fn summary(&self) -> String {
69 format!(
70 "Context: {} nodes, ~{} tokens ({})",
71 self.nodes.len(),
72 self.total_tokens,
73 self.truncation_reason
74 )
75 }
76
77 pub fn pinned_only(&self) -> Vec<&ContextNode> {
78 self.nodes.iter().filter(|n| n.pinned).collect()
79 }
80}
81
82use once_cell::sync::Lazy;
83use tiktoken_rs::cl100k_base;
84
85static TOKENIZER: Lazy<tiktoken_rs::CoreBPE> = Lazy::new(|| {
87 cl100k_base().expect("Failed to load cl100k_base tokenizer")
88});
89
90const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
92
93fn estimate_tokens(node: &NodeInfo) -> usize {
98 let base = node.name.len() + node.qualified_name.len() + node.file.len();
99 let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
100 let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
101 let estimated_chars = base + signature_len + (lines * 40);
102
103 if estimated_chars > LARGE_FILE_THRESHOLD {
105 return (estimated_chars + 3) / 4; }
107
108 let text = format!(
110 "{} {} {}{}",
111 node.qualified_name,
112 node.file,
113 node.signature.as_deref().unwrap_or(""),
114 " ".repeat(lines * 40) );
116
117 TOKENIZER.encode_with_special_tokens(&text).len()
119}
120
121impl ArborGraph {
122 pub fn slice_context(
138 &self,
139 target: NodeId,
140 max_tokens: usize,
141 max_depth: usize,
142 pinned: &[NodeId],
143 ) -> ContextSlice {
144 let start = Instant::now();
145
146 let target_node = match self.get(target) {
147 Some(node) => {
148 let mut info = NodeInfo::from(node);
149 info.centrality = self.centrality(target);
150 info
151 }
152 None => {
153 return ContextSlice {
154 target: NodeInfo {
155 id: String::new(),
156 name: String::new(),
157 qualified_name: String::new(),
158 kind: String::new(),
159 file: String::new(),
160 line_start: 0,
161 line_end: 0,
162 signature: None,
163 centrality: 0.0,
164 },
165 nodes: Vec::new(),
166 total_tokens: 0,
167 max_tokens,
168 truncation_reason: TruncationReason::Complete,
169 query_time_ms: 0,
170 };
171 }
172 };
173
174 let effective_max = if max_depth == 0 {
175 usize::MAX
176 } else {
177 max_depth
178 };
179 let effective_tokens = if max_tokens == 0 {
180 usize::MAX
181 } else {
182 max_tokens
183 };
184
185 let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
186 let mut visited: HashSet<NodeId> = HashSet::new();
187 let mut result: Vec<ContextNode> = Vec::new();
188 let mut total_tokens = 0usize;
189 let mut truncation_reason = TruncationReason::Complete;
190
191 let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
193
194 queue.push_back((target, 0));
196
197 while let Some((current, depth)) = queue.pop_front() {
198 if visited.contains(¤t) {
199 continue;
200 }
201
202 if depth > effective_max {
203 truncation_reason = TruncationReason::MaxDepth;
204 continue;
205 }
206
207 visited.insert(current);
208
209 let is_pinned = pinned_set.contains(¤t);
210
211 if let Some(node) = self.get(current) {
212 let mut node_info = NodeInfo::from(node);
213 node_info.centrality = self.centrality(current);
214
215 let token_est = estimate_tokens(&node_info);
216
217 let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
219
220 if within_budget {
221 total_tokens += token_est;
222
223 result.push(ContextNode {
224 node_info,
225 token_estimate: token_est,
226 depth,
227 pinned: is_pinned,
228 });
229 } else {
230 truncation_reason = TruncationReason::TokenBudget;
231 }
233 }
234
235 if depth < effective_max {
237 for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
239 let neighbor = edge_ref.source();
240 if !visited.contains(&neighbor) {
241 queue.push_back((neighbor, depth + 1));
242 }
243 }
244
245 for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
247 let neighbor = edge_ref.target();
248 if !visited.contains(&neighbor) {
249 queue.push_back((neighbor, depth + 1));
250 }
251 }
252 }
253 }
254
255 result.sort_by(|a, b| {
257 b.pinned
258 .cmp(&a.pinned)
259 .then_with(|| a.depth.cmp(&b.depth))
260 .then_with(|| {
261 b.node_info
262 .centrality
263 .partial_cmp(&a.node_info.centrality)
264 .unwrap_or(std::cmp::Ordering::Equal)
265 })
266 });
267
268 let elapsed = start.elapsed().as_millis() as u64;
269
270 ContextSlice {
271 target: target_node,
272 nodes: result,
273 total_tokens,
274 max_tokens,
275 truncation_reason,
276 query_time_ms: elapsed,
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::edge::{Edge, EdgeKind};
285 use arbor_core::{CodeNode, NodeKind};
286
287 fn make_node(name: &str) -> CodeNode {
288 CodeNode::new(name, name, NodeKind::Function, "test.rs")
289 }
290
291 #[test]
292 fn test_empty_graph() {
293 let graph = ArborGraph::new();
294 let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
295 assert!(result.nodes.is_empty());
296 assert_eq!(result.total_tokens, 0);
297 }
298
299 #[test]
300 fn test_single_node() {
301 let mut graph = ArborGraph::new();
302 let id = graph.add_node(make_node("lonely"));
303
304 let result = graph.slice_context(id, 1000, 2, &[]);
305 assert_eq!(result.nodes.len(), 1);
306 assert_eq!(result.nodes[0].node_info.name, "lonely");
307 assert_eq!(result.truncation_reason, TruncationReason::Complete);
308 }
309
310 #[test]
311 fn test_linear_chain_depth_limit() {
312 let mut graph = ArborGraph::new();
314 let a = graph.add_node(make_node("a"));
315 let b = graph.add_node(make_node("b"));
316 let c = graph.add_node(make_node("c"));
317 let d = graph.add_node(make_node("d"));
318
319 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
320 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
321 graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
322
323 let result = graph.slice_context(b, 10000, 1, &[]);
325
326 let names: Vec<&str> = result
329 .nodes
330 .iter()
331 .map(|n| n.node_info.name.as_str())
332 .collect();
333 assert!(names.contains(&"b"));
334 assert!(names.contains(&"a"));
335 assert!(names.contains(&"c"));
336 assert!(!names.contains(&"d"));
337 }
338
339 #[test]
340 fn test_token_budget() {
341 let mut graph = ArborGraph::new();
342 let a = graph.add_node(make_node("a"));
343 let b = graph.add_node(make_node("b"));
344 let c = graph.add_node(make_node("c"));
345
346 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
347 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
348
349 let result = graph.slice_context(a, 5, 10, &[]);
351
352 assert!(result.nodes.len() < 3);
354 assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
355 }
356
357 #[test]
358 fn test_pinned_nodes_bypass_budget() {
359 let mut graph = ArborGraph::new();
360 let a = graph.add_node(make_node("a"));
361 let b = graph.add_node(make_node("important_node"));
362 let c = graph.add_node(make_node("c"));
363
364 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
365 graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
366
367 let result = graph.slice_context(a, 5, 10, &[b]);
369
370 let has_important = result
372 .nodes
373 .iter()
374 .any(|n| n.node_info.name == "important_node");
375 assert!(has_important);
376 }
377
378 #[test]
379 fn test_complete_traversal() {
380 let mut graph = ArborGraph::new();
381 let a = graph.add_node(make_node("a"));
382 let b = graph.add_node(make_node("b"));
383
384 graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
385
386 let result = graph.slice_context(a, 100000, 10, &[]);
388 assert_eq!(result.truncation_reason, TruncationReason::Complete);
389 assert_eq!(result.nodes.len(), 2);
390 }
391}