Skip to main content

mempalace_rs/
vector_storage.rs

1// vector_storage.rs — MemPalace Pure-Rust Storage Engine (replaces ChromaDB)
2//
3// Zero-network, embedded storage combining:
4//   • fastembed-rs  → CPU/ONNX text embeddings (AllMiniLML6V2, 384-dim)
5//   • rusqlite      → relational source of truth
6//   • usearch       → SIMD-accelerated HNSW ANN index
7
8use std::path::Path;
9use std::sync::Arc;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use anyhow::{anyhow, Context, Result};
13use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
14use rusqlite::{params, Connection, OptionalExtension};
15use std::path::PathBuf;
16use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
17
18const VECTOR_DIMS: usize = 384;
19const HNSW_M: usize = 16;
20const HNSW_EF_CONSTRUCTION: usize = 128;
21
22/// A structured record of a single atomic memory filed in the Palace.
23#[derive(Debug, Clone)]
24pub struct MemoryRecord {
25    pub id: i64,
26    pub text_content: String,
27    pub wing: String,
28    pub room: String,
29    pub source_file: Option<String>,
30    pub valid_from: i64,
31    pub valid_to: Option<i64>,
32    pub score: f32,
33    pub importance: f32,
34}
35
36/// Represents a chronological validity window for a memory.
37#[derive(Debug, Clone, Default)]
38pub struct TemporalRange {
39    pub valid_from: Option<i64>,
40    pub valid_to: Option<i64>,
41}
42
43fn now_unix() -> i64 {
44    SystemTime::now()
45        .duration_since(UNIX_EPOCH)
46        .expect("system clock before Unix epoch")
47        .as_secs() as i64
48}
49
50fn compute_decayed_importance(base_score: f32, last_accessed: i64, access_count: i64) -> f32 {
51    let days_since = ((now_unix() - last_accessed) as f32 / 86400.0).max(0.0);
52    let freq_boost = (1.0 + access_count as f32).ln().max(1.0);
53    base_score * 0.9f32.powf(days_since) * freq_boost
54}
55
56fn build_index() -> Result<Index> {
57    let opts = IndexOptions {
58        dimensions: VECTOR_DIMS,
59        metric: MetricKind::Cos,
60        quantization: ScalarKind::F32,
61        connectivity: HNSW_M,
62        expansion_add: HNSW_EF_CONSTRUCTION,
63        expansion_search: 64,
64        ..Default::default()
65    };
66    Index::new(&opts).map_err(|e| anyhow!("usearch index creation failed: {e}"))
67}
68
69/// The pure-Rust vector storage engine combining SQLite metadata and usearch HNSW index.
70pub struct VectorStorage {
71    pub embedder: Arc<TextEmbedding>,
72    pub db: Connection,
73    pub index: Index,
74}
75
76impl VectorStorage {
77    pub fn new(db_path: impl AsRef<Path>, index_path: impl AsRef<Path>) -> Result<Self> {
78        let cache_dir = std::env::var("MEMPALACE_MODELS_DIR")
79            .ok()
80            .map(PathBuf::from)
81            .filter(|p| p.exists())
82            .or_else(|| {
83                std::env::current_exe()
84                    .ok()
85                    .and_then(|exe| exe.parent().map(|p| p.join("models")))
86                    .filter(|p| p.exists())
87            });
88
89        let mut init_opts =
90            InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(false);
91
92        if let Some(cache) = cache_dir {
93            init_opts = init_opts.with_cache_dir(cache);
94        }
95
96        let embedder =
97            TextEmbedding::try_new(init_opts).context("Failed to initialise fastembed")?;
98
99        Self::new_with_embedder(db_path, index_path, Arc::new(embedder))
100    }
101
102    pub fn new_with_embedder(
103        db_path: impl AsRef<Path>,
104        index_path: impl AsRef<Path>,
105        embedder: Arc<TextEmbedding>,
106    ) -> Result<Self> {
107        // 2. SQLite
108        let db = Connection::open(db_path.as_ref())
109            .with_context(|| format!("Cannot open SQLite at {:?}", db_path.as_ref()))?;
110
111        db.execute_batch(
112            "PRAGMA journal_mode = WAL;
113             PRAGMA foreign_keys = ON;
114             PRAGMA synchronous = NORMAL;
115             CREATE TABLE IF NOT EXISTS memories (
116                id INTEGER PRIMARY KEY AUTOINCREMENT,
117                text_content TEXT NOT NULL,
118                wing TEXT NOT NULL,
119                room TEXT NOT NULL,
120                source_file TEXT,
121                source_mtime REAL,
122                valid_from INTEGER NOT NULL,
123                valid_to INTEGER,
124                last_accessed INTEGER DEFAULT 0,
125                access_count INTEGER DEFAULT 0,
126                importance_score REAL DEFAULT 5.0
127             );
128             CREATE INDEX IF NOT EXISTS idx_source_file ON memories (source_file);
129             CREATE INDEX IF NOT EXISTS idx_wing_room ON memories (wing, room);
130             CREATE INDEX IF NOT EXISTS idx_valid ON memories (valid_from, valid_to);
131             CREATE TABLE IF NOT EXISTS drawers (
132                id INTEGER PRIMARY KEY AUTOINCREMENT,
133                content TEXT NOT NULL,
134                wing TEXT NOT NULL,
135                room TEXT NOT NULL,
136                source_file TEXT,
137                filed_at TEXT NOT NULL,
138                embedding_id INTEGER REFERENCES memories(id)
139             );
140             CREATE INDEX IF NOT EXISTS idx_drawers_wing_room ON drawers (wing, room);
141            ",
142        )?;
143
144        {
145            let mut check_stmt = db.prepare("PRAGMA table_info(memories)")?;
146            let mut has_accessed = false;
147            let mut has_mtime = false;
148            let mut rows = check_stmt.query([])?;
149            while let Some(row) = rows.next()? {
150                let name: String = row.get(1)?;
151                if name == "last_accessed" {
152                    has_accessed = true;
153                }
154                if name == "source_mtime" {
155                    has_mtime = true;
156                }
157            }
158            if !has_accessed {
159                db.execute_batch(
160                    "ALTER TABLE memories ADD COLUMN last_accessed INTEGER DEFAULT 0;
161                     ALTER TABLE memories ADD COLUMN access_count INTEGER DEFAULT 0;
162                     ALTER TABLE memories ADD COLUMN importance_score REAL DEFAULT 5.0;",
163                )?;
164                let now = now_unix();
165                db.execute("UPDATE memories SET last_accessed = ?1", params![now])?;
166            }
167            if !has_mtime {
168                db.execute_batch("ALTER TABLE memories ADD COLUMN source_mtime REAL;")?;
169            }
170        }
171
172        // 3. usearch HNSW index
173        let index_path = index_path.as_ref();
174        let index = if index_path.exists() {
175            let idx = build_index()?;
176            idx.load(
177                index_path
178                    .to_str()
179                    .ok_or_else(|| anyhow!("Non-UTF8 index path"))?,
180            )
181            .map_err(|e| anyhow!("Failed to load usearch index: {e}"))?;
182            idx
183        } else {
184            build_index()?
185        };
186
187        Ok(Self {
188            embedder,
189            db,
190            index,
191        })
192    }
193
194    pub fn add_memory(
195        &mut self,
196        text: &str,
197        wing: &str,
198        room: &str,
199        source_file: Option<&str>,
200        source_mtime: Option<f64>,
201    ) -> Result<i64> {
202        let vector = self.embed_single(text)?;
203        let valid_from = now_unix();
204
205        self.db.execute(
206            "INSERT INTO memories (text_content, wing, room, source_file, source_mtime, valid_from, last_accessed, access_count, importance_score)
207             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 0, 5.0)",
208            params![text, wing, room, source_file, source_mtime, valid_from, valid_from],
209        )?;
210
211        let row_id = self.db.last_insert_rowid();
212
213        let needed = self.index.size() + 1;
214        if needed > self.index.capacity() {
215            let new_cap = (needed * 2).max(64);
216            self.index
217                .reserve(new_cap)
218                .map_err(|e| anyhow!("usearch reserve failed: {e}"))?;
219        }
220
221        self.index
222            .add(row_id as u64, &vector)
223            .map_err(|e| anyhow!("usearch add failed: {e}"))?;
224
225        Ok(row_id)
226    }
227
228    pub fn get_source_mtime(&self, source_file: &str) -> Result<Option<f64>> {
229        let mut stmt = self.db.prepare(
230            "SELECT source_mtime FROM memories WHERE source_file = ?1 ORDER BY id DESC LIMIT 1",
231        )?;
232        let mtime = stmt
233            .query_row(params![source_file], |row| row.get::<_, Option<f64>>(0))
234            .optional()?;
235        Ok(mtime.flatten())
236    }
237
238    pub fn search_room(
239        &self,
240        query: &str,
241        wing: &str,
242        room: &str,
243        limit: usize,
244        at_time: Option<i64>,
245    ) -> Result<Vec<MemoryRecord>> {
246        if limit == 0 {
247            return Ok(vec![]);
248        }
249        let at_time = at_time.unwrap_or_else(now_unix);
250        let query_vector = self.embed_single(query)?;
251
252        let mut stmt = self.db.prepare_cached(
253            "SELECT id FROM memories
254             WHERE wing = ?1 AND room = ?2
255               AND valid_from <= ?3
256               AND (valid_to IS NULL OR valid_to >= ?3)",
257        )?;
258
259        let candidate_ids: Vec<u64> = stmt
260            .query_map(params![wing, room, at_time], |row| row.get::<_, i64>(0))?
261            .collect::<rusqlite::Result<Vec<_>>>()?
262            .into_iter()
263            .map(|id| id as u64)
264            .collect();
265
266        if candidate_ids.is_empty() {
267            return Ok(vec![]);
268        }
269
270        let candidate_set: std::collections::HashSet<u64> = candidate_ids.iter().cloned().collect();
271        let results = self
272            .index
273            .filtered_search(&query_vector, limit, |key: u64| {
274                candidate_set.contains(&key)
275            })
276            .map_err(|e| anyhow!("usearch filtered_search failed: {e}"))?;
277
278        if results.keys.is_empty() {
279            return Ok(vec![]);
280        }
281
282        let id_placeholders: String = results
283            .keys
284            .iter()
285            .enumerate()
286            .map(|(i, _)| format!("?{}", i + 1))
287            .collect::<Vec<_>>()
288            .join(", ");
289
290        let sql = format!(
291            "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score
292             FROM memories WHERE id IN ({id_placeholders})"
293        );
294
295        let mut stmt = self.db.prepare(&sql)?;
296        let params_vec: Vec<&dyn rusqlite::ToSql> = results
297            .keys
298            .iter()
299            .map(|k| k as &dyn rusqlite::ToSql)
300            .collect();
301
302        let mut record_map: std::collections::HashMap<i64, MemoryRecord> = stmt
303            .query_map(params_vec.as_slice(), |row| {
304                let last_accessed: i64 = row.get(7)?;
305                let access_count: i64 = row.get(8)?;
306                let base_score: f32 = row.get(9)?;
307                Ok(MemoryRecord {
308                    id: row.get(0)?,
309                    text_content: row.get(1)?,
310                    wing: row.get(2)?,
311                    room: row.get(3)?,
312                    source_file: row.get(4)?,
313                    valid_from: row.get(5)?,
314                    valid_to: row.get(6)?,
315                    score: 0.0,
316                    importance: compute_decayed_importance(base_score, last_accessed, access_count),
317                })
318            })?
319            .collect::<rusqlite::Result<Vec<_>>>()?
320            .into_iter()
321            .map(|r| (r.id, r))
322            .collect();
323
324        let mut ordered: Vec<MemoryRecord> = results
325            .keys
326            .iter()
327            .zip(results.distances.iter())
328            .filter_map(|(&key, &dist)| {
329                let id = key as i64;
330                record_map.remove(&id).map(|mut rec| {
331                    rec.score = 1.0 - dist;
332                    rec
333                })
334            })
335            .collect();
336
337        ordered.sort_by(|a, b| {
338            b.score
339                .partial_cmp(&a.score)
340                .unwrap_or(std::cmp::Ordering::Equal)
341        });
342        Ok(ordered)
343    }
344
345    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
346        if limit == 0 {
347            return Ok(vec![]);
348        }
349        let query_vector = self.embed_single(query)?;
350
351        let results = self
352            .index
353            .search(&query_vector, limit)
354            .map_err(|e| anyhow!("usearch search failed: {e}"))?;
355
356        if results.keys.is_empty() {
357            return Ok(vec![]);
358        }
359
360        let id_placeholders: String = results
361            .keys
362            .iter()
363            .enumerate()
364            .map(|(i, _)| format!("?{}", i + 1))
365            .collect::<Vec<_>>()
366            .join(", ");
367
368        let sql = format!(
369            "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score
370             FROM memories WHERE id IN ({id_placeholders})"
371        );
372
373        let mut stmt = self.db.prepare(&sql)?;
374        let params_vec: Vec<&dyn rusqlite::ToSql> = results
375            .keys
376            .iter()
377            .map(|k| k as &dyn rusqlite::ToSql)
378            .collect();
379
380        let mut record_map: std::collections::HashMap<i64, MemoryRecord> = stmt
381            .query_map(params_vec.as_slice(), |row| {
382                let last_accessed: i64 = row.get(7)?;
383                let access_count: i64 = row.get(8)?;
384                let base_score: f32 = row.get(9)?;
385                Ok(MemoryRecord {
386                    id: row.get(0)?,
387                    text_content: row.get(1)?,
388                    wing: row.get(2)?,
389                    room: row.get(3)?,
390                    source_file: row.get(4)?,
391                    valid_from: row.get(5)?,
392                    valid_to: row.get(6)?,
393                    score: 0.0,
394                    importance: compute_decayed_importance(base_score, last_accessed, access_count),
395                })
396            })?
397            .collect::<rusqlite::Result<Vec<_>>>()?
398            .into_iter()
399            .map(|r| (r.id, r))
400            .collect();
401
402        let mut ordered: Vec<MemoryRecord> = results
403            .keys
404            .iter()
405            .zip(results.distances.iter())
406            .filter_map(|(&key, &dist)| {
407                let id = key as i64;
408                record_map.remove(&id).map(|mut rec| {
409                    rec.score = 1.0 - dist;
410                    rec
411                })
412            })
413            .collect();
414
415        ordered.sort_by(|a, b| {
416            b.score
417                .partial_cmp(&a.score)
418                .unwrap_or(std::cmp::Ordering::Equal)
419        });
420        Ok(ordered)
421    }
422
423    pub fn get_memories(
424        &self,
425        wing: Option<&str>,
426        room: Option<&str>,
427        limit: usize,
428    ) -> Result<Vec<MemoryRecord>> {
429        let (sql, params_dyn): (String, Vec<Box<dyn rusqlite::ToSql>>) = match (wing, room) {
430            (Some(w), Some(r)) => (
431                format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE wing = ?1 AND room = ?2 ORDER BY valid_from DESC LIMIT {limit}"),
432                vec![Box::new(w.to_string()), Box::new(r.to_string())],
433            ),
434            (Some(w), None) => (
435                format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE wing = ?1 ORDER BY valid_from DESC LIMIT {limit}"),
436                vec![Box::new(w.to_string())],
437            ),
438            (None, Some(r)) => (
439                format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE room = ?1 ORDER BY valid_from DESC LIMIT {limit}"),
440                vec![Box::new(r.to_string())],
441            ),
442            (None, None) => (
443                format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories ORDER BY valid_from DESC LIMIT {limit}"),
444                vec![],
445            ),
446        };
447        let mut stmt = self.db.prepare(&sql)?;
448        let params_ref: Vec<&dyn rusqlite::ToSql> = params_dyn.iter().map(|p| p.as_ref()).collect();
449        let records = stmt
450            .query_map(params_ref.as_slice(), |row| {
451                let last_accessed: i64 = row.get(7)?;
452                let access_count: i64 = row.get(8)?;
453                let base_score: f32 = row.get(9)?;
454                Ok(MemoryRecord {
455                    id: row.get(0)?,
456                    text_content: row.get(1)?,
457                    wing: row.get(2)?,
458                    room: row.get(3)?,
459                    source_file: row.get(4)?,
460                    valid_from: row.get(5)?,
461                    valid_to: row.get(6)?,
462                    score: 0.0,
463                    importance: compute_decayed_importance(base_score, last_accessed, access_count),
464                })
465            })?
466            .collect::<rusqlite::Result<Vec<_>>>()?;
467        Ok(records)
468    }
469
470    pub fn get_all_ids(&self, wing: Option<&str>) -> Result<Vec<i64>> {
471        if let Some(w) = wing {
472            let mut stmt = self.db.prepare("SELECT id FROM memories WHERE wing = ?1")?;
473            let ids = stmt
474                .query_map(params![w], |row| row.get(0))?
475                .collect::<rusqlite::Result<Vec<i64>>>()?;
476            Ok(ids)
477        } else {
478            let mut stmt = self.db.prepare("SELECT id FROM memories")?;
479            let ids = stmt
480                .query_map([], |row| row.get(0))?
481                .collect::<rusqlite::Result<Vec<i64>>>()?;
482            Ok(ids)
483        }
484    }
485
486    pub fn get_memory_by_id(&self, id: i64) -> Result<MemoryRecord> {
487        self.db.query_row(
488            "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE id = ?1",
489            params![id],
490            |row| {
491                let last_accessed: i64 = row.get(7)?;
492                let access_count: i64 = row.get(8)?;
493                let base_score: f32 = row.get(9)?;
494                Ok(MemoryRecord {
495                    id: row.get(0)?,
496                    text_content: row.get(1)?,
497                    wing: row.get(2)?,
498                    room: row.get(3)?,
499                    source_file: row.get(4)?,
500                    valid_from: row.get(5)?,
501                    valid_to: row.get(6)?,
502                    score: 0.0,
503                    importance: compute_decayed_importance(base_score, last_accessed, access_count),
504                })
505            },
506        ).context("Memory not found")
507    }
508
509    pub fn update_memory_summary(&self, id: i64, new_summary: &str) -> Result<()> {
510        self.db.execute(
511            "UPDATE memories SET text_content = ?1 WHERE id = ?2",
512            params![new_summary, id],
513        )?;
514        Ok(())
515    }
516
517    pub fn touch_memory(&self, id: i64) -> Result<()> {
518        let now = now_unix();
519        self.db.execute(
520            "UPDATE memories SET access_count = access_count + 1, last_accessed = ?1 WHERE id = ?2",
521            params![now, id],
522        )?;
523        Ok(())
524    }
525
526    pub fn delete_memory(&self, id: i64) -> Result<()> {
527        self.db
528            .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
529        Ok(())
530    }
531
532    pub fn has_source_file(&self, source_file: &str) -> Result<bool> {
533        let count: i64 = self.db.query_row(
534            "SELECT COUNT(*) FROM memories WHERE source_file = ?1 LIMIT 1",
535            params![source_file],
536            |row| row.get(0),
537        )?;
538        Ok(count > 0)
539    }
540
541    pub fn get_wings_rooms(&self) -> Result<Vec<(String, String)>> {
542        let mut stmt = self
543            .db
544            .prepare("SELECT DISTINCT wing, room FROM memories ORDER BY wing, room")?;
545        let pairs = stmt
546            .query_map([], |row| {
547                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
548            })?
549            .collect::<rusqlite::Result<Vec<_>>>()?;
550        Ok(pairs)
551    }
552
553    pub fn save_index(&self, index_path: impl AsRef<Path>) -> Result<()> {
554        let path = index_path
555            .as_ref()
556            .to_str()
557            .ok_or_else(|| anyhow!("Non-UTF8 path"))?;
558        self.index
559            .save(path)
560            .map_err(|e| anyhow!("Save failed: {e}"))
561    }
562
563    pub fn memory_count(&self) -> Result<u64> {
564        self.db
565            .query_row("SELECT COUNT(*) FROM memories", [], |row| {
566                row.get::<_, i64>(0)
567            })
568            .map(|n| n as u64)
569            .context("Count failed")
570    }
571
572    pub fn index_size(&self) -> usize {
573        self.index.size()
574    }
575
576    pub fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
577        let mut batch = self
578            .embedder
579            .embed(vec![text.to_string()], None)
580            .context("fastembed failed")?;
581        let vec = batch.pop().ok_or_else(|| anyhow!("Empty batch"))?;
582        if vec.len() != VECTOR_DIMS {
583            return Err(anyhow!("Expected {VECTOR_DIMS}-dim, got {}", vec.len()));
584        }
585        Ok(vec)
586    }
587}
588
589impl Drop for VectorStorage {
590    fn drop(&mut self) {
591        let _ = self.db.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);");
592    }
593}