agentroot_core/search/
vector.rs1use super::{SearchOptions, SearchResult, SearchSource};
6use crate::db::vectors::cosine_similarity;
7use crate::db::{docid_from_hash, Database};
8use crate::error::Result;
9use crate::llm::Embedder;
10use std::collections::HashMap;
11
12impl Database {
13 pub async fn search_vec(
15 &self,
16 query: &str,
17 embedder: &dyn Embedder,
18 options: &SearchOptions,
19 ) -> Result<Vec<SearchResult>> {
20 let query_embedding = embedder.embed(&format_query_for_embedding(query)).await?;
22
23 let stored_embeddings = if let Some(ref coll) = options.collection {
25 self.get_embeddings_for_collection(coll)?
26 } else {
27 self.get_all_embeddings()?
28 };
29
30 let mut similarities: Vec<(String, f32)> = stored_embeddings
32 .iter()
33 .map(|(hash_seq, embedding)| {
34 let sim = cosine_similarity(&query_embedding, embedding);
35 (hash_seq.clone(), sim)
36 })
37 .collect();
38
39 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
41
42 let fetch_limit = options.limit * 3;
44 let top_candidates: Vec<_> = similarities.into_iter().take(fetch_limit).collect();
45
46 let mut results = Vec::new();
48 for (hash_seq, score) in top_candidates {
49 if let Some(result) = self.get_search_result_for_hash_seq(&hash_seq, score, options)? {
50 results.push(result);
51 }
52 }
53
54 let mut best_by_hash: HashMap<String, SearchResult> = HashMap::new();
56 for result in results {
57 let existing = best_by_hash.get(&result.hash);
58 if existing.is_none() || existing.unwrap().score < result.score {
59 best_by_hash.insert(result.hash.clone(), result);
60 }
61 }
62
63 let mut final_results: Vec<SearchResult> = best_by_hash.into_values().collect();
64 final_results.sort_by(|a, b| {
65 b.score
66 .partial_cmp(&a.score)
67 .unwrap_or(std::cmp::Ordering::Equal)
68 });
69
70 let filtered: Vec<SearchResult> = final_results
72 .into_iter()
73 .filter(|r| r.score >= options.min_score)
74 .take(options.limit)
75 .collect();
76
77 Ok(filtered)
78 }
79
80 fn get_search_result_for_hash_seq(
82 &self,
83 hash_seq: &str,
84 score: f32,
85 options: &SearchOptions,
86 ) -> Result<Option<SearchResult>> {
87 let parts: Vec<&str> = hash_seq.rsplitn(2, '_').collect();
89 if parts.len() != 2 {
90 return Ok(None);
91 }
92 let hash = parts[1];
93
94 let result = self.conn.query_row(
95 "SELECT
96 'agentroot://' || d.collection || '/' || d.path as filepath,
97 d.collection || '/' || d.path as display_path,
98 d.title,
99 d.hash,
100 d.collection,
101 d.modified_at,
102 c.doc,
103 LENGTH(c.doc),
104 cv.pos,
105 d.llm_summary,
106 d.llm_title,
107 d.llm_keywords,
108 d.llm_category,
109 d.llm_difficulty,
110 d.user_metadata
111 FROM documents d
112 JOIN content c ON c.hash = d.hash
113 JOIN content_vectors cv ON cv.hash = d.hash
114 WHERE d.hash = ?1 AND d.active = 1
115 LIMIT 1",
116 rusqlite::params![hash],
117 |row| {
118 let keywords_json: Option<String> = row.get(11)?;
119 let keywords =
120 keywords_json.and_then(|json| serde_json::from_str::<Vec<String>>(&json).ok());
121
122 let user_metadata_json: Option<String> = row.get(14)?;
123 let user_metadata = user_metadata_json
124 .and_then(|json| crate::db::UserMetadata::from_json(&json).ok());
125
126 Ok(SearchResult {
127 filepath: row.get(0)?,
128 display_path: row.get(1)?,
129 title: row.get(2)?,
130 hash: row.get(3)?,
131 collection_name: row.get(4)?,
132 modified_at: row.get(5)?,
133 body: if options.full_content {
134 Some(row.get(6)?)
135 } else {
136 None
137 },
138 body_length: row.get(7)?,
139 docid: docid_from_hash(&row.get::<_, String>(3)?),
140 context: None,
141 score: score as f64,
142 source: SearchSource::Vector,
143 chunk_pos: Some(row.get(8)?),
144 llm_summary: row.get(9)?,
145 llm_title: row.get(10)?,
146 llm_keywords: keywords,
147 llm_category: row.get(12)?,
148 llm_difficulty: row.get(13)?,
149 user_metadata,
150 })
151 },
152 );
153
154 match result {
155 Ok(r) => Ok(Some(r)),
156 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
157 Err(e) => Err(e.into()),
158 }
159 }
160}
161
162fn format_query_for_embedding(query: &str) -> String {
164 format!("search_query: {}", query)
165}