Skip to main content

offline_intelligence/memory_db/
embedding_store.rs

1// "D:\_ProjectWorks\AUDIO_Interface\Server\src\memory_db\embedding_store.rs"
2//! Embedding storage and retrieval operations with ANN indexing support
3
4use crate::memory_db::schema::*;
5use rusqlite::{params, Result, Row};
6use std::sync::{Arc, RwLock};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::collections::HashMap;
9use tracing::{info, warn};
10use r2d2::Pool;
11use r2d2_sqlite::SqliteConnectionManager;
12use hora::core::ann_index::ANNIndex;
13use hora::core::metrics::Metric;
14use hora::index::hnsw_idx::HNSWIndex;
15use hora::index::hnsw_params::HNSWParams; 
16
17/// Manages embedding storage and retrieval with HNSW-based ANN indexing
18#[derive(Debug, Clone, serde::Serialize)]
19pub struct EmbeddingStats {
20    pub total_embeddings: usize,
21    pub dimension: usize,
22    pub index_type: String,
23}
24
25pub struct EmbeddingStore {
26    pool: Arc<Pool<SqliteConnectionManager>>,
27    // ANN index for fast similarity search
28    ann_index: RwLock<Option<HNSWIndex<f32, i64>>>,
29    // In-memory cache for linear search fallbacks
30    embedding_cache: RwLock<HashMap<i64, Vec<f32>>>,
31    // Set to true when embeddings are added but the index hasn't been rebuilt yet.
32    // The rebuild is deferred until the next search (lazy strategy).
33    index_dirty: AtomicBool,
34}
35
36impl EmbeddingStore {
37    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
38        Self {
39            pool,
40            ann_index: RwLock::new(None),
41            embedding_cache: RwLock::new(HashMap::new()),
42            index_dirty: AtomicBool::new(false),
43        }
44    }
45
46    fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
47        self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
48    }
49
50    pub fn initialize_index(&self, model: &str) -> anyhow::Result<()> {
51        let conn = self.get_conn()?;
52        
53        let mut stmt = conn.prepare(
54            "SELECT id, message_id, embedding FROM embeddings WHERE embedding_model = ?1"
55        )?;
56        
57        let mut rows = stmt.query([model])?;
58        
59        // Fix: Use correct field names for HNSWParams based on error message
60        let params = HNSWParams {
61            // According to error: n_neighbor is likely the equivalent of 'm' in other implementations
62            n_neighbor: 16,
63            // According to error: ef_build is likely the equivalent of 'ef_construction'
64            ef_build: 100,
65            // According to error: ef_search should be available
66            ef_search: 50,
67            ..Default::default()
68        };
69        
70        // Initialize HNSW index with dimension and a reference to params
71        let mut index = HNSWIndex::<f32, i64>::new(
72            384, // dimension
73            &params,
74        );
75        
76        let mut cache = self.embedding_cache.write().unwrap();
77        
78        while let Some(row) = rows.next()? {
79            let message_id: i64 = row.get(1)?;
80            let embedding_bytes: Vec<u8> = row.get(2)?;
81            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
82                .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
83            
84            // Ignore duplicate insert errors while rebuilding the ANN index
85            let _ = index.add(&embedding, message_id);
86            cache.insert(message_id, embedding);
87        }
88        
89        // Build the index with the Metric
90        index.build(Metric::CosineSimilarity)
91            .map_err(|e| anyhow::anyhow!("Failed to build index: {}", e))?;
92        
93        *self.ann_index.write().unwrap() = Some(index);
94        self.index_dirty.store(false, Ordering::Release);
95        info!("ANN index initialized with {} embeddings", cache.len());
96        Ok(())
97    }
98
99    pub fn store_embedding(&self, embedding: &Embedding) -> anyhow::Result<()> {
100        let embedding_bytes = bincode::serialize(&embedding.embedding)?;
101        let conn = self.get_conn()?;
102        conn.execute(
103            "INSERT OR REPLACE INTO embeddings (message_id, embedding, embedding_model, generated_at) VALUES (?1, ?2, ?3, ?4)",
104            params![embedding.message_id, embedding_bytes, &embedding.embedding_model, embedding.generated_at.to_rfc3339()],
105        )?;
106
107        let mut cache = self.embedding_cache.write().unwrap();
108        cache.insert(embedding.message_id, embedding.embedding.clone());
109
110        if let Some(ref mut index) = *self.ann_index.write().unwrap() {
111            // Add to the graph structure; the build (which is the expensive O(n log n) step)
112            // is deferred until the next search via the dirty flag.
113            let _ = index.add(&embedding.embedding, embedding.message_id);
114            self.index_dirty.store(true, Ordering::Release);
115        }
116        Ok(())
117    }
118
119    pub fn find_similar_embeddings(
120        &self,
121        query_embedding: &[f32],
122        model: &str,
123        limit: i32,
124        similarity_threshold: f32,
125    ) -> anyhow::Result<Vec<(i64, f32)>> {
126        if model.is_empty() || model.len() > 100 {
127            return Err(anyhow::anyhow!("Invalid model name"));
128        }
129
130        // Lazy rebuild: if new embeddings have been added since the last build,
131        // rebuild now before searching. Uses a double-check under the write lock
132        // to avoid redundant rebuilds when multiple threads enter concurrently.
133        if self.index_dirty.load(Ordering::Acquire) {
134            let mut index_guard = self.ann_index.write().unwrap();
135            if self.index_dirty.load(Ordering::Acquire) {
136                if let Some(ref mut index) = *index_guard {
137                    if let Err(e) = index.build(Metric::CosineSimilarity) {
138                        warn!("HNSW rebuild failed, search will use stale index: {}", e);
139                    }
140                }
141                self.index_dirty.store(false, Ordering::Release);
142            }
143            // write lock drops here before the read-lock search below
144        }
145
146        {
147            let index_guard = self.ann_index.read().unwrap();
148            if let Some(index) = &*index_guard {
149                let results = index.search(query_embedding, limit as usize);
150                
151                let mut scored_results = Vec::new();
152                for id in &results {
153                    if let Some(embedding) = self.embedding_cache.read().unwrap().get(id) {
154                        let sim = cosine_similarity(query_embedding, embedding);
155                        if sim >= similarity_threshold {
156                            scored_results.push((*id, sim));
157                        }
158                    }
159                }
160                
161                scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
162                return Ok(scored_results);
163            }
164        }
165
166        warn!("ANN index not available, falling back to safe linear search");
167        self.find_similar_embeddings_linear(query_embedding, model, limit, similarity_threshold)
168    }
169
170    fn find_similar_embeddings_linear(
171        &self,
172        query_embedding: &[f32],
173        model: &str,
174        limit: i32,
175        similarity_threshold: f32,
176    ) -> anyhow::Result<Vec<(i64, f32)>> {
177        let conn = self.get_conn()?;
178        let mut stmt = conn.prepare(
179            "SELECT message_id, embedding FROM embeddings WHERE embedding_model = ?1"
180        )?;
181        let mut rows = stmt.query([model])?;
182        
183        let mut matches = Vec::new();
184        while let Some(row) = rows.next()? {
185            let message_id: i64 = row.get(0)?;
186            let embedding_bytes: Vec<u8> = row.get(1)?;
187            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
188                .map_err(|e| anyhow::anyhow!("Bincode error: {}", e))?;
189            
190            let sim = cosine_similarity(query_embedding, &embedding);
191            if sim >= similarity_threshold {
192                matches.push((message_id, sim));
193            }
194        }
195        
196        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
197        matches.truncate(limit as usize);
198        Ok(matches)
199    }
200
201    pub fn get_embedding_by_message_id(&self, message_id: i64, model: &str) -> anyhow::Result<Option<Embedding>> {
202        let conn = self.get_conn()?;
203        let mut stmt = conn.prepare(
204            "SELECT id, message_id, embedding, embedding_model, generated_at
205             FROM embeddings WHERE message_id = ?1 AND embedding_model = ?2"
206        )?;
207        
208        let mut rows = stmt.query(params![message_id, model])?;
209        if let Some(row) = rows.next()? {
210            Ok(Some(self.row_to_embedding(row)?))
211        } else {
212            Ok(None)
213        }
214    }
215
216    fn row_to_embedding(&self, row: &Row) -> Result<Embedding> {
217        let embedding_bytes: Vec<u8> = row.get(2)?;
218        let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
219            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
220        
221        let generated_at_str: String = row.get(4)?;
222        let generated_at = chrono::DateTime::parse_from_rfc3339(&generated_at_str)
223            .map_err(|e| rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e)))?
224            .with_timezone(&chrono::Utc);
225        
226        Ok(Embedding {
227            id: row.get(0)?,
228            message_id: row.get(1)?,
229            embedding,
230            embedding_model: row.get(3)?,
231            generated_at,
232        })
233    }
234
235    pub fn get_stats(&self) -> anyhow::Result<EmbeddingStats> {
236        let conn = self.get_conn()?;
237        let count: i64 = conn.query_row(
238            "SELECT COUNT(*) FROM embeddings",
239            [],
240            |row| row.get(0)
241        )?;
242        
243        let mut stmt = conn.prepare("SELECT embedding FROM embeddings LIMIT 1")?;
244        let dimension = if let Some(row) = stmt.query([])?.next()? {
245            let embedding_bytes: Vec<u8> = row.get(0)?;
246            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
247                .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
248            embedding.len()
249        } else {
250            0
251        };
252        
253        let index_type = if self.ann_index.read().unwrap().is_some() {
254            "HNSW".to_string()
255        } else {
256            "Linear".to_string()
257        };
258        
259        Ok(EmbeddingStats {
260            total_embeddings: count as usize,
261            dimension,
262            index_type,
263        })
264    }
265}
266
267fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
268    if a.len() != b.len() { return 0.0; }
269    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272    if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
273}