pub mod optimized;
use crate::{Chunk, Error, Result};
pub use optimized::{OptimizedRaptorTree, RaptorOptConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaptorNode {
pub id: Uuid,
pub text: String,
pub children: Vec<Uuid>,
pub parent: Option<Uuid>,
pub level: usize,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaptorTree {
pub nodes: HashMap<Uuid, RaptorNode>,
pub roots: Vec<Uuid>,
pub max_depth: usize,
pub cluster_size: usize,
}
impl RaptorTree {
pub fn new(max_depth: usize, cluster_size: usize) -> Self {
Self {
nodes: HashMap::new(),
roots: Vec::new(),
max_depth,
cluster_size,
}
}
pub async fn build_from_chunks(
&mut self,
chunks: &[Chunk],
embedder: &dyn Fn(&str) -> Result<Vec<f32>>,
summarizer: &dyn Fn(&str) -> Result<String>,
) -> Result<()> {
if chunks.is_empty() {
return Ok(());
}
let mut leaf_nodes = Vec::new();
for chunk in chunks {
let embedding = embedder(&chunk.text)?;
let node = RaptorNode {
id: chunk.id,
text: chunk.text.clone(),
children: Vec::new(),
parent: None,
level: 0,
embedding: Some(embedding),
};
self.nodes.insert(node.id, node.clone());
leaf_nodes.push(node);
}
let mut current_level_nodes = leaf_nodes;
for level in 1..=self.max_depth {
if current_level_nodes.len() <= self.cluster_size {
for node in ¤t_level_nodes {
self.roots.push(node.id);
}
break;
}
let next_level_nodes = self
.build_level(¤t_level_nodes, level, embedder, summarizer)
.await?;
current_level_nodes = next_level_nodes;
}
for node in current_level_nodes {
self.roots.push(node.id);
}
Ok(())
}
async fn build_level(
&mut self,
nodes: &[RaptorNode],
level: usize,
embedder: &dyn Fn(&str) -> Result<Vec<f32>>,
summarizer: &dyn Fn(&str) -> Result<String>,
) -> Result<Vec<RaptorNode>> {
let mut next_level_nodes = Vec::new();
for i in (0..nodes.len()).step_by(self.cluster_size) {
let cluster_end = (i + self.cluster_size).min(nodes.len());
let cluster = &nodes[i..cluster_end];
if cluster.len() == 1 {
let mut node = cluster[0].clone();
node.level = level;
next_level_nodes.push(node);
continue;
}
let cluster_texts: Vec<String> = cluster.iter().map(|n| n.text.clone()).collect();
let combined_text = cluster_texts.join("\n\n");
let summary = summarizer(&combined_text)?;
let embedding = embedder(&summary)?;
let cluster_node = RaptorNode {
id: Uuid::new_v4(),
text: summary,
children: cluster.iter().map(|n| n.id).collect(),
parent: None,
level,
embedding: Some(embedding),
};
for child in cluster {
if let Some(child_node) = self.nodes.get_mut(&child.id) {
child_node.parent = Some(cluster_node.id);
}
}
self.nodes.insert(cluster_node.id, cluster_node.clone());
next_level_nodes.push(cluster_node);
}
Ok(next_level_nodes)
}
pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(Uuid, f32)>> {
let mut candidates = Vec::new();
for (node_id, node) in &self.nodes {
if let Some(embedding) = &node.embedding {
if embedding.len() == query_embedding.len() {
let similarity = cosine_similarity(query_embedding, embedding);
candidates.push((*node_id, similarity));
}
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut leaf_results = Vec::new();
for (node_id, score) in candidates.into_iter().take(top_k * 2) {
leaf_results.extend(self.expand_to_leaves(node_id, score));
}
let mut unique_results = HashMap::new();
for (leaf_id, score) in leaf_results {
unique_results
.entry(leaf_id)
.and_modify(|e: &mut f32| *e = e.max(score))
.or_insert(score);
}
let mut final_results: Vec<(Uuid, f32)> = unique_results.into_iter().collect();
final_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(final_results.into_iter().take(top_k).collect())
}
fn expand_to_leaves(&self, node_id: Uuid, score: f32) -> Vec<(Uuid, f32)> {
let mut results = Vec::new();
if let Some(node) = self.nodes.get(&node_id) {
if node.children.is_empty() {
results.push((node_id, score));
} else {
for child_id in &node.children {
results.extend(self.expand_to_leaves(*child_id, score));
}
}
}
results
}
pub fn get_leaf_nodes(&self, node_id: Uuid) -> Vec<Uuid> {
self.expand_to_leaves(node_id, 1.0)
.into_iter()
.map(|(id, _)| id)
.collect()
}
pub fn get_node(&self, node_id: &Uuid) -> Option<&RaptorNode> {
self.nodes.get(node_id)
}
pub fn stats(&self) -> RaptorStats {
let mut level_counts = HashMap::new();
let mut total_nodes = 0;
let mut leaf_nodes = 0;
for node in self.nodes.values() {
*level_counts.entry(node.level).or_insert(0) += 1;
total_nodes += 1;
if node.children.is_empty() {
leaf_nodes += 1;
}
}
RaptorStats {
total_nodes,
leaf_nodes,
max_depth: self.max_depth,
level_counts,
root_count: self.roots.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaptorStats {
pub total_nodes: usize,
pub leaf_nodes: usize,
pub max_depth: usize,
pub level_counts: HashMap<usize, usize>,
pub root_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeNode {
pub id: Uuid,
pub name: String,
pub node_type: String,
pub file_path: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeEdge {
pub source_id: Uuid,
pub target_id: Uuid,
pub relation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeGraph {
pub nodes: HashMap<Uuid, CodeNode>,
pub edges: Vec<CodeEdge>,
}
impl CodeGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: Vec::new(),
}
}
pub fn add_node(&mut self, node: CodeNode) -> Result<Uuid> {
let id = node.id;
self.nodes.insert(id, node);
Ok(id)
}
pub fn add_edge(&mut self, source_id: Uuid, target_id: Uuid, relation: String) -> Result<()> {
if !self.nodes.contains_key(&source_id) {
return Err(Error::validation(format!(
"Source node {} not found in graph",
source_id
)));
}
if !self.nodes.contains_key(&target_id) {
return Err(Error::validation(format!(
"Target node {} not found in graph",
target_id
)));
}
self.edges.push(CodeEdge {
source_id,
target_id,
relation,
});
Ok(())
}
pub fn get_node(&self, node_id: &Uuid) -> Option<&CodeNode> {
self.nodes.get(node_id)
}
pub fn get_outgoing_edges(&self, node_id: &Uuid) -> Vec<&CodeEdge> {
self.edges
.iter()
.filter(|e| &e.source_id == node_id)
.collect()
}
pub fn get_incoming_edges(&self, node_id: &Uuid) -> Vec<&CodeEdge> {
self.edges
.iter()
.filter(|e| &e.target_id == node_id)
.collect()
}
pub fn find_nodes_by_type(&self, node_type: &str) -> Vec<&CodeNode> {
self.nodes
.values()
.filter(|n| n.node_type == node_type)
.collect()
}
pub fn find_nodes_by_file(&self, file_path: &str) -> Vec<&CodeNode> {
self.nodes
.values()
.filter(|n| n.file_path == file_path)
.collect()
}
pub fn stats(&self) -> CodeGraphStats {
let mut type_counts = HashMap::new();
let mut file_counts = HashMap::new();
let mut relation_counts = HashMap::new();
for node in self.nodes.values() {
*type_counts.entry(node.node_type.clone()).or_insert(0) += 1;
*file_counts.entry(node.file_path.clone()).or_insert(0) += 1;
}
for edge in &self.edges {
*relation_counts.entry(edge.relation.clone()).or_insert(0) += 1;
}
CodeGraphStats {
total_nodes: self.nodes.len(),
total_edges: self.edges.len(),
type_counts,
file_counts,
relation_counts,
}
}
}
impl Default for CodeGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeGraphStats {
pub total_nodes: usize,
pub total_edges: usize,
pub type_counts: HashMap<String, usize>,
pub file_counts: HashMap<String, usize>,
pub relation_counts: HashMap<String, usize>,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
fn mock_embedder(_text: &str) -> Result<Vec<f32>> {
Ok(vec![0.1, 0.2, 0.3, 0.4, 0.5])
}
fn mock_summarizer(text: &str) -> Result<String> {
Ok(format!(
"Summary of: {}",
text.chars().take(50).collect::<String>()
))
}
#[tokio::test]
async fn test_raptor_tree_build() {
let mut tree = RaptorTree::new(2, 3);
let chunks = vec![
Chunk {
id: Uuid::new_v4(),
text: "First chunk of text".to_string(),
index: 0,
start_char: 0,
end_char: 20,
token_count: Some(5),
section: None,
page: None,
embedding_ids: Default::default(),
},
Chunk {
id: Uuid::new_v4(),
text: "Second chunk of text".to_string(),
index: 1,
start_char: 21,
end_char: 42,
token_count: Some(5),
section: None,
page: None,
embedding_ids: Default::default(),
},
Chunk {
id: Uuid::new_v4(),
text: "Third chunk of text".to_string(),
index: 2,
start_char: 43,
end_char: 63,
token_count: Some(5),
section: None,
page: None,
embedding_ids: Default::default(),
},
Chunk {
id: Uuid::new_v4(),
text: "Fourth chunk of text".to_string(),
index: 3,
start_char: 64,
end_char: 85,
token_count: Some(5),
section: None,
page: None,
embedding_ids: Default::default(),
},
];
tree.build_from_chunks(&chunks, &mock_embedder, &mock_summarizer)
.await
.unwrap();
let stats = tree.stats();
assert!(stats.total_nodes > chunks.len()); assert_eq!(stats.leaf_nodes, chunks.len());
}
#[test]
fn test_raptor_search() {
let mut tree = RaptorTree::new(1, 2);
let node_id = Uuid::new_v4();
let node = RaptorNode {
id: node_id,
text: "Test node".to_string(),
children: vec![],
parent: None,
level: 0,
embedding: Some(vec![1.0, 0.0, 0.0]),
};
tree.nodes.insert(node_id, node);
tree.roots.push(node_id);
let query = vec![1.0, 0.0, 0.0];
let results = tree.search(&query, 5).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node_id);
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[test]
fn test_code_graph_new() {
let graph = CodeGraph::new();
assert_eq!(graph.nodes.len(), 0);
assert_eq!(graph.edges.len(), 0);
}
#[test]
fn test_code_graph_add_node() {
let mut graph = CodeGraph::new();
let node_id = Uuid::new_v4();
let node = CodeNode {
id: node_id,
name: "test_function".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![1.0, 2.0, 3.0],
};
graph.add_node(node).unwrap();
assert_eq!(graph.nodes.len(), 1);
assert!(graph.get_node(&node_id).is_some());
}
#[test]
fn test_code_graph_add_edge() {
let mut graph = CodeGraph::new();
let node1_id = Uuid::new_v4();
let node2_id = Uuid::new_v4();
let node1 = CodeNode {
id: node1_id,
name: "caller".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![1.0, 2.0, 3.0],
};
let node2 = CodeNode {
id: node2_id,
name: "callee".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![4.0, 5.0, 6.0],
};
graph.add_node(node1).unwrap();
graph.add_node(node2).unwrap();
graph
.add_edge(node1_id, node2_id, "calls".to_string())
.unwrap();
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.edges[0].source_id, node1_id);
assert_eq!(graph.edges[0].target_id, node2_id);
assert_eq!(graph.edges[0].relation, "calls");
}
#[test]
fn test_code_graph_add_edge_invalid_node() {
let mut graph = CodeGraph::new();
let node_id = Uuid::new_v4();
let invalid_id = Uuid::new_v4();
let node = CodeNode {
id: node_id,
name: "test".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![1.0, 2.0, 3.0],
};
graph.add_node(node).unwrap();
let result = graph.add_edge(node_id, invalid_id, "calls".to_string());
assert!(result.is_err());
}
#[test]
fn test_code_graph_find_by_type() {
let mut graph = CodeGraph::new();
let node1 = CodeNode {
id: Uuid::new_v4(),
name: "func1".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![1.0, 2.0, 3.0],
};
let node2 = CodeNode {
id: Uuid::new_v4(),
name: "MyStruct".to_string(),
node_type: "struct".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![4.0, 5.0, 6.0],
};
graph.add_node(node1).unwrap();
graph.add_node(node2).unwrap();
let functions = graph.find_nodes_by_type("function");
assert_eq!(functions.len(), 1);
assert_eq!(functions[0].name, "func1");
let structs = graph.find_nodes_by_type("struct");
assert_eq!(structs.len(), 1);
assert_eq!(structs[0].name, "MyStruct");
}
#[test]
fn test_code_graph_stats() {
let mut graph = CodeGraph::new();
let node1_id = Uuid::new_v4();
let node2_id = Uuid::new_v4();
let node1 = CodeNode {
id: node1_id,
name: "func1".to_string(),
node_type: "function".to_string(),
file_path: "/src/lib.rs".to_string(),
embedding: vec![1.0, 2.0, 3.0],
};
let node2 = CodeNode {
id: node2_id,
name: "func2".to_string(),
node_type: "function".to_string(),
file_path: "/src/main.rs".to_string(),
embedding: vec![4.0, 5.0, 6.0],
};
graph.add_node(node1).unwrap();
graph.add_node(node2).unwrap();
graph
.add_edge(node1_id, node2_id, "calls".to_string())
.unwrap();
let stats = graph.stats();
assert_eq!(stats.total_nodes, 2);
assert_eq!(stats.total_edges, 1);
assert_eq!(stats.type_counts.get("function"), Some(&2));
assert_eq!(stats.relation_counts.get("calls"), Some(&1));
}
}