i-self 0.4.3

Personal developer-companion CLI: scans your repos, indexes code semantically, watches your activity, and moves AI-agent sessions between tools (Claude Code, Aider, Goose, OpenAI Codex CLI, Continue.dev, OpenCode).
#![allow(dead_code)]

use super::{cosine_similarity, CodeEmbedding, EmbeddingGenerator, SearchResult, SemanticConfig};
use anyhow::Result;
use std::collections::BinaryHeap;
use std::cmp::Ordering;
use tracing::{info, debug};

/// Semantic search over code embeddings
pub struct SemanticSearch {
    generator: EmbeddingGenerator,
    config: SemanticConfig,
    embeddings: Vec<CodeEmbedding>,
}

/// Search result with ordering for priority queue
#[derive(Debug, Clone)]
struct ScoredResult {
    embedding: CodeEmbedding,
    score: f32,
}

impl Ord for ScoredResult {
    fn cmp(&self, other: &Self) -> Ordering {
        self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
    }
}

impl PartialOrd for ScoredResult {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl PartialEq for ScoredResult {
    fn eq(&self, other: &Self) -> bool {
        self.score == other.score
    }
}

impl Eq for ScoredResult {}

impl SemanticSearch {
    pub fn new(config: SemanticConfig, embeddings: Vec<CodeEmbedding>) -> Result<Self> {
        let generator = EmbeddingGenerator::new(config.clone())?;
        Ok(Self {
            generator,
            config,
            embeddings,
        })
    }

    /// Search for code semantically similar to the query
    pub async fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
        info!("Searching for: {}", query);

        // Generate embedding for the query
        let query_embedding = self.generator.embed_single(query).await?;
        
        // Find most similar embeddings
        let results = self.find_similar(&query_embedding, top_k);
        
        debug!("Found {} results", results.len());
        Ok(results)
    }

    /// Find similar embeddings using cosine similarity
    fn find_similar(&self, query_embedding: &[f32], top_k: usize) -> Vec<SearchResult> {
        let mut heap: BinaryHeap<ScoredResult> = BinaryHeap::new();

        for embedding in &self.embeddings {
            let score = cosine_similarity(query_embedding, &embedding.embedding);
            
            if score >= self.config.similarity_threshold {
                if heap.len() < top_k {
                    heap.push(ScoredResult {
                        embedding: embedding.clone(),
                        score,
                    });
                } else if let Some(min) = heap.peek() {
                    if score > min.score {
                        heap.pop();
                        heap.push(ScoredResult {
                            embedding: embedding.clone(),
                            score,
                        });
                    }
                }
            }
        }

        // Convert heap to sorted vector (highest score first)
        let mut results: Vec<ScoredResult> = heap.into_vec();
        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));

        results.into_iter()
            .map(|scored| SearchResult {
                highlights: generate_highlights(&scored.embedding.content, query_embedding),
                embedding: scored.embedding,
                score: scored.score,
            })
            .collect()
    }

    /// Search with filters
    pub async fn search_filtered(
        &self,
        query: &str,
        top_k: usize,
        language: Option<&str>,
        repository: Option<&str>,
        tags: Option<&[String]>,
    ) -> Result<Vec<SearchResult>> {
        // First filter embeddings
        let filtered: Vec<&CodeEmbedding> = self.embeddings.iter()
            .filter(|e| {
                if let Some(lang) = language {
                    if e.metadata.language.to_lowercase() != lang.to_lowercase() {
                        return false;
                    }
                }
                if let Some(repo) = repository {
                    if !e.metadata.repository.contains(repo) {
                        return false;
                    }
                }
                if let Some(tag_list) = tags {
                    if !tag_list.iter().all(|tag| e.metadata.tags.contains(&tag.to_lowercase())) {
                        return false;
                    }
                }
                true
            })
            .collect();

        if filtered.is_empty() {
            return Ok(Vec::new());
        }

        // Generate query embedding
        let query_embedding = self.generator.embed_single(query).await?;

        // Score filtered embeddings
        let mut scored: Vec<(f32, &CodeEmbedding)> = filtered.into_iter()
            .map(|e| (cosine_similarity(&query_embedding, &e.embedding), e))
            .filter(|(score, _)| *score >= self.config.similarity_threshold)
            .collect();

        // Sort by score
        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));

        // Take top_k
        Ok(scored.into_iter()
            .take(top_k)
            .map(|(score, embedding)| SearchResult {
                highlights: generate_highlights(&embedding.content, &query_embedding),
                embedding: embedding.clone(),
                score,
            })
            .collect())
    }

    /// Add new embeddings to the search index
    pub fn add_embeddings(&mut self, embeddings: Vec<CodeEmbedding>) {
        self.embeddings.extend(embeddings);
    }

    /// Get statistics about the index
    pub fn stats(&self) -> SearchStats {
        let languages: std::collections::HashSet<String> = self.embeddings.iter()
            .map(|e| e.metadata.language.clone())
            .collect();
        
        let repositories: std::collections::HashSet<String> = self.embeddings.iter()
            .map(|e| e.metadata.repository.clone())
            .collect();

        SearchStats {
            total_embeddings: self.embeddings.len(),
            unique_languages: languages.len(),
            unique_repositories: repositories.len(),
            languages: languages.into_iter().collect(),
            repositories: repositories.into_iter().collect(),
        }
    }
}

#[derive(Debug, Clone, serde::Serialize)]
pub struct SearchStats {
    pub total_embeddings: usize,
    pub unique_languages: usize,
    pub unique_repositories: usize,
    pub languages: Vec<String>,
    pub repositories: Vec<String>,
}

/// Generate highlighted snippets from content
fn generate_highlights(content: &str, _query_embedding: &[f32]) -> Vec<String> {
    // Simple highlighting - could be improved with token-level analysis
    let lines: Vec<&str> = content.lines().collect();
    let mut highlights = Vec::new();

    // Take first few non-empty lines as highlights
    for line in lines.iter().take(3) {
        let trimmed = line.trim();
        if !trimmed.is_empty() && !trimmed.starts_with("//") && !trimmed.starts_with("#") {
            highlights.push(trimmed.to_string());
            if highlights.len() >= 2 {
                break;
            }
        }
    }

    highlights
}

/// Hybrid search combining semantic and keyword matching
pub fn hybrid_search(
    semantic_results: &[SearchResult],
    keyword_query: &str,
    top_k: usize,
) -> Vec<SearchResult> {
    let keyword_lower = keyword_query.to_lowercase();
    let keywords: Vec<&str> = keyword_lower.split_whitespace().collect();

    let mut scored: Vec<(f32, SearchResult)> = semantic_results.iter()
        .map(|r| {
            let mut score = r.score;
            
            // Boost score for keyword matches
            let content_lower = r.embedding.content.to_lowercase();
            for keyword in &keywords {
                if content_lower.contains(keyword) {
                    score += 0.1; // Boost for keyword match
                }
                if r.embedding.metadata.function_name.as_ref()
                    .map(|f| f.to_lowercase().contains(keyword))
                    .unwrap_or(false) 
                {
                    score += 0.15; // Extra boost for function name match
                }
            }

            (score, r.clone())
        })
        .collect();

    // Re-sort by combined score
    scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));

    scored.into_iter()
        .take(top_k)
        .map(|(_, result)| result)
        .collect()
}