offline_intelligence/memory_db/
embedding_store.rs1use crate::memory_db::schema::*;
5use rusqlite::{params, Result, Row};
6use std::sync::{Arc, RwLock};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::collections::HashMap;
9use tracing::{info, warn};
10use r2d2::Pool;
11use r2d2_sqlite::SqliteConnectionManager;
12use hora::core::ann_index::ANNIndex;
13use hora::core::metrics::Metric;
14use hora::index::hnsw_idx::HNSWIndex;
15use hora::index::hnsw_params::HNSWParams;
16
17#[derive(Debug, Clone, serde::Serialize)]
19pub struct EmbeddingStats {
20 pub total_embeddings: usize,
21 pub dimension: usize,
22 pub index_type: String,
23}
24
25pub struct EmbeddingStore {
26 pool: Arc<Pool<SqliteConnectionManager>>,
27 ann_index: RwLock<Option<HNSWIndex<f32, i64>>>,
29 embedding_cache: RwLock<HashMap<i64, Vec<f32>>>,
31 index_dirty: AtomicBool,
34}
35
36impl EmbeddingStore {
37 pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
38 Self {
39 pool,
40 ann_index: RwLock::new(None),
41 embedding_cache: RwLock::new(HashMap::new()),
42 index_dirty: AtomicBool::new(false),
43 }
44 }
45
46 fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
47 self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
48 }
49
50 pub fn initialize_index(&self, model: &str) -> anyhow::Result<()> {
51 let conn = self.get_conn()?;
52
53 let mut stmt = conn.prepare(
54 "SELECT id, message_id, embedding FROM embeddings WHERE embedding_model = ?1"
55 )?;
56
57 let mut rows = stmt.query([model])?;
58
59 let params = HNSWParams {
61 n_neighbor: 16,
63 ef_build: 100,
65 ef_search: 50,
67 ..Default::default()
68 };
69
70 let mut index = HNSWIndex::<f32, i64>::new(
72 384, ¶ms,
74 );
75
76 let mut cache = self.embedding_cache.write().unwrap();
77
78 while let Some(row) = rows.next()? {
79 let message_id: i64 = row.get(1)?;
80 let embedding_bytes: Vec<u8> = row.get(2)?;
81 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
82 .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
83
84 let _ = index.add(&embedding, message_id);
86 cache.insert(message_id, embedding);
87 }
88
89 index.build(Metric::CosineSimilarity)
91 .map_err(|e| anyhow::anyhow!("Failed to build index: {}", e))?;
92
93 *self.ann_index.write().unwrap() = Some(index);
94 self.index_dirty.store(false, Ordering::Release);
95 info!("ANN index initialized with {} embeddings", cache.len());
96 Ok(())
97 }
98
99 pub fn store_embedding(&self, embedding: &Embedding) -> anyhow::Result<()> {
100 let embedding_bytes = bincode::serialize(&embedding.embedding)?;
101 let conn = self.get_conn()?;
102 conn.execute(
103 "INSERT OR REPLACE INTO embeddings (message_id, embedding, embedding_model, generated_at) VALUES (?1, ?2, ?3, ?4)",
104 params![embedding.message_id, embedding_bytes, &embedding.embedding_model, embedding.generated_at.to_rfc3339()],
105 )?;
106
107 let mut cache = self.embedding_cache.write().unwrap();
108 cache.insert(embedding.message_id, embedding.embedding.clone());
109
110 if let Some(ref mut index) = *self.ann_index.write().unwrap() {
111 let _ = index.add(&embedding.embedding, embedding.message_id);
114 self.index_dirty.store(true, Ordering::Release);
115 }
116 Ok(())
117 }
118
119 pub fn find_similar_embeddings(
120 &self,
121 query_embedding: &[f32],
122 model: &str,
123 limit: i32,
124 similarity_threshold: f32,
125 ) -> anyhow::Result<Vec<(i64, f32)>> {
126 if model.is_empty() || model.len() > 100 {
127 return Err(anyhow::anyhow!("Invalid model name"));
128 }
129
130 if self.index_dirty.load(Ordering::Acquire) {
134 let mut index_guard = self.ann_index.write().unwrap();
135 if self.index_dirty.load(Ordering::Acquire) {
136 if let Some(ref mut index) = *index_guard {
137 if let Err(e) = index.build(Metric::CosineSimilarity) {
138 warn!("HNSW rebuild failed, search will use stale index: {}", e);
139 }
140 }
141 self.index_dirty.store(false, Ordering::Release);
142 }
143 }
145
146 {
147 let index_guard = self.ann_index.read().unwrap();
148 if let Some(index) = &*index_guard {
149 let results = index.search(query_embedding, limit as usize);
150
151 let mut scored_results = Vec::new();
152 for id in &results {
153 if let Some(embedding) = self.embedding_cache.read().unwrap().get(id) {
154 let sim = cosine_similarity(query_embedding, embedding);
155 if sim >= similarity_threshold {
156 scored_results.push((*id, sim));
157 }
158 }
159 }
160
161 scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
162 return Ok(scored_results);
163 }
164 }
165
166 warn!("ANN index not available, falling back to safe linear search");
167 self.find_similar_embeddings_linear(query_embedding, model, limit, similarity_threshold)
168 }
169
170 fn find_similar_embeddings_linear(
171 &self,
172 query_embedding: &[f32],
173 model: &str,
174 limit: i32,
175 similarity_threshold: f32,
176 ) -> anyhow::Result<Vec<(i64, f32)>> {
177 let conn = self.get_conn()?;
178 let mut stmt = conn.prepare(
179 "SELECT message_id, embedding FROM embeddings WHERE embedding_model = ?1"
180 )?;
181 let mut rows = stmt.query([model])?;
182
183 let mut matches = Vec::new();
184 while let Some(row) = rows.next()? {
185 let message_id: i64 = row.get(0)?;
186 let embedding_bytes: Vec<u8> = row.get(1)?;
187 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
188 .map_err(|e| anyhow::anyhow!("Bincode error: {}", e))?;
189
190 let sim = cosine_similarity(query_embedding, &embedding);
191 if sim >= similarity_threshold {
192 matches.push((message_id, sim));
193 }
194 }
195
196 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
197 matches.truncate(limit as usize);
198 Ok(matches)
199 }
200
201 pub fn get_embedding_by_message_id(&self, message_id: i64, model: &str) -> anyhow::Result<Option<Embedding>> {
202 let conn = self.get_conn()?;
203 let mut stmt = conn.prepare(
204 "SELECT id, message_id, embedding, embedding_model, generated_at
205 FROM embeddings WHERE message_id = ?1 AND embedding_model = ?2"
206 )?;
207
208 let mut rows = stmt.query(params![message_id, model])?;
209 if let Some(row) = rows.next()? {
210 Ok(Some(self.row_to_embedding(row)?))
211 } else {
212 Ok(None)
213 }
214 }
215
216 fn row_to_embedding(&self, row: &Row) -> Result<Embedding> {
217 let embedding_bytes: Vec<u8> = row.get(2)?;
218 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
219 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
220
221 let generated_at_str: String = row.get(4)?;
222 let generated_at = chrono::DateTime::parse_from_rfc3339(&generated_at_str)
223 .map_err(|e| rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e)))?
224 .with_timezone(&chrono::Utc);
225
226 Ok(Embedding {
227 id: row.get(0)?,
228 message_id: row.get(1)?,
229 embedding,
230 embedding_model: row.get(3)?,
231 generated_at,
232 })
233 }
234
235 pub fn get_stats(&self) -> anyhow::Result<EmbeddingStats> {
236 let conn = self.get_conn()?;
237 let count: i64 = conn.query_row(
238 "SELECT COUNT(*) FROM embeddings",
239 [],
240 |row| row.get(0)
241 )?;
242
243 let mut stmt = conn.prepare("SELECT embedding FROM embeddings LIMIT 1")?;
244 let dimension = if let Some(row) = stmt.query([])?.next()? {
245 let embedding_bytes: Vec<u8> = row.get(0)?;
246 let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
247 .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
248 embedding.len()
249 } else {
250 0
251 };
252
253 let index_type = if self.ann_index.read().unwrap().is_some() {
254 "HNSW".to_string()
255 } else {
256 "Linear".to_string()
257 };
258
259 Ok(EmbeddingStats {
260 total_embeddings: count as usize,
261 dimension,
262 index_type,
263 })
264 }
265}
266
267fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
268 if a.len() != b.len() { return 0.0; }
269 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272 if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
273}