use crate::EmbeddingModel;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: String,
pub path: PathBuf,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub document: Document,
pub score: f32,
}
pub struct VectorSearcher {
model: EmbeddingModel,
documents: Vec<Document>,
index: Vec<Vec<f32>>,
}
impl VectorSearcher {
pub fn new(model: EmbeddingModel) -> Self {
Self { model, documents: Vec::new(), index: Vec::new() }
}
pub fn add_document(&mut self, doc: Document) -> Result<()> {
info!("添加文档到索引: {:?}", doc.path);
let embedding = self.model.encode(&doc.content)?;
self.documents.push(doc);
self.index.push(embedding);
debug!("当前索引文档数: {}", self.documents.len());
Ok(())
}
pub fn add_documents(&mut self, docs: Vec<Document>) -> Result<usize> {
let total = docs.len();
info!("批量添加 {} 个文档到索引", total);
let mut success_count = 0;
for doc in docs {
if self.add_document(doc).is_ok() {
success_count += 1;
}
}
info!("成功添加 {}/{} 个文档", success_count, total);
Ok(success_count)
}
pub fn search(&mut self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
if self.documents.is_empty() {
return Ok(Vec::new());
}
debug!("语义搜索: \"{}\"", query);
let query_embedding = self.model.encode(query)?;
let mut scores: Vec<(usize, f32)> = Vec::new();
for (i, doc_embedding) in self.index.iter().enumerate() {
let score = EmbeddingModel::cosine_similarity(&query_embedding, doc_embedding);
scores.push((i, score));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<SearchResult> = scores
.into_iter()
.take(top_k)
.map(|(i, score)| SearchResult { document: self.documents[i].clone(), score })
.collect();
debug!("找到 {} 个结果", results.len());
Ok(results)
}
pub fn document_count(&self) -> usize {
self.documents.len()
}
pub fn clear(&mut self) {
self.documents.clear();
self.index.clear();
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&c, &d) - 0.0).abs() < 1e-6);
let e = vec![1.0, 1.0];
let f = vec![-1.0, -1.0];
assert!((cosine_similarity(&e, &f) + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_edge_cases() {
let empty: Vec<f32> = vec![];
assert_eq!(cosine_similarity(&empty, &empty), 0.0);
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
let zero = vec![0.0, 0.0, 0.0];
let nonzero = vec![1.0, 2.0, 3.0];
assert_eq!(cosine_similarity(&zero, &nonzero), 0.0);
}
}