use canon_core::{Chunk, CPError, Result};
use canon_embed::EmbeddingEngine;
use canon_store::GraphStore;
use std::sync::{Arc, Mutex};
use tracing::{info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, serde::Serialize)]
pub struct SearchResult {
pub chunk: Chunk,
pub score: f32,
pub doc_path: String,
}
pub struct QueryEngine {
graph: Arc<Mutex<GraphStore>>,
embedder: Arc<EmbeddingEngine>,
}
impl QueryEngine {
pub fn new(graph: Arc<Mutex<GraphStore>>, embedder: Arc<EmbeddingEngine>) -> Self {
Self { graph, embedder }
}
pub fn search(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
info!("Hybrid search for: '{}'", query);
let query_vec = self
.embedder
.embed(query)
.map_err(|e| CPError::Embedding(format!("Failed to embed query: {}", e)))?;
let semantic_results = {
let graph = self.graph.lock().unwrap();
graph.search(&query_vec, k)?
};
let lexical_results = {
let graph = self.graph.lock().unwrap();
let fts_query = if query.contains(' ') {
format!("\"{}\"", query.replace('"', ""))
} else {
query.to_string()
};
graph.search_lexical(&fts_query, k).unwrap_or_else(|e| {
warn!("Lexical search failed: {}. Falling back to semantic only.", e);
Vec::new()
})
};
const RRF_K: u64 = 60;
const RRF_SCALE: u64 = 1_000_000;
let mut scores: std::collections::HashMap<Uuid, u64> = std::collections::HashMap::new();
{
let graph = self.graph.lock().unwrap();
for (i, (emb_id, _)) in semantic_results.iter().enumerate() {
if let Ok(Some(chunk_id)) = graph.get_chunk_id_for_embedding(*emb_id) {
let score = RRF_SCALE / (RRF_K + i as u64);
*scores.entry(chunk_id).or_insert(0) += score;
}
}
for (i, (chunk_id, _)) in lexical_results.iter().enumerate() {
let score = RRF_SCALE / (RRF_K + i as u64);
*scores.entry(*chunk_id).or_insert(0) += score;
}
}
let mut fused: Vec<(Uuid, u64)> = scores.into_iter().collect();
fused.sort_by(|a, b| {
b.1.cmp(&a.1)
.then_with(|| a.0.cmp(&b.0))
});
fused.truncate(k);
let mut search_results = Vec::with_capacity(fused.len());
let graph = self.graph.lock().unwrap();
for (chunk_id, fused_score) in fused {
let chunk = match graph.get_chunk(chunk_id)? {
Some(c) => c,
None => continue,
};
let doc = match graph.get_document(chunk.doc_id)? {
Some(d) => d,
None => continue,
};
let normalized_score = fused_score as f32 / (RRF_SCALE * 2) as f32;
search_results.push(SearchResult {
chunk,
score: normalized_score,
doc_path: doc.path.to_string_lossy().to_string(),
});
}
Ok(search_results)
}
pub fn search_semantic(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
info!("Semantic search for: '{}'", query);
let query_vec = self
.embedder
.embed(query)
.map_err(|e| CPError::Embedding(format!("Failed to embed query: {}", e)))?;
let raw_results = {
let graph = self.graph.lock().unwrap();
graph.search(&query_vec, k)?
};
let mut search_results = Vec::with_capacity(raw_results.len());
let graph = self.graph.lock().unwrap();
for (emb_id, score) in raw_results {
if let Some(chunk_id) = graph.get_chunk_id_for_embedding(emb_id)? {
if let Some(chunk) = graph.get_chunk(chunk_id)? {
if let Some(doc) = graph.get_document(chunk.doc_id)? {
search_results.push(SearchResult {
chunk,
score,
doc_path: doc.path.to_string_lossy().to_string(),
});
}
}
}
}
Ok(search_results)
}
pub fn search_lexical(&self, query: &str, k: usize) -> Result<Vec<SearchResult>> {
info!("Lexical search for: '{}'", query);
let raw_results = {
let graph = self.graph.lock().unwrap();
let fts_query = if query.contains(' ') {
format!("\"{}\"", query.replace('"', ""))
} else {
query.to_string()
};
graph.search_lexical(&fts_query, k)?
};
let mut search_results = Vec::with_capacity(raw_results.len());
let graph = self.graph.lock().unwrap();
for (chunk_id, score) in raw_results {
if let Some(chunk) = graph.get_chunk(chunk_id)? {
if let Some(doc) = graph.get_document(chunk.doc_id)? {
search_results.push(SearchResult {
chunk,
score,
doc_path: doc.path.to_string_lossy().to_string(),
});
}
}
}
Ok(search_results)
}
}