Skip to main content

offline_intelligence/memory_db/
embedding_store.rs

1// "D:\_ProjectWorks\AUDIO_Interface\Server\src\memory_db\embedding_store.rs"
2//! Embedding storage and retrieval operations with ANN indexing support
3
4use crate::memory_db::schema::*;
5use rusqlite::{params, Result, Row};
6use std::sync::{Arc, RwLock};
7use std::collections::HashMap;
8use tracing::{info, warn};
9use r2d2::Pool;
10use r2d2_sqlite::SqliteConnectionManager;
11use hora::core::ann_index::ANNIndex;
12use hora::core::metrics::Metric;
13use hora::index::hnsw_idx::HNSWIndex;
14use hora::index::hnsw_params::HNSWParams; 
15
16/// Manages embedding storage and retrieval with HNSW-based ANN indexing
17#[derive(Debug, Clone, serde::Serialize)]
18pub struct EmbeddingStats {
19    pub total_embeddings: usize,
20    pub dimension: usize,
21    pub index_type: String,
22}
23
24pub struct EmbeddingStore {
25    pool: Arc<Pool<SqliteConnectionManager>>,
26    // ANN index for fast similarity search
27    ann_index: RwLock<Option<HNSWIndex<f32, i64>>>,
28    // In-memory cache for linear search fallbacks
29    embedding_cache: RwLock<HashMap<i64, Vec<f32>>>,
30}
31
32impl EmbeddingStore {
33    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
34        Self {
35            pool,
36            ann_index: RwLock::new(None),
37            embedding_cache: RwLock::new(HashMap::new()),
38        }
39    }
40
41    fn get_conn(&self) -> anyhow::Result<r2d2::PooledConnection<SqliteConnectionManager>> {
42        self.pool.get().map_err(|e| anyhow::anyhow!("Failed to get connection from pool: {}", e))
43    }
44
45    pub fn initialize_index(&self, model: &str) -> anyhow::Result<()> {
46        let conn = self.get_conn()?;
47        
48        let mut stmt = conn.prepare(
49            "SELECT id, message_id, embedding FROM embeddings WHERE embedding_model = ?1"
50        )?;
51        
52        let mut rows = stmt.query([model])?;
53        
54        // Fix: Use correct field names for HNSWParams based on error message
55        let params = HNSWParams {
56            // According to error: n_neighbor is likely the equivalent of 'm' in other implementations
57            n_neighbor: 16,
58            // According to error: ef_build is likely the equivalent of 'ef_construction'
59            ef_build: 100,
60            // According to error: ef_search should be available
61            ef_search: 50,
62            ..Default::default()
63        };
64        
65        // Initialize HNSW index with dimension and a reference to params
66        let mut index = HNSWIndex::<f32, i64>::new(
67            384, // dimension
68            &params,
69        );
70        
71        let mut cache = self.embedding_cache.write().unwrap();
72        
73        while let Some(row) = rows.next()? {
74            let message_id: i64 = row.get(1)?;
75            let embedding_bytes: Vec<u8> = row.get(2)?;
76            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
77                .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
78            
79            // Ignore duplicate insert errors while rebuilding the ANN index
80            let _ = index.add(&embedding, message_id);
81            cache.insert(message_id, embedding);
82        }
83        
84        // Build the index with the Metric
85        index.build(Metric::CosineSimilarity)
86            .map_err(|e| anyhow::anyhow!("Failed to build index: {}", e))?;
87        
88        *self.ann_index.write().unwrap() = Some(index);
89        info!("ANN index initialized with {} embeddings", cache.len());
90        Ok(())
91    }
92
93    pub fn store_embedding(&self, embedding: &Embedding) -> anyhow::Result<()> {
94        let embedding_bytes = bincode::serialize(&embedding.embedding)?;
95        let conn = self.get_conn()?;
96        conn.execute(
97            "INSERT OR REPLACE INTO embeddings (message_id, embedding, embedding_model, generated_at) VALUES (?1, ?2, ?3, ?4)",
98            params![embedding.message_id, embedding_bytes, &embedding.embedding_model, embedding.generated_at.to_rfc3339()],
99        )?;
100
101        let mut cache = self.embedding_cache.write().unwrap();
102        cache.insert(embedding.message_id, embedding.embedding.clone());
103
104        if let Some(ref mut index) = *self.ann_index.write().unwrap() {
105            // Ignore add errors; rebuild ensures index stays consistent
106            let _ = index.add(&embedding.embedding, embedding.message_id);
107            // Re-building after every add is expensive, but necessary for HNSW 
108            // if you want immediate searchability of the new item.
109            index.build(Metric::CosineSimilarity)
110                .map_err(|e| anyhow::anyhow!("Failed to rebuild index: {}", e))?;
111        }
112        Ok(())
113    }
114
115    pub fn find_similar_embeddings(
116        &self,
117        query_embedding: &[f32],
118        model: &str,
119        limit: i32,
120        similarity_threshold: f32,
121    ) -> anyhow::Result<Vec<(i64, f32)>> {
122        if model.is_empty() || model.len() > 100 {
123            return Err(anyhow::anyhow!("Invalid model name"));
124        }
125
126        {
127            let index_guard = self.ann_index.read().unwrap();
128            if let Some(index) = &*index_guard {
129                let results = index.search(query_embedding, limit as usize);
130                
131                let mut scored_results = Vec::new();
132                for id in &results {
133                    if let Some(embedding) = self.embedding_cache.read().unwrap().get(id) {
134                        let sim = cosine_similarity(query_embedding, embedding);
135                        if sim >= similarity_threshold {
136                            scored_results.push((*id, sim));
137                        }
138                    }
139                }
140                
141                scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
142                return Ok(scored_results);
143            }
144        }
145
146        warn!("ANN index not available, falling back to safe linear search");
147        self.find_similar_embeddings_linear(query_embedding, model, limit, similarity_threshold)
148    }
149
150    fn find_similar_embeddings_linear(
151        &self,
152        query_embedding: &[f32],
153        model: &str,
154        limit: i32,
155        similarity_threshold: f32,
156    ) -> anyhow::Result<Vec<(i64, f32)>> {
157        let conn = self.get_conn()?;
158        let mut stmt = conn.prepare(
159            "SELECT message_id, embedding FROM embeddings WHERE embedding_model = ?1"
160        )?;
161        let mut rows = stmt.query([model])?;
162        
163        let mut matches = Vec::new();
164        while let Some(row) = rows.next()? {
165            let message_id: i64 = row.get(0)?;
166            let embedding_bytes: Vec<u8> = row.get(1)?;
167            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
168                .map_err(|e| anyhow::anyhow!("Bincode error: {}", e))?;
169            
170            let sim = cosine_similarity(query_embedding, &embedding);
171            if sim >= similarity_threshold {
172                matches.push((message_id, sim));
173            }
174        }
175        
176        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177        matches.truncate(limit as usize);
178        Ok(matches)
179    }
180
181    pub fn get_embedding_by_message_id(&self, message_id: i64, model: &str) -> anyhow::Result<Option<Embedding>> {
182        let conn = self.get_conn()?;
183        let mut stmt = conn.prepare(
184            "SELECT id, message_id, embedding, embedding_model, generated_at
185             FROM embeddings WHERE message_id = ?1 AND embedding_model = ?2"
186        )?;
187        
188        let mut rows = stmt.query(params![message_id, model])?;
189        if let Some(row) = rows.next()? {
190            Ok(Some(self.row_to_embedding(row)?))
191        } else {
192            Ok(None)
193        }
194    }
195
196    fn row_to_embedding(&self, row: &Row) -> Result<Embedding> {
197        let embedding_bytes: Vec<u8> = row.get(2)?;
198        let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
199            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
200        
201        let generated_at_str: String = row.get(4)?;
202        let generated_at = chrono::DateTime::parse_from_rfc3339(&generated_at_str)
203            .map_err(|e| rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e)))?
204            .with_timezone(&chrono::Utc);
205        
206        Ok(Embedding {
207            id: row.get(0)?,
208            message_id: row.get(1)?,
209            embedding,
210            embedding_model: row.get(3)?,
211            generated_at,
212        })
213    }
214
215    pub fn get_stats(&self) -> anyhow::Result<EmbeddingStats> {
216        let conn = self.get_conn()?;
217        let count: i64 = conn.query_row(
218            "SELECT COUNT(*) FROM embeddings",
219            [],
220            |row| row.get(0)
221        )?;
222        
223        let mut stmt = conn.prepare("SELECT embedding FROM embeddings LIMIT 1")?;
224        let dimension = if let Some(row) = stmt.query([])?.next()? {
225            let embedding_bytes: Vec<u8> = row.get(0)?;
226            let embedding: Vec<f32> = bincode::deserialize(&embedding_bytes)
227                .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
228            embedding.len()
229        } else {
230            0
231        };
232        
233        let index_type = if self.ann_index.read().unwrap().is_some() {
234            "HNSW".to_string()
235        } else {
236            "Linear".to_string()
237        };
238        
239        Ok(EmbeddingStats {
240            total_embeddings: count as usize,
241            dimension,
242            index_type,
243        })
244    }
245}
246
247fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
248    if a.len() != b.len() { return 0.0; }
249    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
250    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
251    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
252    if norm_a == 0.0 || norm_b == 0.0 { 0.0 } else { dot / (norm_a * norm_b) }
253}