agentroot_core/search/
vector.rs

1//! Vector similarity search
2//!
3//! Computes cosine similarity between query embedding and stored embeddings.
4
5use super::{SearchOptions, SearchResult, SearchSource};
6use crate::db::vectors::cosine_similarity;
7use crate::db::{docid_from_hash, Database};
8use crate::error::Result;
9use crate::llm::Embedder;
10use std::collections::HashMap;
11
12impl Database {
13    /// Perform vector similarity search
14    pub async fn search_vec(
15        &self,
16        query: &str,
17        embedder: &dyn Embedder,
18        options: &SearchOptions,
19    ) -> Result<Vec<SearchResult>> {
20        // Get query embedding
21        let query_embedding = embedder.embed(&format_query_for_embedding(query)).await?;
22
23        // Get all stored embeddings (optionally filtered by collection)
24        let stored_embeddings = if let Some(ref coll) = options.collection {
25            self.get_embeddings_for_collection(coll)?
26        } else {
27            self.get_all_embeddings()?
28        };
29
30        // Compute similarities
31        let mut similarities: Vec<(String, f32)> = stored_embeddings
32            .iter()
33            .map(|(hash_seq, embedding)| {
34                let sim = cosine_similarity(&query_embedding, embedding);
35                (hash_seq.clone(), sim)
36            })
37            .collect();
38
39        // Sort by similarity (descending)
40        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
41
42        // Take top candidates (3x limit for deduplication)
43        let fetch_limit = options.limit * 3;
44        let top_candidates: Vec<_> = similarities.into_iter().take(fetch_limit).collect();
45
46        // Fetch document details for top candidates
47        let mut results = Vec::new();
48        for (hash_seq, score) in top_candidates {
49            if let Some(result) = self.get_search_result_for_hash_seq(&hash_seq, score, options)? {
50                results.push(result);
51            }
52        }
53
54        // Deduplicate: keep best chunk per document
55        let mut best_by_hash: HashMap<String, SearchResult> = HashMap::new();
56        for result in results {
57            let existing = best_by_hash.get(&result.hash);
58            if existing.is_none() || existing.unwrap().score < result.score {
59                best_by_hash.insert(result.hash.clone(), result);
60            }
61        }
62
63        let mut final_results: Vec<SearchResult> = best_by_hash.into_values().collect();
64        final_results.sort_by(|a, b| {
65            b.score
66                .partial_cmp(&a.score)
67                .unwrap_or(std::cmp::Ordering::Equal)
68        });
69
70        // Filter by min_score and limit
71        let filtered: Vec<SearchResult> = final_results
72            .into_iter()
73            .filter(|r| r.score >= options.min_score)
74            .take(options.limit)
75            .collect();
76
77        Ok(filtered)
78    }
79
80    /// Get search result for a hash_seq
81    fn get_search_result_for_hash_seq(
82        &self,
83        hash_seq: &str,
84        score: f32,
85        options: &SearchOptions,
86    ) -> Result<Option<SearchResult>> {
87        // Parse hash_seq (format: "hash_seq")
88        let parts: Vec<&str> = hash_seq.rsplitn(2, '_').collect();
89        if parts.len() != 2 {
90            return Ok(None);
91        }
92        let hash = parts[1];
93
94        let result = self.conn.query_row(
95            "SELECT
96                'agentroot://' || d.collection || '/' || d.path as filepath,
97                d.collection || '/' || d.path as display_path,
98                d.title,
99                d.hash,
100                d.collection,
101                d.modified_at,
102                c.doc,
103                LENGTH(c.doc),
104                cv.pos,
105                d.llm_summary,
106                d.llm_title,
107                d.llm_keywords,
108                d.llm_category,
109                d.llm_difficulty,
110                d.user_metadata
111             FROM documents d
112             JOIN content c ON c.hash = d.hash
113             JOIN content_vectors cv ON cv.hash = d.hash
114             WHERE d.hash = ?1 AND d.active = 1
115             LIMIT 1",
116            rusqlite::params![hash],
117            |row| {
118                let keywords_json: Option<String> = row.get(11)?;
119                let keywords =
120                    keywords_json.and_then(|json| serde_json::from_str::<Vec<String>>(&json).ok());
121
122                let user_metadata_json: Option<String> = row.get(14)?;
123                let user_metadata = user_metadata_json
124                    .and_then(|json| crate::db::UserMetadata::from_json(&json).ok());
125
126                Ok(SearchResult {
127                    filepath: row.get(0)?,
128                    display_path: row.get(1)?,
129                    title: row.get(2)?,
130                    hash: row.get(3)?,
131                    collection_name: row.get(4)?,
132                    modified_at: row.get(5)?,
133                    body: if options.full_content {
134                        Some(row.get(6)?)
135                    } else {
136                        None
137                    },
138                    body_length: row.get(7)?,
139                    docid: docid_from_hash(&row.get::<_, String>(3)?),
140                    context: None,
141                    score: score as f64,
142                    source: SearchSource::Vector,
143                    chunk_pos: Some(row.get(8)?),
144                    llm_summary: row.get(9)?,
145                    llm_title: row.get(10)?,
146                    llm_keywords: keywords,
147                    llm_category: row.get(12)?,
148                    llm_difficulty: row.get(13)?,
149                    user_metadata,
150                })
151            },
152        );
153
154        match result {
155            Ok(r) => Ok(Some(r)),
156            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
157            Err(e) => Err(e.into()),
158        }
159    }
160}
161
162/// Format query for embedding (matches document format)
163fn format_query_for_embedding(query: &str) -> String {
164    format!("search_query: {}", query)
165}