use crate::graph::{ArborGraph, NodeId};
use crate::query::NodeInfo;
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use serde::{Deserialize, Serialize};
use std::collections::{HashSet, VecDeque};
use std::time::Instant;
use tracing::warn;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TruncationReason {
Complete,
TokenBudget,
MaxDepth,
}
impl std::fmt::Display for TruncationReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TruncationReason::Complete => write!(f, "complete"),
TruncationReason::TokenBudget => write!(f, "token_budget"),
TruncationReason::MaxDepth => write!(f, "max_depth"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextNode {
pub node_info: NodeInfo,
pub token_estimate: usize,
pub depth: usize,
pub pinned: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ContextSlice {
pub target: NodeInfo,
pub nodes: Vec<ContextNode>,
pub total_tokens: usize,
pub max_tokens: usize,
pub truncation_reason: TruncationReason,
pub query_time_ms: u64,
}
impl ContextSlice {
pub fn summary(&self) -> String {
format!(
"Context: {} nodes, ~{} tokens ({})",
self.nodes.len(),
self.total_tokens,
self.truncation_reason
)
}
pub fn pinned_only(&self) -> Vec<&ContextNode> {
self.nodes.iter().filter(|n| n.pinned).collect()
}
}
use once_cell::sync::Lazy;
use tiktoken_rs::cl100k_base;
static TOKENIZER: Lazy<Option<tiktoken_rs::CoreBPE>> = Lazy::new(|| match cl100k_base() {
Ok(tokenizer) => Some(tokenizer),
Err(error) => {
warn!(
"Failed to initialize cl100k_base tokenizer; falling back to heuristic token estimates: {}",
error
);
None
}
});
const LARGE_FILE_THRESHOLD: usize = 800 * 1024;
fn estimate_tokens(node: &NodeInfo) -> usize {
let base = node.name.len() + node.qualified_name.len() + node.file.len();
let signature_len = node.signature.as_ref().map(|s| s.len()).unwrap_or(0);
let lines = (node.line_end.saturating_sub(node.line_start) + 1) as usize;
let estimated_chars = base + signature_len + (lines * 40);
if estimated_chars > LARGE_FILE_THRESHOLD {
return estimated_chars.div_ceil(4); }
let text = format!(
"{} {} {}{}",
node.qualified_name,
node.file,
node.signature.as_deref().unwrap_or(""),
" ".repeat(lines * 40) );
if let Some(tokenizer) = TOKENIZER.as_ref() {
tokenizer.encode_with_special_tokens(&text).len()
} else {
estimated_chars.div_ceil(4)
}
}
impl ArborGraph {
pub fn slice_context(
&self,
target: NodeId,
max_tokens: usize,
max_depth: usize,
pinned: &[NodeId],
) -> ContextSlice {
let start = Instant::now();
let target_node = match self.get(target) {
Some(node) => {
let mut info = NodeInfo::from(node);
info.centrality = self.centrality(target);
info
}
None => {
return ContextSlice {
target: NodeInfo {
id: String::new(),
name: String::new(),
qualified_name: String::new(),
kind: String::new(),
file: String::new(),
line_start: 0,
line_end: 0,
signature: None,
centrality: 0.0,
},
nodes: Vec::new(),
total_tokens: 0,
max_tokens,
truncation_reason: TruncationReason::Complete,
query_time_ms: 0,
};
}
};
let effective_max = if max_depth == 0 {
usize::MAX
} else {
max_depth
};
let effective_tokens = if max_tokens == 0 {
usize::MAX
} else {
max_tokens
};
let pinned_set: HashSet<NodeId> = pinned.iter().copied().collect();
let mut visited: HashSet<NodeId> = HashSet::new();
let mut result: Vec<ContextNode> = Vec::new();
let mut total_tokens = 0usize;
let mut truncation_reason = TruncationReason::Complete;
let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
queue.push_back((target, 0));
while let Some((current, depth)) = queue.pop_front() {
if visited.contains(¤t) {
continue;
}
if depth > effective_max {
truncation_reason = TruncationReason::MaxDepth;
continue;
}
visited.insert(current);
let is_pinned = pinned_set.contains(¤t);
if let Some(node) = self.get(current) {
let mut node_info = NodeInfo::from(node);
node_info.centrality = self.centrality(current);
let token_est = estimate_tokens(&node_info);
let within_budget = is_pinned || total_tokens + token_est <= effective_tokens;
if within_budget {
total_tokens += token_est;
result.push(ContextNode {
node_info,
token_estimate: token_est,
depth,
pinned: is_pinned,
});
} else {
truncation_reason = TruncationReason::TokenBudget;
}
}
if depth < effective_max {
for edge_ref in self.graph.edges_directed(current, Direction::Incoming) {
let neighbor = edge_ref.source();
if !visited.contains(&neighbor) {
queue.push_back((neighbor, depth + 1));
}
}
for edge_ref in self.graph.edges_directed(current, Direction::Outgoing) {
let neighbor = edge_ref.target();
if !visited.contains(&neighbor) {
queue.push_back((neighbor, depth + 1));
}
}
}
}
result.sort_by(|a, b| {
b.pinned
.cmp(&a.pinned)
.then_with(|| a.depth.cmp(&b.depth))
.then_with(|| {
b.node_info
.centrality
.partial_cmp(&a.node_info.centrality)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
let elapsed = start.elapsed().as_millis() as u64;
ContextSlice {
target: target_node,
nodes: result,
total_tokens,
max_tokens,
truncation_reason,
query_time_ms: elapsed,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::edge::{Edge, EdgeKind};
use arbor_core::{CodeNode, NodeKind};
fn make_node(name: &str) -> CodeNode {
CodeNode::new(name, name, NodeKind::Function, "test.rs")
}
#[test]
fn test_empty_graph() {
let graph = ArborGraph::new();
let result = graph.slice_context(NodeId::new(0), 1000, 2, &[]);
assert!(result.nodes.is_empty());
assert_eq!(result.total_tokens, 0);
}
#[test]
fn test_single_node() {
let mut graph = ArborGraph::new();
let id = graph.add_node(make_node("lonely"));
let result = graph.slice_context(id, 1000, 2, &[]);
assert_eq!(result.nodes.len(), 1);
assert_eq!(result.nodes[0].node_info.name, "lonely");
assert_eq!(result.truncation_reason, TruncationReason::Complete);
}
#[test]
fn test_linear_chain_depth_limit() {
let mut graph = ArborGraph::new();
let a = graph.add_node(make_node("a"));
let b = graph.add_node(make_node("b"));
let c = graph.add_node(make_node("c"));
let d = graph.add_node(make_node("d"));
graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
let result = graph.slice_context(b, 10000, 1, &[]);
let names: Vec<&str> = result
.nodes
.iter()
.map(|n| n.node_info.name.as_str())
.collect();
assert!(names.contains(&"b"));
assert!(names.contains(&"a"));
assert!(names.contains(&"c"));
assert!(!names.contains(&"d"));
}
#[test]
fn test_token_budget() {
let mut graph = ArborGraph::new();
let a = graph.add_node(make_node("a"));
let b = graph.add_node(make_node("b"));
let c = graph.add_node(make_node("c"));
graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
let result = graph.slice_context(a, 5, 10, &[]);
assert!(result.nodes.len() < 3);
assert_eq!(result.truncation_reason, TruncationReason::TokenBudget);
}
#[test]
fn test_pinned_nodes_bypass_budget() {
let mut graph = ArborGraph::new();
let a = graph.add_node(make_node("a"));
let b = graph.add_node(make_node("important_node"));
let c = graph.add_node(make_node("c"));
graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
let result = graph.slice_context(a, 5, 10, &[b]);
let has_important = result
.nodes
.iter()
.any(|n| n.node_info.name == "important_node");
assert!(has_important);
}
#[test]
fn test_complete_traversal() {
let mut graph = ArborGraph::new();
let a = graph.add_node(make_node("a"));
let b = graph.add_node(make_node("b"));
graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
let result = graph.slice_context(a, 100000, 10, &[]);
assert_eq!(result.truncation_reason, TruncationReason::Complete);
assert_eq!(result.nodes.len(), 2);
}
}