offline_intelligence/memory_db/
embedding_store.rs1use crate::memory_db::schema::*;
5use rusqlite::{params, Result, Row};
6use std::sync::{Arc, RwLock};
7use std::collections::HashMap;
8use tracing::{info, warn};
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11use hora::core::ann_index::ANNIndex;
12use hora::core::metrics::Metric;
13use hora::index::hnsw_idx::HNSWIndex;
14use hora::index::hnsw_params::HNSWParams;
15
16#[derive(Debug, Clone, serde::Serialize)]
18pub struct EmbeddingStats {
19 pub total_embeddings: usize,
20 pub dimension: usize,
21 pub index_type: String,
22}
23
24pub struct EmbeddingStore {
25 pool: Arc<Pool<SqliteConnectionManager>>,
26 ann_index: RwLock<Option<HNSWIndex<f32, i64>>>,
28 embedding_cache: RwLock<HashMap<i64, Vec<f32>>>,
30}
31
32impl EmbeddingStore {
33 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
34 Self {
35 pool,
36 ann_index: RwLock::new(None),
37 embedding_cache: RwLock::new(HashMap::new()),
38 }
39 }
40
41 fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
42 self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
43 }
44
45 pub fn initialize_index(&self, model: &str) -> anyhow::Result<()> {
46 let conn = self.get_conn()?;
47
48 let mut stmt = conn.prepare(
49 "SELECT id, message_id, embedding FROM embeddings WHERE embedding_model = ?1"
50 )?;
51
52 let mut rows = stmt.query([model])?;
53
54 let params = HNSWParams {
56 n_neighbor: 16,
58 ef_build: 100,
60 ef_search: 50,
62 ..Default::default()
63 };
64
65 let mut index = HNSWIndex::<f32, i64>::new(
67 384, ¶ms,
69 );
70
71 let mut cache = self.embedding_cache.write().unwrap();
72
73 while let Some(row) = rows.next()? {
74 let message_id: i64 = row.get(1)?;
75 let embedding_bytes: Vec<u8> = row.get(2)?;
76 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
77 .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
78
79 let _ = index.add(&embedding, message_id);
81 cache.insert(message_id, embedding);
82 }
83
84 index.build(Metric::CosineSimilarity)
86 .map_err(|e| anyhow::anyhow!("Failed to build index: {}", e))?;
87
88 *self.ann_index.write().unwrap() = Some(index);
89 info!("ANN index initialized with {} embeddings", cache.len());
90 Ok(())
91 }
92
93 pub fn store_embedding(&self, embedding: &Embedding) -> anyhow::Result<()> {
94 let embedding_bytes = bincode::serialize(&embedding.embedding)?;
95 let conn = self.get_conn()?;
96 conn.execute(
97 "INSERT OR REPLACE INTO embeddings (message_id, embedding, embedding_model, generated_at) VALUES (?1, ?2, ?3, ?4)",
98 params![embedding.message_id, embedding_bytes, &embedding.embedding_model, embedding.generated_at.to_rfc3339()],
99 )?;
100
101 let mut cache = self.embedding_cache.write().unwrap();
102 cache.insert(embedding.message_id, embedding.embedding.clone());
103
104 if let Some(ref mut index) = *self.ann_index.write().unwrap() {
105 let _ = index.add(&embedding.embedding, embedding.message_id);
107 index.build(Metric::CosineSimilarity)
110 .map_err(|e| anyhow::anyhow!("Failed to rebuild index: {}", e))?;
111 }
112 Ok(())
113 }
114
115 pub fn find_similar_embeddings(
116 &self,
117 query_embedding: &[f32],
118 model: &str,
119 limit: i32,
120 similarity_threshold: f32,
121 ) -> anyhow::Result<Vec<(i64, f32)>> {
122 if model.is_empty() || model.len() > 100 {
123 return Err(anyhow::anyhow!("Invalid model name"));
124 }
125
126 {
127 let index_guard = self.ann_index.read().unwrap();
128 if let Some(index) = &*index_guard {
129 let results = index.search(query_embedding, limit as usize);
130
131 let mut scored_results = Vec::new();
132 for id in &results {
133 if let Some(embedding) = self.embedding_cache.read().unwrap().get(id) {
134 let sim = cosine_similarity(query_embedding, embedding);
135 if sim >= similarity_threshold {
136 scored_results.push((*id, sim));
137 }
138 }
139 }
140
141 scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
142 return Ok(scored_results);
143 }
144 }
145
146 warn!("ANN index not available, falling back to safe linear search");
147 self.find_similar_embeddings_linear(query_embedding, model, limit, similarity_threshold)
148 }
149
150 fn find_similar_embeddings_linear(
151 &self,
152 query_embedding: &[f32],
153 model: &str,
154 limit: i32,
155 similarity_threshold: f32,
156 ) -> anyhow::Result<Vec<(i64, f32)>> {
157 let conn = self.get_conn()?;
158 let mut stmt = conn.prepare(
159 "SELECT message_id, embedding FROM embeddings WHERE embedding_model = ?1"
160 )?;
161 let mut rows = stmt.query([model])?;
162
163 let mut matches = Vec::new();
164 while let Some(row) = rows.next()? {
165 let message_id: i64 = row.get(0)?;
166 let embedding_bytes: Vec<u8> = row.get(1)?;
167 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
168 .map_err(|e| anyhow::anyhow!("Bincode error: {}", e))?;
169
170 let sim = cosine_similarity(query_embedding, &embedding);
171 if sim >= similarity_threshold {
172 matches.push((message_id, sim));
173 }
174 }
175
176 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177 matches.truncate(limit as usize);
178 Ok(matches)
179 }
180
181 pub fn get_embedding_by_message_id(&self, message_id: i64, model: &str) -> anyhow::Result<Option<Embedding>> {
182 let conn = self.get_conn()?;
183 let mut stmt = conn.prepare(
184 "SELECT id, message_id, embedding, embedding_model, generated_at
185 FROM embeddings WHERE message_id = ?1 AND embedding_model = ?2"
186 )?;
187
188 let mut rows = stmt.query(params![message_id, model])?;
189 if let Some(row) = rows.next()? {
190 Ok(Some(self.row_to_embedding(row)?))
191 } else {
192 Ok(None)
193 }
194 }
195
196 fn row_to_embedding(&self, row: &Row) -> Result<Embedding> {
197 let embedding_bytes: Vec<u8> = row.get(2)?;
198 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
199 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
200
201 let generated_at_str: String = row.get(4)?;
202 let generated_at = chrono::DateTime::parse_from_rfc3339(&generated_at_str)
203 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e)))?
204 .with_timezone(&chrono::Utc);
205
206 Ok(Embedding {
207 id: row.get(0)?,
208 message_id: row.get(1)?,
209 embedding,
210 embedding_model: row.get(3)?,
211 generated_at,
212 })
213 }
214
215 pub fn get_stats(&self) -> anyhow::Result<EmbeddingStats> {
216 let conn = self.get_conn()?;
217 let count: i64 = conn.query_row(
218 "SELECT COUNT(*) FROM embeddings",
219 [],
220 |row| row.get(0)
221 )?;
222
223 let mut stmt = conn.prepare("SELECT embedding FROM embeddings LIMIT 1")?;
224 let dimension = if let Some(row) = stmt.query([])?.next()? {
225 let embedding_bytes: Vec<u8> = row.get(0)?;
226 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
227 .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
228 embedding.len()
229 } else {
230 0
231 };
232
233 let index_type = if self.ann_index.read().unwrap().is_some() {
234 "HNSW".to_string()
235 } else {
236 "Linear".to_string()
237 };
238
239 Ok(EmbeddingStats {
240 total_embeddings: count as usize,
241 dimension,
242 index_type,
243 })
244 }
245}
246
247fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
248 if a.len() != b.len() { return 0.0; }
249 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
250 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
251 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
252 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
253}