use anyhow::Result;
use pdf_extract;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sha2::{Digest, Sha256};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::debug;
use crate::{
embeddings::MLXBridge,
preprocessing::{PreprocessingConfig, Preprocessor},
storage::{ChromaDocument, StorageManager},
};
const DEFAULT_NAMESPACE: &str = "rag";
const STORAGE_BATCH_SIZE: usize = 100;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum SliceLayer {
Outer = 1,
Middle = 2,
Inner = 3,
Core = 4,
}
impl SliceLayer {
pub fn target_chars(&self) -> usize {
match self {
SliceLayer::Outer => 100,
SliceLayer::Middle => 300,
SliceLayer::Inner => 600,
SliceLayer::Core => usize::MAX,
}
}
pub fn as_u8(&self) -> u8 {
*self as u8
}
pub fn from_u8(v: u8) -> Option<Self> {
match v {
1 => Some(SliceLayer::Outer),
2 => Some(SliceLayer::Middle),
3 => Some(SliceLayer::Inner),
4 => Some(SliceLayer::Core),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
SliceLayer::Outer => "outer",
SliceLayer::Middle => "middle",
SliceLayer::Inner => "inner",
SliceLayer::Core => "core",
}
}
}
impl std::fmt::Display for SliceLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnionSlice {
pub id: String,
pub layer: SliceLayer,
pub content: String,
pub parent_id: Option<String>,
pub children_ids: Vec<String>,
pub keywords: Vec<String>,
}
impl OnionSlice {
pub fn generate_id(content: &str, layer: SliceLayer) -> String {
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
layer.as_u8().hash(&mut hasher);
format!("slice_{:016x}_{}", hasher.finish(), layer.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SliceMode {
#[default]
Onion,
Flat,
}
impl std::str::FromStr for SliceMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"onion" => Ok(SliceMode::Onion),
"flat" => Ok(SliceMode::Flat),
other => Err(format!(
"Invalid slice mode: '{}'. Use 'onion' or 'flat'",
other
)),
}
}
}
#[derive(Debug, Clone)]
pub enum IndexResult {
Indexed {
chunks_indexed: usize,
content_hash: String,
},
Skipped {
reason: String,
content_hash: String,
},
}
impl IndexResult {
pub fn is_indexed(&self) -> bool {
matches!(self, IndexResult::Indexed { .. })
}
pub fn is_skipped(&self) -> bool {
matches!(self, IndexResult::Skipped { .. })
}
pub fn content_hash(&self) -> &str {
match self {
IndexResult::Indexed { content_hash, .. } => content_hash,
IndexResult::Skipped { content_hash, .. } => content_hash,
}
}
}
pub fn compute_content_hash(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let result = hasher.finalize();
result.iter().map(|b| format!("{:02x}", b)).collect()
}
#[derive(Debug, Clone)]
pub struct OnionSliceConfig {
pub outer_target: usize,
pub middle_target: usize,
pub inner_target: usize,
pub min_content_for_slicing: usize,
}
impl Default for OnionSliceConfig {
fn default() -> Self {
Self {
outer_target: 100,
middle_target: 300,
inner_target: 600,
min_content_for_slicing: 200,
}
}
}
pub fn create_onion_slices(
content: &str,
_metadata: &serde_json::Value,
config: &OnionSliceConfig,
) -> Vec<OnionSlice> {
let content = content.trim();
if content.len() < config.min_content_for_slicing {
let core_id = OnionSlice::generate_id(content, SliceLayer::Core);
let keywords = extract_keywords(content, 5);
return vec![OnionSlice {
id: core_id,
layer: SliceLayer::Core,
content: content.to_string(),
parent_id: None,
children_ids: vec![],
keywords,
}];
}
let mut slices = Vec::with_capacity(4);
let core_id = OnionSlice::generate_id(content, SliceLayer::Core);
let core_keywords = extract_keywords(content, 10);
let inner_content = extract_key_content(content, config.inner_target);
let inner_id = OnionSlice::generate_id(&inner_content, SliceLayer::Inner);
let inner_keywords = extract_keywords(&inner_content, 7);
let middle_content = extract_key_content(&inner_content, config.middle_target);
let middle_id = OnionSlice::generate_id(&middle_content, SliceLayer::Middle);
let middle_keywords = extract_keywords(&middle_content, 5);
let outer_content = create_outer_summary(&middle_content, &core_keywords, config.outer_target);
let outer_id = OnionSlice::generate_id(&outer_content, SliceLayer::Outer);
let outer_keywords = extract_keywords(&outer_content, 3);
slices.push(OnionSlice {
id: outer_id.clone(),
layer: SliceLayer::Outer,
content: outer_content,
parent_id: Some(middle_id.clone()),
children_ids: vec![],
keywords: outer_keywords,
});
slices.push(OnionSlice {
id: middle_id.clone(),
layer: SliceLayer::Middle,
content: middle_content,
parent_id: Some(inner_id.clone()),
children_ids: vec![outer_id],
keywords: middle_keywords,
});
slices.push(OnionSlice {
id: inner_id.clone(),
layer: SliceLayer::Inner,
content: inner_content,
parent_id: Some(core_id.clone()),
children_ids: vec![middle_id],
keywords: inner_keywords,
});
slices.push(OnionSlice {
id: core_id.clone(),
layer: SliceLayer::Core,
content: content.to_string(),
parent_id: None,
children_ids: vec![inner_id],
keywords: core_keywords,
});
slices
}
fn extract_keywords(text: &str, max_keywords: usize) -> Vec<String> {
use std::collections::HashMap;
const STOP_WORDS: &[&str] = &[
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"from", "as", "is", "was", "are", "were", "been", "be", "have", "has", "had", "do", "does",
"did", "will", "would", "could", "should", "may", "might", "must", "shall", "can", "this",
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what", "which",
"who", "whom", "when", "where", "why", "how", "all", "each", "every", "both", "few",
"more", "most", "other", "some", "such", "no", "not", "only", "own", "same", "so", "than",
"too", "very", "just", "also", "now", "here", "there", "then", "once", "if", "into",
"through", "during", "before", "after", "above", "below", "between", "under", "again",
"further", "about", "out", "over", "up", "down", "off", "any", "because", "until", "while",
];
let stop_set: std::collections::HashSet<&str> = STOP_WORDS.iter().copied().collect();
let mut word_counts: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
let cleaned: String = word
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase();
if cleaned.len() >= 3 && !stop_set.contains(cleaned.as_str()) {
*word_counts.entry(cleaned).or_insert(0) += 1;
}
}
let mut words: Vec<_> = word_counts.into_iter().collect();
words.sort_by(|a, b| b.1.cmp(&a.1));
words
.into_iter()
.take(max_keywords)
.map(|(word, _)| word)
.collect()
}
fn extract_key_content(text: &str, target_chars: usize) -> String {
if text.len() <= target_chars {
return text.to_string();
}
let sentences: Vec<&str> = text
.split(['.', '!', '?'])
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if sentences.is_empty() {
return truncate_at_word_boundary(text, target_chars);
}
let keywords = extract_keywords(text, 10);
let keyword_set: std::collections::HashSet<&str> =
keywords.iter().map(|s| s.as_str()).collect();
let mut scored_sentences: Vec<(usize, f32, &str)> = sentences
.iter()
.enumerate()
.map(|(idx, sentence)| {
let mut score = 0.0_f32;
if idx == 0 {
score += 2.0;
} else if idx == sentences.len() - 1 {
score += 1.5;
}
let words: Vec<&str> = sentence.split_whitespace().collect();
let keyword_count = words
.iter()
.filter(|w| {
let cleaned: String = w
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase();
keyword_set.contains(cleaned.as_str())
})
.count();
if !words.is_empty() {
score += (keyword_count as f32 / words.len() as f32) * 3.0;
}
if sentence.len() < 20 {
score -= 0.5;
}
(idx, score, *sentence)
})
.collect();
scored_sentences.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut selected_indices: Vec<usize> = Vec::new();
let mut total_len = 0;
for (idx, _, sentence) in &scored_sentences {
let sentence_len = sentence.len() + 2; if total_len + sentence_len > target_chars && !selected_indices.is_empty() {
break;
}
selected_indices.push(*idx);
total_len += sentence_len;
}
selected_indices.sort();
let result: Vec<&str> = selected_indices
.iter()
.filter_map(|&idx| sentences.get(idx).copied())
.collect();
if result.is_empty() {
truncate_at_word_boundary(text, target_chars)
} else {
result.join(". ") + "."
}
}
fn create_outer_summary(middle_content: &str, keywords: &[String], target_chars: usize) -> String {
let keyword_prefix = if !keywords.is_empty() {
format!(
"[{}] ",
keywords
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
)
} else {
String::new()
};
let remaining_chars = target_chars.saturating_sub(keyword_prefix.len());
let first_sentence = middle_content
.split(['.', '!', '?'])
.next()
.unwrap_or(middle_content)
.trim();
let summary = if first_sentence.len() <= remaining_chars {
first_sentence.to_string()
} else {
truncate_at_word_boundary(first_sentence, remaining_chars)
};
format!("{}{}", keyword_prefix, summary)
}
fn truncate_at_word_boundary(text: &str, max_chars: usize) -> String {
let char_count = text.chars().count();
if char_count <= max_chars {
return text.to_string();
}
let byte_idx = text
.char_indices()
.nth(max_chars)
.map(|(idx, _)| idx)
.unwrap_or(text.len());
let truncated = &text[..byte_idx];
if let Some(last_space) = truncated.rfind(' ') {
format!("{}...", &text[..last_space])
} else {
format!("{}...", truncated)
}
}
pub struct RAGPipeline {
mlx_bridge: Arc<Mutex<MLXBridge>>,
storage: Arc<StorageManager>,
}
impl RAGPipeline {
pub async fn new(
mlx_bridge: Arc<Mutex<MLXBridge>>,
storage: Arc<StorageManager>,
) -> Result<Self> {
Ok(Self {
mlx_bridge,
storage,
})
}
pub fn storage(&self) -> Arc<StorageManager> {
self.storage.clone()
}
pub fn mlx_connected_to(&self) -> String {
if let Ok(bridge) = self.mlx_bridge.try_lock() {
bridge.connected_to().to_string()
} else {
"mlx (lock held)".to_string()
}
}
pub async fn index_document(&self, path: &Path, namespace: Option<&str>) -> Result<()> {
self.index_document_with_mode(path, namespace, SliceMode::default())
.await
}
pub async fn index_document_with_mode(
&self,
path: &Path,
namespace: Option<&str>,
slice_mode: SliceMode,
) -> Result<()> {
self.index_document_internal(path, namespace, None, slice_mode)
.await
}
pub async fn index_document_with_preprocessing(
&self,
path: &Path,
namespace: Option<&str>,
preprocess_config: PreprocessingConfig,
) -> Result<()> {
self.index_document_internal(path, namespace, Some(preprocess_config), SliceMode::Flat)
.await
}
pub async fn index_document_with_dedup(
&self,
path: &Path,
namespace: Option<&str>,
slice_mode: SliceMode,
) -> Result<IndexResult> {
let text = self.extract_text(path).await?;
let ns = namespace.unwrap_or(DEFAULT_NAMESPACE);
let content_hash = compute_content_hash(&text);
if self.storage.has_content_hash(ns, &content_hash).await? {
debug!(
"Skipping duplicate content: {} (hash: {})",
path.display(),
&content_hash[..16]
);
return Ok(IndexResult::Skipped {
reason: "exact duplicate".to_string(),
content_hash,
});
}
let base_metadata = json!({
"path": path.to_str(),
"slice_mode": match slice_mode {
SliceMode::Onion => "onion",
SliceMode::Flat => "flat",
},
"content_hash": &content_hash,
});
let chunks_indexed = match slice_mode {
SliceMode::Onion => {
self.index_with_onion_slicing_and_hash(&text, ns, base_metadata, &content_hash)
.await?
}
SliceMode::Flat => {
self.index_with_flat_chunking_and_hash(
&text,
ns,
path,
base_metadata,
&content_hash,
)
.await?
}
};
Ok(IndexResult::Indexed {
chunks_indexed,
content_hash,
})
}
pub async fn index_document_with_preprocessing_and_dedup(
&self,
path: &Path,
namespace: Option<&str>,
preprocess_config: PreprocessingConfig,
) -> Result<IndexResult> {
let text = self.extract_text(path).await?;
let ns = namespace.unwrap_or(DEFAULT_NAMESPACE);
let content_hash = compute_content_hash(&text);
if self.storage.has_content_hash(ns, &content_hash).await? {
debug!(
"Skipping duplicate content: {} (hash: {})",
path.display(),
&content_hash[..16]
);
return Ok(IndexResult::Skipped {
reason: "exact duplicate".to_string(),
content_hash,
});
}
let preprocessor = Preprocessor::new(preprocess_config);
let cleaned = preprocessor.extract_semantic_content(&text);
tracing::info!(
"Preprocessing: {} chars -> {} chars ({:.1}% reduction)",
text.len(),
cleaned.len(),
(1.0 - (cleaned.len() as f32 / text.len() as f32)) * 100.0
);
let base_metadata = json!({
"path": path.to_str(),
"slice_mode": "flat",
"content_hash": &content_hash,
});
let chunks_indexed = self
.index_with_flat_chunking_and_hash(&cleaned, ns, path, base_metadata, &content_hash)
.await?;
Ok(IndexResult::Indexed {
chunks_indexed,
content_hash,
})
}
async fn index_document_internal(
&self,
path: &Path,
namespace: Option<&str>,
preprocess_config: Option<PreprocessingConfig>,
slice_mode: SliceMode,
) -> Result<()> {
let text = self.extract_text(path).await?;
let text = if let Some(config) = preprocess_config {
let preprocessor = Preprocessor::new(config);
let cleaned = preprocessor.extract_semantic_content(&text);
tracing::info!(
"Preprocessing: {} chars -> {} chars ({:.1}% reduction)",
text.len(),
cleaned.len(),
(1.0 - (cleaned.len() as f32 / text.len() as f32)) * 100.0
);
cleaned
} else {
text
};
let ns = namespace.unwrap_or(DEFAULT_NAMESPACE);
let base_metadata = json!({
"path": path.to_str(),
"slice_mode": match slice_mode {
SliceMode::Onion => "onion",
SliceMode::Flat => "flat",
}
});
match slice_mode {
SliceMode::Onion => {
self.index_with_onion_slicing(&text, ns, base_metadata)
.await
}
SliceMode::Flat => {
self.index_with_flat_chunking(&text, ns, path, base_metadata)
.await
}
}
}
async fn index_with_onion_slicing(
&self,
text: &str,
namespace: &str,
base_metadata: serde_json::Value,
) -> Result<()> {
let config = OnionSliceConfig::default();
let slices = create_onion_slices(text, &base_metadata, &config);
let total_slices = slices.len();
tracing::info!(
"Onion slicing: {} chars -> {} slices (outer/middle/inner/core)",
text.len(),
total_slices
);
let mut total_stored = 0;
for batch in slices.chunks(STORAGE_BATCH_SIZE) {
let batch_contents: Vec<String> = batch.iter().map(|s| s.content.clone()).collect();
let embeddings = self.embed_chunks(&batch_contents).await?;
let mut batch_docs = Vec::with_capacity(batch.len());
for (slice, embedding) in batch.iter().zip(embeddings.iter()) {
let mut metadata = base_metadata.clone();
if let serde_json::Value::Object(ref mut map) = metadata {
map.insert("layer".to_string(), json!(slice.layer.name()));
map.insert("keywords".to_string(), json!(slice.keywords));
}
let doc = ChromaDocument::from_onion_slice(
slice,
namespace.to_string(),
embedding.clone(),
metadata,
);
batch_docs.push(doc);
}
self.storage.add_to_store(batch_docs).await?;
total_stored += batch.len();
tracing::info!("Stored {}/{} slices", total_stored, total_slices);
}
Ok(())
}
async fn index_with_onion_slicing_and_hash(
&self,
text: &str,
namespace: &str,
base_metadata: serde_json::Value,
content_hash: &str,
) -> Result<usize> {
let config = OnionSliceConfig::default();
let slices = create_onion_slices(text, &base_metadata, &config);
let total_slices = slices.len();
tracing::info!(
"Onion slicing: {} chars -> {} slices (outer/middle/inner/core)",
text.len(),
total_slices
);
let mut total_stored = 0;
for batch in slices.chunks(STORAGE_BATCH_SIZE) {
let batch_contents: Vec<String> = batch.iter().map(|s| s.content.clone()).collect();
let embeddings = self.embed_chunks(&batch_contents).await?;
let mut batch_docs = Vec::with_capacity(batch.len());
for (slice, embedding) in batch.iter().zip(embeddings.iter()) {
let mut metadata = base_metadata.clone();
if let serde_json::Value::Object(ref mut map) = metadata {
map.insert("layer".to_string(), json!(slice.layer.name()));
map.insert("keywords".to_string(), json!(slice.keywords));
}
let doc = ChromaDocument::from_onion_slice_with_hash(
slice,
namespace.to_string(),
embedding.clone(),
metadata,
content_hash.to_string(),
);
batch_docs.push(doc);
}
self.storage.add_to_store(batch_docs).await?;
total_stored += batch.len();
tracing::info!("Stored {}/{} slices", total_stored, total_slices);
}
Ok(total_slices)
}
async fn index_with_flat_chunking(
&self,
text: &str,
namespace: &str,
path: &Path,
base_metadata: serde_json::Value,
) -> Result<()> {
let chunks = self.chunk_text(text, 512, 128)?;
let total_chunks = chunks.len();
tracing::info!(
"Flat chunking: {} chars -> {} chunks",
text.len(),
total_chunks
);
let mut total_stored = 0;
let mut global_idx = 0;
for batch in chunks.chunks(STORAGE_BATCH_SIZE) {
let embeddings = self.embed_chunks(batch).await?;
let mut batch_docs = Vec::with_capacity(batch.len());
for (chunk, embedding) in batch.iter().zip(embeddings.iter()) {
let mut metadata = base_metadata.clone();
if let serde_json::Value::Object(ref mut map) = metadata {
map.insert("chunk_index".to_string(), json!(global_idx));
map.insert("total_chunks".to_string(), json!(total_chunks));
}
let doc = ChromaDocument::new_flat(
format!("{}_{}", path.to_str().unwrap_or("unknown"), global_idx),
namespace.to_string(),
embedding.clone(),
metadata,
chunk.clone(),
);
batch_docs.push(doc);
global_idx += 1;
}
self.storage.add_to_store(batch_docs).await?;
total_stored += batch.len();
tracing::info!("Stored {}/{} chunks", total_stored, total_chunks);
}
Ok(())
}
async fn index_with_flat_chunking_and_hash(
&self,
text: &str,
namespace: &str,
path: &Path,
base_metadata: serde_json::Value,
content_hash: &str,
) -> Result<usize> {
let chunks = self.chunk_text(text, 512, 128)?;
let total_chunks = chunks.len();
tracing::info!(
"Flat chunking: {} chars -> {} chunks",
text.len(),
total_chunks
);
let mut total_stored = 0;
let mut global_idx = 0;
for batch in chunks.chunks(STORAGE_BATCH_SIZE) {
let embeddings = self.embed_chunks(batch).await?;
let mut batch_docs = Vec::with_capacity(batch.len());
for (chunk, embedding) in batch.iter().zip(embeddings.iter()) {
let mut metadata = base_metadata.clone();
if let serde_json::Value::Object(ref mut map) = metadata {
map.insert("chunk_index".to_string(), json!(global_idx));
map.insert("total_chunks".to_string(), json!(total_chunks));
}
let doc = ChromaDocument::new_flat_with_hash(
format!("{}_{}", path.to_str().unwrap_or("unknown"), global_idx),
namespace.to_string(),
embedding.clone(),
metadata,
chunk.clone(),
content_hash.to_string(),
);
batch_docs.push(doc);
global_idx += 1;
}
self.storage.add_to_store(batch_docs).await?;
total_stored += batch.len();
tracing::info!("Stored {}/{} chunks", total_stored, total_chunks);
}
Ok(total_chunks)
}
pub async fn index_text(
&self,
namespace: Option<&str>,
id: String,
text: String,
metadata: serde_json::Value,
) -> Result<String> {
self.index_text_with_mode(namespace, id, text, metadata, SliceMode::default())
.await
}
pub async fn index_text_with_mode(
&self,
namespace: Option<&str>,
id: String,
text: String,
metadata: serde_json::Value,
slice_mode: SliceMode,
) -> Result<String> {
let ns = namespace.unwrap_or(DEFAULT_NAMESPACE).to_string();
match slice_mode {
SliceMode::Onion => {
let config = OnionSliceConfig::default();
let slices = create_onion_slices(&text, &metadata, &config);
let slice_contents: Vec<String> =
slices.iter().map(|s| s.content.clone()).collect();
let embeddings = self.embed_chunks(&slice_contents).await?;
let mut documents = Vec::with_capacity(slices.len());
for (slice, embedding) in slices.iter().zip(embeddings.iter()) {
let mut meta = metadata.clone();
if let serde_json::Value::Object(ref mut map) = meta {
map.insert("layer".to_string(), json!(slice.layer.name()));
map.insert("original_id".to_string(), json!(id));
}
let doc = ChromaDocument::from_onion_slice(
slice,
ns.clone(),
embedding.clone(),
meta,
);
documents.push(doc);
}
self.storage.add_to_store(documents).await?;
Ok(slices
.iter()
.find(|s| s.layer == SliceLayer::Outer)
.map(|s| s.id.clone())
.unwrap_or(id))
}
SliceMode::Flat => {
let embedding = self.embed_query(&text).await?;
let doc = ChromaDocument::new_flat(id.clone(), ns, embedding, metadata, text);
self.storage.add_to_store(vec![doc]).await?;
Ok(id)
}
}
}
pub async fn memory_upsert(
&self,
namespace: &str,
id: String,
text: String,
metadata: serde_json::Value,
) -> Result<()> {
self.index_text(Some(namespace), id, text, metadata).await?;
Ok(())
}
pub async fn memory_get(&self, namespace: &str, id: &str) -> Result<Option<SearchResult>> {
if let Some(doc) = self.storage.get_document(namespace, id).await? {
let layer = doc.slice_layer();
return Ok(Some(SearchResult {
id: doc.id,
namespace: doc.namespace,
text: doc.document,
score: 1.0,
metadata: doc.metadata,
layer,
parent_id: doc.parent_id,
children_ids: doc.children_ids,
keywords: doc.keywords,
}));
}
Ok(None)
}
pub async fn memory_delete(&self, namespace: &str, id: &str) -> Result<usize> {
self.storage.delete_document(namespace, id).await
}
pub async fn purge_namespace(&self, namespace: &str) -> Result<usize> {
self.storage.purge_namespace(namespace).await
}
pub async fn memory_search(
&self,
namespace: &str,
query: &str,
k: usize,
) -> Result<Vec<SearchResult>> {
self.search_with_options(Some(namespace), query, k, SearchOptions::default())
.await
}
pub async fn memory_search_with_layer(
&self,
namespace: &str,
query: &str,
k: usize,
layer: Option<SliceLayer>,
) -> Result<Vec<SearchResult>> {
self.search_with_options(
Some(namespace),
query,
k,
SearchOptions {
layer_filter: layer,
},
)
.await
}
pub async fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
self.search_inner(None, query, k).await
}
pub async fn search_inner(
&self,
namespace: Option<&str>,
query: &str,
k: usize,
) -> Result<Vec<SearchResult>> {
self.search_with_options(namespace, query, k, SearchOptions::default())
.await
}
pub async fn search_with_options(
&self,
namespace: Option<&str>,
query: &str,
k: usize,
options: SearchOptions,
) -> Result<Vec<SearchResult>> {
let query_embedding = self.embed_query(query).await?;
let candidates = self
.storage
.search_store_with_layer(
namespace,
query_embedding.clone(),
k * 3,
options.layer_filter,
)
.await?;
if !candidates.is_empty() {
let documents: Vec<String> = candidates.iter().map(|c| c.document.clone()).collect();
let metadatas: Vec<serde_json::Value> =
candidates.iter().map(|c| c.metadata.clone()).collect();
let reranked = match self.mlx_bridge.lock().await.rerank(query, &documents).await {
Ok(r) => Some(r),
Err(e) => {
tracing::warn!("MLX rerank failed, using cosine fallback: {}", e);
None
}
};
let reranked = if let Some(r) = reranked {
r
} else {
let doc_embeddings = self.ensure_doc_embeddings(&documents, &candidates).await?;
let scores = doc_embeddings
.iter()
.enumerate()
.map(|(idx, emb)| (idx, cosine(&query_embedding, emb)))
.collect::<Vec<_>>();
let mut scores = scores;
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
};
let results: Vec<SearchResult> = reranked
.into_iter()
.take(k)
.filter_map(|(idx, score)| {
candidates.get(idx).map(|candidate| {
SearchResult {
id: candidate.id.clone(),
namespace: candidate.namespace.clone(),
text: candidate.document.clone(),
score,
metadata: metadatas.get(idx).cloned().unwrap_or_else(|| json!({})),
layer: candidate.slice_layer(),
parent_id: candidate.parent_id.clone(),
children_ids: candidate.children_ids.clone(),
keywords: candidate.keywords.clone(),
}
})
})
.collect();
return Ok(results);
}
Ok(vec![])
}
pub async fn expand_result(&self, namespace: &str, id: &str) -> Result<Vec<SearchResult>> {
let children = self.storage.get_children(namespace, id).await?;
Ok(children
.into_iter()
.map(|doc| {
let layer = doc.slice_layer();
SearchResult {
id: doc.id,
namespace: doc.namespace,
text: doc.document,
score: 1.0,
metadata: doc.metadata,
layer,
parent_id: doc.parent_id,
children_ids: doc.children_ids,
keywords: doc.keywords,
}
})
.collect())
}
pub async fn get_parent_result(
&self,
namespace: &str,
id: &str,
) -> Result<Option<SearchResult>> {
if let Some(parent) = self.storage.get_parent(namespace, id).await? {
let layer = parent.slice_layer();
return Ok(Some(SearchResult {
id: parent.id,
namespace: parent.namespace,
text: parent.document,
score: 1.0,
metadata: parent.metadata,
layer,
parent_id: parent.parent_id,
children_ids: parent.children_ids,
keywords: parent.keywords,
}));
}
Ok(None)
}
async fn extract_text(&self, path: &Path) -> Result<String> {
let ext = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
if ext == "pdf" {
let path = path.to_path_buf();
let pdf_text =
tokio::task::spawn_blocking(move || pdf_extract::extract_text(&path)).await??;
return Ok(pdf_text);
}
tokio::fs::read_to_string(path).await.map_err(|e| e.into())
}
async fn embed_chunks(&self, chunks: &[String]) -> Result<Vec<Vec<f32>>> {
self.mlx_bridge.lock().await.embed_batch(chunks).await
}
async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
self.mlx_bridge.lock().await.embed(query).await
}
async fn ensure_doc_embeddings(
&self,
documents: &[String],
candidates: &[ChromaDocument],
) -> Result<Vec<Vec<f32>>> {
let has_all = candidates.iter().all(|c| !c.embedding.is_empty());
if has_all {
return Ok(candidates.iter().map(|c| c.embedding.clone()).collect());
}
self.mlx_bridge.lock().await.embed_batch(documents).await
}
fn chunk_text(&self, text: &str, target_size: usize, overlap: usize) -> Result<Vec<String>> {
let sentences = split_into_sentences(text);
if sentences.is_empty() {
return Ok(vec![text.to_string()]);
}
if text.chars().count() <= target_size {
return Ok(vec![text.to_string()]);
}
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut overlap_sentences: Vec<String> = Vec::new();
let overlap_sentence_count = (overlap / 50).clamp(1, 3);
for sentence in &sentences {
let sentence_len = sentence.chars().count();
let current_len = current_chunk.chars().count();
let max_size = target_size + target_size / 2;
if current_len + sentence_len > max_size && !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
current_chunk = overlap_sentences.join(" ");
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
overlap_sentences.clear();
}
current_chunk.push_str(sentence);
current_chunk.push(' ');
overlap_sentences.push(sentence.clone());
if overlap_sentences.len() > overlap_sentence_count {
overlap_sentences.remove(0);
}
if current_chunk.chars().count() >= target_size {
chunks.push(current_chunk.trim().to_string());
current_chunk = overlap_sentences.join(" ");
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
overlap_sentences.clear();
}
}
let remaining = current_chunk.trim();
if !remaining.is_empty() {
if remaining.chars().count() < target_size / 4 && !chunks.is_empty() {
let last_idx = chunks.len() - 1;
chunks[last_idx].push(' ');
chunks[last_idx].push_str(remaining);
} else {
chunks.push(remaining.to_string());
}
}
if chunks.is_empty() {
chunks.push(text.to_string());
}
Ok(chunks)
}
}
#[derive(Debug, Clone)]
pub struct ContextPrefixConfig {
pub include_source: bool,
pub include_section: bool,
pub include_doc_type: bool,
pub max_prefix_length: usize,
}
impl Default for ContextPrefixConfig {
fn default() -> Self {
Self {
include_source: true,
include_section: true,
include_doc_type: true,
max_prefix_length: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct EnrichedChunk {
pub content: String,
pub original_content: String,
pub doc_path: String,
pub chunk_index: usize,
pub section: Option<String>,
pub doc_type: Option<String>,
}
pub fn create_enriched_chunks(
content: &str,
doc_path: &str,
chunk_size: usize,
overlap: usize,
config: &ContextPrefixConfig,
) -> Vec<EnrichedChunk> {
let doc_type = detect_doc_type(doc_path);
let filename = std::path::Path::new(doc_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
let sections = extract_sections(content);
let mut enriched_chunks = Vec::new();
let mut global_chunk_index = 0;
for (section_header, section_content) in sections {
let chunks = smart_chunk_text(section_content, chunk_size, overlap);
for chunk in chunks {
let prefix = build_context_prefix(
filename,
section_header.as_deref(),
doc_type.as_deref(),
config,
);
let full_content = if prefix.is_empty() {
chunk.clone()
} else {
format!("{}\n\n{}", prefix, chunk)
};
enriched_chunks.push(EnrichedChunk {
content: full_content,
original_content: chunk,
doc_path: doc_path.to_string(),
chunk_index: global_chunk_index,
section: section_header.clone(),
doc_type: doc_type.clone(),
});
global_chunk_index += 1;
}
}
if enriched_chunks.is_empty() && !content.trim().is_empty() {
let prefix = build_context_prefix(filename, None, doc_type.as_deref(), config);
let full_content = if prefix.is_empty() {
content.to_string()
} else {
format!("{}\n\n{}", prefix, content)
};
enriched_chunks.push(EnrichedChunk {
content: full_content,
original_content: content.to_string(),
doc_path: doc_path.to_string(),
chunk_index: 0,
section: None,
doc_type,
});
}
enriched_chunks
}
fn build_context_prefix(
filename: &str,
section: Option<&str>,
doc_type: Option<&str>,
config: &ContextPrefixConfig,
) -> String {
let mut parts = Vec::new();
if config.include_source && !filename.is_empty() {
parts.push(format!("[Source: {}]", filename));
}
if config.include_section
&& let Some(sec) = section
{
parts.push(format!("[Section: {}]", sec));
}
if config.include_doc_type
&& let Some(dt) = doc_type
{
parts.push(format!("[Type: {}]", dt));
}
let prefix = parts.join(" ");
if prefix.len() > config.max_prefix_length {
prefix.chars().take(config.max_prefix_length).collect()
} else {
prefix
}
}
fn detect_doc_type(path: &str) -> Option<String> {
let ext = std::path::Path::new(path)
.extension()
.and_then(|e| e.to_str())
.map(|s| s.to_lowercase())?;
let doc_type = match ext.as_str() {
"rs" => "Rust source code",
"py" => "Python source code",
"js" | "jsx" => "JavaScript source code",
"ts" | "tsx" => "TypeScript source code",
"md" => "Markdown documentation",
"txt" => "Plain text",
"json" => "JSON data",
"yaml" | "yml" => "YAML configuration",
"toml" => "TOML configuration",
"html" => "HTML document",
"css" => "CSS stylesheet",
"sql" => "SQL query",
"sh" | "bash" => "Shell script",
"pdf" => "PDF document",
_ => return None,
};
Some(doc_type.to_string())
}
fn extract_sections(content: &str) -> Vec<(Option<String>, &str)> {
let header_pattern = regex::Regex::new(r"(?m)^(#{1,6})\s+(.+)$").ok();
if let Some(re) = header_pattern {
let mut sections = Vec::new();
let mut last_end = 0;
let mut current_header: Option<String> = None;
for caps in re.captures_iter(content) {
let match_start = caps.get(0).unwrap().start();
if match_start > last_end {
let section_content = &content[last_end..match_start];
if !section_content.trim().is_empty() {
sections.push((current_header.clone(), section_content.trim()));
}
}
current_header = Some(caps.get(2).unwrap().as_str().to_string());
last_end = caps.get(0).unwrap().end();
}
if last_end < content.len() {
let section_content = &content[last_end..];
if !section_content.trim().is_empty() {
sections.push((current_header, section_content.trim()));
}
}
if sections.is_empty() {
vec![(None, content)]
} else {
sections
}
} else {
vec![(None, content)]
}
}
fn smart_chunk_text(text: &str, target_size: usize, overlap: usize) -> Vec<String> {
let sentences = split_into_sentences(text);
if sentences.is_empty() || text.chars().count() <= target_size {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut overlap_sentences: Vec<String> = Vec::new();
let overlap_sentence_count = (overlap / 50).clamp(1, 3);
for sentence in &sentences {
let sentence_len = sentence.chars().count();
let current_len = current_chunk.chars().count();
let max_size = target_size + target_size / 2;
if current_len + sentence_len > max_size && !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
current_chunk = overlap_sentences.join(" ");
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
overlap_sentences.clear();
}
current_chunk.push_str(sentence);
current_chunk.push(' ');
overlap_sentences.push(sentence.clone());
if overlap_sentences.len() > overlap_sentence_count {
overlap_sentences.remove(0);
}
if current_chunk.chars().count() >= target_size {
chunks.push(current_chunk.trim().to_string());
current_chunk = overlap_sentences.join(" ");
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
overlap_sentences.clear();
}
}
let remaining = current_chunk.trim();
if !remaining.is_empty() {
if remaining.chars().count() < target_size / 4 && !chunks.is_empty() {
let last_idx = chunks.len() - 1;
chunks[last_idx].push(' ');
chunks[last_idx].push_str(remaining);
} else {
chunks.push(remaining.to_string());
}
}
if chunks.is_empty() {
chunks.push(text.to_string());
}
chunks
}
fn split_into_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current = String::new();
let mut chars = text.chars().peekable();
while let Some(c) = chars.next() {
current.push(c);
if matches!(c, '.' | '!' | '?') {
if let Some(&next) = chars.peek() {
if next.is_whitespace() {
let trimmed = current.trim();
let is_abbreviation = trimmed.ends_with("Mr.")
|| trimmed.ends_with("Mrs.")
|| trimmed.ends_with("Dr.")
|| trimmed.ends_with("Prof.")
|| trimmed.ends_with("vs.")
|| trimmed.ends_with("etc.")
|| trimmed.ends_with("e.g.")
|| trimmed.ends_with("i.e.")
|| (trimmed.len() >= 2 && trimmed.chars().rev().nth(1).map(|c| c.is_uppercase()).unwrap_or(false));
if !is_abbreviation {
sentences.push(current.trim().to_string());
current = String::new();
chars.next();
}
}
} else {
sentences.push(current.trim().to_string());
current = String::new();
}
} else if c == '\n' {
if let Some(&next) = chars.peek()
&& next == '\n'
{
if !current.trim().is_empty() {
sentences.push(current.trim().to_string());
current = String::new();
}
chars.next(); }
}
}
let remaining = current.trim();
if !remaining.is_empty() {
sentences.push(remaining.to_string());
}
sentences
}
#[derive(Debug, Clone, Default)]
pub struct SearchOptions {
pub layer_filter: Option<SliceLayer>,
}
impl SearchOptions {
pub fn outer_only() -> Self {
Self {
layer_filter: Some(SliceLayer::Outer),
}
}
pub fn deep() -> Self {
Self { layer_filter: None }
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct SearchResult {
pub id: String,
pub namespace: String,
pub text: String,
pub score: f32,
pub metadata: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub layer: Option<SliceLayer>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_id: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub children_ids: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub keywords: Vec<String>,
}
impl SearchResult {
pub fn new_legacy(
id: String,
namespace: String,
text: String,
score: f32,
metadata: serde_json::Value,
) -> Self {
Self {
id,
namespace,
text,
score,
metadata,
layer: None,
parent_id: None,
children_ids: vec![],
keywords: vec![],
}
}
pub fn can_expand(&self) -> bool {
!self.children_ids.is_empty()
}
pub fn can_drill_up(&self) -> bool {
self.parent_id.is_some()
}
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a.sqrt() * norm_b.sqrt())
}