#![allow(dead_code)]
use super::{cosine_similarity, CodeEmbedding, EmbeddingGenerator, SearchResult, SemanticConfig};
use anyhow::Result;
use std::collections::BinaryHeap;
use std::cmp::Ordering;
use tracing::{info, debug};
pub struct SemanticSearch {
generator: EmbeddingGenerator,
config: SemanticConfig,
embeddings: Vec<CodeEmbedding>,
}
#[derive(Debug, Clone)]
struct ScoredResult {
embedding: CodeEmbedding,
score: f32,
}
impl Ord for ScoredResult {
fn cmp(&self, other: &Self) -> Ordering {
self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for ScoredResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for ScoredResult {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for ScoredResult {}
impl SemanticSearch {
pub fn new(config: SemanticConfig, embeddings: Vec<CodeEmbedding>) -> Result<Self> {
let generator = EmbeddingGenerator::new(config.clone())?;
Ok(Self {
generator,
config,
embeddings,
})
}
pub async fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
info!("Searching for: {}", query);
let query_embedding = self.generator.embed_single(query).await?;
let results = self.find_similar(&query_embedding, top_k);
debug!("Found {} results", results.len());
Ok(results)
}
fn find_similar(&self, query_embedding: &[f32], top_k: usize) -> Vec<SearchResult> {
let mut heap: BinaryHeap<ScoredResult> = BinaryHeap::new();
for embedding in &self.embeddings {
let score = cosine_similarity(query_embedding, &embedding.embedding);
if score >= self.config.similarity_threshold {
if heap.len() < top_k {
heap.push(ScoredResult {
embedding: embedding.clone(),
score,
});
} else if let Some(min) = heap.peek() {
if score > min.score {
heap.pop();
heap.push(ScoredResult {
embedding: embedding.clone(),
score,
});
}
}
}
}
let mut results: Vec<ScoredResult> = heap.into_vec();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.into_iter()
.map(|scored| SearchResult {
highlights: generate_highlights(&scored.embedding.content, query_embedding),
embedding: scored.embedding,
score: scored.score,
})
.collect()
}
pub async fn search_filtered(
&self,
query: &str,
top_k: usize,
language: Option<&str>,
repository: Option<&str>,
tags: Option<&[String]>,
) -> Result<Vec<SearchResult>> {
let filtered: Vec<&CodeEmbedding> = self.embeddings.iter()
.filter(|e| {
if let Some(lang) = language {
if e.metadata.language.to_lowercase() != lang.to_lowercase() {
return false;
}
}
if let Some(repo) = repository {
if !e.metadata.repository.contains(repo) {
return false;
}
}
if let Some(tag_list) = tags {
if !tag_list.iter().all(|tag| e.metadata.tags.contains(&tag.to_lowercase())) {
return false;
}
}
true
})
.collect();
if filtered.is_empty() {
return Ok(Vec::new());
}
let query_embedding = self.generator.embed_single(query).await?;
let mut scored: Vec<(f32, &CodeEmbedding)> = filtered.into_iter()
.map(|e| (cosine_similarity(&query_embedding, &e.embedding), e))
.filter(|(score, _)| *score >= self.config.similarity_threshold)
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
Ok(scored.into_iter()
.take(top_k)
.map(|(score, embedding)| SearchResult {
highlights: generate_highlights(&embedding.content, &query_embedding),
embedding: embedding.clone(),
score,
})
.collect())
}
pub fn add_embeddings(&mut self, embeddings: Vec<CodeEmbedding>) {
self.embeddings.extend(embeddings);
}
pub fn stats(&self) -> SearchStats {
let languages: std::collections::HashSet<String> = self.embeddings.iter()
.map(|e| e.metadata.language.clone())
.collect();
let repositories: std::collections::HashSet<String> = self.embeddings.iter()
.map(|e| e.metadata.repository.clone())
.collect();
SearchStats {
total_embeddings: self.embeddings.len(),
unique_languages: languages.len(),
unique_repositories: repositories.len(),
languages: languages.into_iter().collect(),
repositories: repositories.into_iter().collect(),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct SearchStats {
pub total_embeddings: usize,
pub unique_languages: usize,
pub unique_repositories: usize,
pub languages: Vec<String>,
pub repositories: Vec<String>,
}
fn generate_highlights(content: &str, _query_embedding: &[f32]) -> Vec<String> {
let lines: Vec<&str> = content.lines().collect();
let mut highlights = Vec::new();
for line in lines.iter().take(3) {
let trimmed = line.trim();
if !trimmed.is_empty() && !trimmed.starts_with("//") && !trimmed.starts_with("#") {
highlights.push(trimmed.to_string());
if highlights.len() >= 2 {
break;
}
}
}
highlights
}
pub fn hybrid_search(
semantic_results: &[SearchResult],
keyword_query: &str,
top_k: usize,
) -> Vec<SearchResult> {
let keyword_lower = keyword_query.to_lowercase();
let keywords: Vec<&str> = keyword_lower.split_whitespace().collect();
let mut scored: Vec<(f32, SearchResult)> = semantic_results.iter()
.map(|r| {
let mut score = r.score;
let content_lower = r.embedding.content.to_lowercase();
for keyword in &keywords {
if content_lower.contains(keyword) {
score += 0.1; }
if r.embedding.metadata.function_name.as_ref()
.map(|f| f.to_lowercase().contains(keyword))
.unwrap_or(false)
{
score += 0.15; }
}
(score, r.clone())
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
scored.into_iter()
.take(top_k)
.map(|(_, result)| result)
.collect()
}