use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
use crate::embeddings_generator::EmbeddingsGenerator;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingChunk {
pub file_path: String,
pub content: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsDatabaseMetadata {
pub version: String,
pub generated_at: String,
pub model: String,
pub chunk_size: usize,
pub overlap_size: usize,
pub total_files: usize,
pub total_chunks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsDatabase {
pub version: String,
pub generated_at: String,
pub model: String,
pub chunk_size: usize,
pub overlap_size: usize,
pub total_files: usize,
pub total_chunks: usize,
pub chunks: Vec<EmbeddingChunk>,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub file_path: String,
pub content: String,
pub similarity: f32,
}
pub struct SemanticSearch {
database: EmbeddingsDatabase,
generator: EmbeddingsGenerator,
}
impl SemanticSearch {
pub fn new<P: AsRef<Path>>(embeddings_path: P) -> Result<Self> {
let contents = std::fs::read_to_string(embeddings_path.as_ref())
.context("Failed to read embeddings file")?;
let database: EmbeddingsDatabase = serde_json::from_str(&contents)
.context("Failed to parse embeddings JSON")?;
let generator = EmbeddingsGenerator::new()
.context("Failed to initialize embeddings generator")?;
Ok(Self {
database,
generator,
})
}
pub fn metadata(&self) -> EmbeddingsDatabaseMetadata {
EmbeddingsDatabaseMetadata {
version: self.database.version.clone(),
generated_at: self.database.generated_at.clone(),
model: self.database.model.clone(),
chunk_size: self.database.chunk_size,
overlap_size: self.database.overlap_size,
total_files: self.database.total_files,
total_chunks: self.database.total_chunks,
}
}
pub fn search(&mut self, query: &str, top_n: usize) -> Result<Vec<SearchResult>> {
let query_embedding = self.generator.generate_embedding(query)
.context("Failed to generate query embedding")?;
let mut results: Vec<SearchResult> = self.database.chunks
.iter()
.map(|chunk| {
let similarity = cosine_similarity(&query_embedding, &chunk.embedding);
SearchResult {
file_path: chunk.file_path.clone(),
content: chunk.content.clone(),
similarity,
}
})
.collect();
results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_n);
Ok(results)
}
pub fn chunk_count(&self) -> usize {
self.database.chunks.len()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let similarity = cosine_similarity(&a, &b);
assert!((similarity - 1.0).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let similarity = cosine_similarity(&a, &b);
assert!((similarity - 0.0).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let similarity = cosine_similarity(&a, &b);
assert!((similarity - (-1.0)).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let similarity = cosine_similarity(&a, &b);
assert_eq!(similarity, 0.0);
}
}