use crate::{Chunk, Document, Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizedRaptorNode {
pub id: Uuid,
pub text: String,
pub children: Vec<Uuid>,
pub parent: Option<Uuid>,
pub level: usize,
pub embedding: Option<Vec<f32>>,
#[serde(skip)]
pub embedding_cached: bool,
#[serde(skip)]
pub last_accessed: Option<std::time::Instant>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaptorOptConfig {
pub max_depth: usize,
pub cluster_size: usize,
pub beam_width: usize,
pub cache_size: usize,
pub parallel_clustering: bool,
pub min_similarity: f32,
}
impl Default for RaptorOptConfig {
fn default() -> Self {
Self {
max_depth: 3,
cluster_size: 5,
beam_width: 3,
cache_size: 1000,
parallel_clustering: true,
min_similarity: 0.1,
}
}
}
pub struct OptimizedRaptorTree {
nodes: HashMap<Uuid, OptimizedRaptorNode>,
roots: Vec<Uuid>,
config: RaptorOptConfig,
embedding_cache: Arc<RwLock<HashMap<Uuid, Vec<f32>>>>,
access_order: Arc<RwLock<Vec<Uuid>>>,
}
impl OptimizedRaptorTree {
pub fn new(config: RaptorOptConfig) -> Self {
Self {
nodes: HashMap::new(),
roots: Vec::new(),
config,
embedding_cache: Arc::new(RwLock::new(HashMap::new())),
access_order: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn build_from_chunks(
&mut self,
chunks: &[Chunk],
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
summarizer: &(dyn Fn(&str) -> Result<String> + Sync + Send),
) -> Result<()> {
if chunks.is_empty() {
return Ok(());
}
let leaf_nodes = self.create_leaf_nodes(chunks, embedder).await?;
let mut current_level_nodes = leaf_nodes;
for level in 1..=self.config.max_depth {
if current_level_nodes.len() <= self.config.cluster_size {
for node in ¤t_level_nodes {
self.roots.push(node.id);
}
break;
}
let next_level_nodes = if self.config.parallel_clustering {
self.build_level_parallel(¤t_level_nodes, level, embedder, summarizer)
.await?
} else {
self.build_level_sequential(¤t_level_nodes, level, embedder, summarizer)
.await?
};
current_level_nodes = next_level_nodes;
}
for node in current_level_nodes {
self.roots.push(node.id);
}
Ok(())
}
async fn create_leaf_nodes(
&mut self,
chunks: &[Chunk],
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
) -> Result<Vec<OptimizedRaptorNode>> {
let mut leaf_nodes = Vec::new();
for chunk in chunks {
let embedding = embedder(&chunk.text)?;
self.cache_embedding(chunk.id, embedding.clone());
let node = OptimizedRaptorNode {
id: chunk.id,
text: chunk.text.clone(),
children: Vec::new(),
parent: None,
level: 0,
embedding: Some(embedding),
embedding_cached: true,
last_accessed: Some(std::time::Instant::now()),
};
self.nodes.insert(node.id, node.clone());
leaf_nodes.push(node);
}
Ok(leaf_nodes)
}
async fn build_level_sequential(
&mut self,
nodes: &[OptimizedRaptorNode],
level: usize,
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
summarizer: &(dyn Fn(&str) -> Result<String> + Sync + Send),
) -> Result<Vec<OptimizedRaptorNode>> {
let mut next_level_nodes = Vec::new();
for i in (0..nodes.len()).step_by(self.config.cluster_size) {
let cluster_end = (i + self.config.cluster_size).min(nodes.len());
let cluster = &nodes[i..cluster_end];
let cluster_node = self.create_cluster_node(cluster, level, embedder, summarizer)?;
next_level_nodes.push(cluster_node);
}
Ok(next_level_nodes)
}
async fn build_level_parallel(
&mut self,
nodes: &[OptimizedRaptorNode],
level: usize,
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
summarizer: &(dyn Fn(&str) -> Result<String> + Sync + Send),
) -> Result<Vec<OptimizedRaptorNode>> {
use rayon::prelude::*;
let clusters: Vec<_> = (0..nodes.len()).step_by(self.config.cluster_size).collect();
let next_level_nodes: Result<Vec<_>> = clusters
.par_iter()
.map(|&i| {
let cluster_end = (i + self.config.cluster_size).min(nodes.len());
let cluster = &nodes[i..cluster_end];
Self::create_cluster_node_pure(
cluster,
level,
embedder,
summarizer,
self.config.cache_size,
&self.embedding_cache,
&self.access_order,
)
})
.collect();
let new_nodes = next_level_nodes?;
let mut result_nodes = Vec::new();
for node in new_nodes {
let node_id = node.id;
for child_id in &node.children {
if let Some(child_node) = self.nodes.get_mut(child_id) {
child_node.parent = Some(node_id);
}
}
self.nodes.insert(node_id, node.clone());
result_nodes.push(node);
}
Ok(result_nodes)
}
fn create_cluster_node(
&mut self,
cluster: &[OptimizedRaptorNode],
level: usize,
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
summarizer: &(dyn Fn(&str) -> Result<String> + Sync + Send),
) -> Result<OptimizedRaptorNode> {
let node = Self::create_cluster_node_pure(
cluster,
level,
embedder,
summarizer,
self.config.cache_size,
&self.embedding_cache,
&self.access_order,
)?;
for child in cluster {
if let Some(child_node) = self.nodes.get_mut(&child.id) {
child_node.parent = Some(node.id);
}
}
self.nodes.insert(node.id, node.clone());
Ok(node)
}
fn create_cluster_node_pure(
cluster: &[OptimizedRaptorNode],
level: usize,
embedder: &(dyn Fn(&str) -> Result<Vec<f32>> + Sync + Send),
summarizer: &(dyn Fn(&str) -> Result<String> + Sync + Send),
cache_size: usize,
embedding_cache: &Arc<RwLock<HashMap<Uuid, Vec<f32>>>>,
access_order: &Arc<RwLock<Vec<Uuid>>>,
) -> Result<OptimizedRaptorNode> {
if cluster.len() == 1 {
let mut node = cluster[0].clone();
node.level = level;
return Ok(node);
}
let cluster_texts: Vec<String> = cluster.iter().map(|n| n.text.clone()).collect();
let combined_text = cluster_texts.join("\n\n");
let summary = summarizer(&combined_text)?;
let embedding = embedder(&summary)?;
let cluster_node_id = Uuid::new_v4();
{
if let Ok(mut cache) = embedding_cache.write() {
while cache.len() >= cache_size {
if let Ok(mut order) = access_order.write() {
if let Some(oldest_id) = order.first().cloned() {
order.remove(0);
cache.remove(&oldest_id);
} else {
break;
}
} else {
break;
}
}
cache.insert(cluster_node_id, embedding.clone());
if let Ok(mut order) = access_order.write() {
order.push(cluster_node_id);
}
}
}
let cluster_node = OptimizedRaptorNode {
id: cluster_node_id,
text: summary,
children: cluster.iter().map(|n| n.id).collect(),
parent: None,
level,
embedding: Some(embedding),
embedding_cached: true,
last_accessed: Some(std::time::Instant::now()),
};
Ok(cluster_node)
}
pub fn search_beam(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(Uuid, f32)>> {
let mut beam: Vec<(Uuid, f32)> = Vec::new();
for root_id in &self.roots {
if let Some(node) = self.nodes.get(root_id) {
if let Some(embedding) = self.get_embedding(&node.id) {
let score = cosine_similarity(query_embedding, &embedding);
if score >= self.config.min_similarity {
beam.push((*root_id, score));
}
}
}
}
if beam.is_empty() {
return Ok(Vec::new());
}
let mut current_level = self.config.max_depth;
while current_level > 0 {
beam.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
beam.truncate(self.config.beam_width);
let mut next_beam = Vec::new();
for (node_id, _parent_score) in &beam {
if let Some(node) = self.nodes.get(node_id) {
if node.level == 0 {
next_beam.push((*node_id, *_parent_score));
continue;
}
for child_id in &node.children {
if let Some(child) = self.nodes.get(child_id) {
if let Some(embedding) = self.get_embedding(&child.id) {
let score = cosine_similarity(query_embedding, &embedding);
if score >= self.config.min_similarity {
next_beam.push((*child_id, score));
}
}
}
}
}
}
if next_beam.is_empty() {
break;
}
beam = next_beam;
current_level -= 1;
}
beam.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
beam.truncate(top_k);
Ok(beam)
}
fn get_embedding(&self, node_id: &Uuid) -> Option<Vec<f32>> {
if let Ok(cache) = self.embedding_cache.read() {
if let Some(embedding) = cache.get(node_id) {
if let Ok(mut order) = self.access_order.write() {
order.retain(|id| id != node_id);
order.push(*node_id);
}
return Some(embedding.clone());
}
}
if let Some(node) = self.nodes.get(node_id) {
if let Some(ref embedding) = node.embedding {
self.cache_embedding(*node_id, embedding.clone());
return Some(embedding.clone());
}
}
None
}
fn cache_embedding(&self, node_id: Uuid, embedding: Vec<f32>) {
if let Ok(mut cache) = self.embedding_cache.write() {
while cache.len() >= self.config.cache_size {
if let Ok(mut order) = self.access_order.write() {
if let Some(oldest_id) = order.first().cloned() {
order.remove(0);
cache.remove(&oldest_id);
} else {
break;
}
} else {
break;
}
}
cache.insert(node_id, embedding);
if let Ok(mut order) = self.access_order.write() {
order.push(node_id);
}
}
}
pub fn stats(&self) -> RaptorOptStats {
let mut level_counts = HashMap::new();
let mut cached_count = 0;
for node in self.nodes.values() {
*level_counts.entry(node.level).or_insert(0) += 1;
if node.embedding_cached {
cached_count += 1;
}
}
let cache_size = self.embedding_cache.read().map(|c| c.len()).unwrap_or(0);
RaptorOptStats {
total_nodes: self.nodes.len(),
leaf_nodes: level_counts.get(&0).copied().unwrap_or(0),
max_depth: self.config.max_depth,
level_counts,
root_count: self.roots.len(),
cache_size,
cache_hit_rate: (cached_count as f32) / (self.nodes.len() as f32),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaptorOptStats {
pub total_nodes: usize,
pub leaf_nodes: usize,
pub max_depth: usize,
pub level_counts: HashMap<usize, usize>,
pub root_count: usize,
pub cache_size: usize,
pub cache_hit_rate: f32,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
fn mock_embedder(_text: &str) -> Result<Vec<f32>> {
Ok(vec![0.1, 0.2, 0.3, 0.4, 0.5])
}
fn mock_summarizer(text: &str) -> Result<String> {
Ok(format!(
"Summary: {}",
text.chars().take(50).collect::<String>()
))
}
#[tokio::test]
async fn test_beam_search() {
let config = RaptorOptConfig {
max_depth: 2,
cluster_size: 3,
beam_width: 2,
..Default::default()
};
let mut tree = OptimizedRaptorTree::new(config);
let chunks = vec![
Chunk {
id: Uuid::new_v4(),
text: "Test chunk 1".to_string(),
index: 0,
start_char: 0,
end_char: 12,
token_count: Some(3),
section: None,
page: None,
embedding_ids: Default::default(),
},
Chunk {
id: Uuid::new_v4(),
text: "Test chunk 2".to_string(),
index: 1,
start_char: 13,
end_char: 25,
token_count: Some(3),
section: None,
page: None,
embedding_ids: Default::default(),
},
];
tree.build_from_chunks(&chunks, &mock_embedder, &mock_summarizer)
.await
.unwrap();
let query_embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let results = tree.search_beam(&query_embedding, 5).unwrap();
assert!(!results.is_empty());
assert!(results[0].1 > 0.0);
}
#[tokio::test]
async fn test_caching() {
let config = RaptorOptConfig {
cache_size: 2,
..Default::default()
};
let tree = OptimizedRaptorTree::new(config);
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let id3 = Uuid::new_v4();
tree.cache_embedding(id1, vec![1.0, 2.0]);
tree.cache_embedding(id2, vec![3.0, 4.0]);
assert!(tree.get_embedding(&id1).is_some());
assert!(tree.get_embedding(&id2).is_some());
tree.cache_embedding(id3, vec![5.0, 6.0]);
assert!(tree.get_embedding(&id3).is_some());
}
}