1use std::path::Path;
9use std::sync::Arc;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use anyhow::{anyhow, Context, Result};
13use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
14use rusqlite::{params, Connection, OptionalExtension};
15use std::path::PathBuf;
16use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
17
18const VECTOR_DIMS: usize = 384;
19const HNSW_M: usize = 16;
20const HNSW_EF_CONSTRUCTION: usize = 128;
21
22#[derive(Debug, Clone)]
24pub struct MemoryRecord {
25 pub id: i64,
26 pub text_content: String,
27 pub wing: String,
28 pub room: String,
29 pub source_file: Option<String>,
30 pub valid_from: i64,
31 pub valid_to: Option<i64>,
32 pub score: f32,
33 pub importance: f32,
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct TemporalRange {
39 pub valid_from: Option<i64>,
40 pub valid_to: Option<i64>,
41}
42
43fn now_unix() -> i64 {
44 SystemTime::now()
45 .duration_since(UNIX_EPOCH)
46 .expect("system clock before Unix epoch")
47 .as_secs() as i64
48}
49
50fn compute_decayed_importance(base_score: f32, last_accessed: i64, access_count: i64) -> f32 {
51 let days_since = ((now_unix() - last_accessed) as f32 / 86400.0).max(0.0);
52 let freq_boost = (1.0 + access_count as f32).ln().max(1.0);
53 base_score * 0.9f32.powf(days_since) * freq_boost
54}
55
56fn build_index() -> Result<Index> {
57 let opts = IndexOptions {
58 dimensions: VECTOR_DIMS,
59 metric: MetricKind::Cos,
60 quantization: ScalarKind::F32,
61 connectivity: HNSW_M,
62 expansion_add: HNSW_EF_CONSTRUCTION,
63 expansion_search: 64,
64 ..Default::default()
65 };
66 Index::new(&opts).map_err(|e| anyhow!("usearch index creation failed: {e}"))
67}
68
69pub struct VectorStorage {
71 pub embedder: Arc<TextEmbedding>,
72 pub db: Connection,
73 pub index: Index,
74}
75
76impl VectorStorage {
77 pub fn new(db_path: impl AsRef<Path>, index_path: impl AsRef<Path>) -> Result<Self> {
78 let cache_dir = std::env::var("MEMPALACE_MODELS_DIR")
79 .ok()
80 .map(PathBuf::from)
81 .filter(|p| p.exists())
82 .or_else(|| {
83 std::env::current_exe()
84 .ok()
85 .and_then(|exe| exe.parent().map(|p| p.join("models")))
86 .filter(|p| p.exists())
87 });
88
89 let mut init_opts =
90 InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(false);
91
92 if let Some(cache) = cache_dir {
93 init_opts = init_opts.with_cache_dir(cache);
94 }
95
96 let embedder =
97 TextEmbedding::try_new(init_opts).context("Failed to initialise fastembed")?;
98
99 Self::new_with_embedder(db_path, index_path, Arc::new(embedder))
100 }
101
102 pub fn new_with_embedder(
103 db_path: impl AsRef<Path>,
104 index_path: impl AsRef<Path>,
105 embedder: Arc<TextEmbedding>,
106 ) -> Result<Self> {
107 let db = Connection::open(db_path.as_ref())
109 .with_context(|| format!("Cannot open SQLite at {:?}", db_path.as_ref()))?;
110
111 db.execute_batch(
112 "PRAGMA journal_mode = WAL;
113 PRAGMA foreign_keys = ON;
114 PRAGMA synchronous = NORMAL;
115 CREATE TABLE IF NOT EXISTS memories (
116 id INTEGER PRIMARY KEY AUTOINCREMENT,
117 text_content TEXT NOT NULL,
118 wing TEXT NOT NULL,
119 room TEXT NOT NULL,
120 source_file TEXT,
121 source_mtime REAL,
122 valid_from INTEGER NOT NULL,
123 valid_to INTEGER,
124 last_accessed INTEGER DEFAULT 0,
125 access_count INTEGER DEFAULT 0,
126 importance_score REAL DEFAULT 5.0
127 );
128 CREATE INDEX IF NOT EXISTS idx_source_file ON memories (source_file);
129 CREATE INDEX IF NOT EXISTS idx_wing_room ON memories (wing, room);
130 CREATE INDEX IF NOT EXISTS idx_valid ON memories (valid_from, valid_to);
131 CREATE TABLE IF NOT EXISTS drawers (
132 id INTEGER PRIMARY KEY AUTOINCREMENT,
133 content TEXT NOT NULL,
134 wing TEXT NOT NULL,
135 room TEXT NOT NULL,
136 source_file TEXT,
137 filed_at TEXT NOT NULL,
138 embedding_id INTEGER REFERENCES memories(id)
139 );
140 CREATE INDEX IF NOT EXISTS idx_drawers_wing_room ON drawers (wing, room);
141 ",
142 )?;
143
144 {
145 let mut check_stmt = db.prepare("PRAGMA table_info(memories)")?;
146 let mut has_accessed = false;
147 let mut has_mtime = false;
148 let mut rows = check_stmt.query([])?;
149 while let Some(row) = rows.next()? {
150 let name: String = row.get(1)?;
151 if name == "last_accessed" {
152 has_accessed = true;
153 }
154 if name == "source_mtime" {
155 has_mtime = true;
156 }
157 }
158 if !has_accessed {
159 db.execute_batch(
160 "ALTER TABLE memories ADD COLUMN last_accessed INTEGER DEFAULT 0;
161 ALTER TABLE memories ADD COLUMN access_count INTEGER DEFAULT 0;
162 ALTER TABLE memories ADD COLUMN importance_score REAL DEFAULT 5.0;",
163 )?;
164 let now = now_unix();
165 db.execute("UPDATE memories SET last_accessed = ?1", params![now])?;
166 }
167 if !has_mtime {
168 db.execute_batch("ALTER TABLE memories ADD COLUMN source_mtime REAL;")?;
169 }
170 }
171
172 let index_path = index_path.as_ref();
174 let index = if index_path.exists() {
175 let idx = build_index()?;
176 idx.load(
177 index_path
178 .to_str()
179 .ok_or_else(|| anyhow!("Non-UTF8 index path"))?,
180 )
181 .map_err(|e| anyhow!("Failed to load usearch index: {e}"))?;
182 idx
183 } else {
184 build_index()?
185 };
186
187 Ok(Self {
188 embedder,
189 db,
190 index,
191 })
192 }
193
194 pub fn add_memory(
195 &mut self,
196 text: &str,
197 wing: &str,
198 room: &str,
199 source_file: Option<&str>,
200 source_mtime: Option<f64>,
201 ) -> Result<i64> {
202 let vector = self.embed_single(text)?;
203 let valid_from = now_unix();
204
205 self.db.execute(
206 "INSERT INTO memories (text_content, wing, room, source_file, source_mtime, valid_from, last_accessed, access_count, importance_score)
207 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 0, 5.0)",
208 params![text, wing, room, source_file, source_mtime, valid_from, valid_from],
209 )?;
210
211 let row_id = self.db.last_insert_rowid();
212
213 let needed = self.index.size() + 1;
214 if needed > self.index.capacity() {
215 let new_cap = (needed * 2).max(64);
216 self.index
217 .reserve(new_cap)
218 .map_err(|e| anyhow!("usearch reserve failed: {e}"))?;
219 }
220
221 self.index
222 .add(row_id as u64, &vector)
223 .map_err(|e| anyhow!("usearch add failed: {e}"))?;
224
225 Ok(row_id)
226 }
227
228 pub fn get_source_mtime(&self, source_file: &str) -> Result<Option<f64>> {
229 let mut stmt = self.db.prepare(
230 "SELECT source_mtime FROM memories WHERE source_file = ?1 ORDER BY id DESC LIMIT 1",
231 )?;
232 let mtime = stmt
233 .query_row(params![source_file], |row| row.get::<_, Option<f64>>(0))
234 .optional()?;
235 Ok(mtime.flatten())
236 }
237
238 pub fn search_room(
239 &self,
240 query: &str,
241 wing: &str,
242 room: &str,
243 limit: usize,
244 at_time: Option<i64>,
245 ) -> Result<Vec<MemoryRecord>> {
246 if limit == 0 {
247 return Ok(vec![]);
248 }
249 let at_time = at_time.unwrap_or_else(now_unix);
250 let query_vector = self.embed_single(query)?;
251
252 let mut stmt = self.db.prepare_cached(
253 "SELECT id FROM memories
254 WHERE wing = ?1 AND room = ?2
255 AND valid_from <= ?3
256 AND (valid_to IS NULL OR valid_to >= ?3)",
257 )?;
258
259 let candidate_ids: Vec<u64> = stmt
260 .query_map(params![wing, room, at_time], |row| row.get::<_, i64>(0))?
261 .collect::<rusqlite::Result<Vec<_>>>()?
262 .into_iter()
263 .map(|id| id as u64)
264 .collect();
265
266 if candidate_ids.is_empty() {
267 return Ok(vec![]);
268 }
269
270 let candidate_set: std::collections::HashSet<u64> = candidate_ids.iter().cloned().collect();
271 let results = self
272 .index
273 .filtered_search(&query_vector, limit, |key: u64| {
274 candidate_set.contains(&key)
275 })
276 .map_err(|e| anyhow!("usearch filtered_search failed: {e}"))?;
277
278 if results.keys.is_empty() {
279 return Ok(vec![]);
280 }
281
282 let id_placeholders: String = results
283 .keys
284 .iter()
285 .enumerate()
286 .map(|(i, _)| format!("?{}", i + 1))
287 .collect::<Vec<_>>()
288 .join(", ");
289
290 let sql = format!(
291 "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score
292 FROM memories WHERE id IN ({id_placeholders})"
293 );
294
295 let mut stmt = self.db.prepare(&sql)?;
296 let params_vec: Vec<&dyn rusqlite::ToSql> = results
297 .keys
298 .iter()
299 .map(|k| k as &dyn rusqlite::ToSql)
300 .collect();
301
302 let mut record_map: std::collections::HashMap<i64, MemoryRecord> = stmt
303 .query_map(params_vec.as_slice(), |row| {
304 let last_accessed: i64 = row.get(7)?;
305 let access_count: i64 = row.get(8)?;
306 let base_score: f32 = row.get(9)?;
307 Ok(MemoryRecord {
308 id: row.get(0)?,
309 text_content: row.get(1)?,
310 wing: row.get(2)?,
311 room: row.get(3)?,
312 source_file: row.get(4)?,
313 valid_from: row.get(5)?,
314 valid_to: row.get(6)?,
315 score: 0.0,
316 importance: compute_decayed_importance(base_score, last_accessed, access_count),
317 })
318 })?
319 .collect::<rusqlite::Result<Vec<_>>>()?
320 .into_iter()
321 .map(|r| (r.id, r))
322 .collect();
323
324 let mut ordered: Vec<MemoryRecord> = results
325 .keys
326 .iter()
327 .zip(results.distances.iter())
328 .filter_map(|(&key, &dist)| {
329 let id = key as i64;
330 record_map.remove(&id).map(|mut rec| {
331 rec.score = 1.0 - dist;
332 rec
333 })
334 })
335 .collect();
336
337 ordered.sort_by(|a, b| {
338 b.score
339 .partial_cmp(&a.score)
340 .unwrap_or(std::cmp::Ordering::Equal)
341 });
342 Ok(ordered)
343 }
344
345 pub fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
346 if limit == 0 {
347 return Ok(vec![]);
348 }
349 let query_vector = self.embed_single(query)?;
350
351 let results = self
352 .index
353 .search(&query_vector, limit)
354 .map_err(|e| anyhow!("usearch search failed: {e}"))?;
355
356 if results.keys.is_empty() {
357 return Ok(vec![]);
358 }
359
360 let id_placeholders: String = results
361 .keys
362 .iter()
363 .enumerate()
364 .map(|(i, _)| format!("?{}", i + 1))
365 .collect::<Vec<_>>()
366 .join(", ");
367
368 let sql = format!(
369 "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score
370 FROM memories WHERE id IN ({id_placeholders})"
371 );
372
373 let mut stmt = self.db.prepare(&sql)?;
374 let params_vec: Vec<&dyn rusqlite::ToSql> = results
375 .keys
376 .iter()
377 .map(|k| k as &dyn rusqlite::ToSql)
378 .collect();
379
380 let mut record_map: std::collections::HashMap<i64, MemoryRecord> = stmt
381 .query_map(params_vec.as_slice(), |row| {
382 let last_accessed: i64 = row.get(7)?;
383 let access_count: i64 = row.get(8)?;
384 let base_score: f32 = row.get(9)?;
385 Ok(MemoryRecord {
386 id: row.get(0)?,
387 text_content: row.get(1)?,
388 wing: row.get(2)?,
389 room: row.get(3)?,
390 source_file: row.get(4)?,
391 valid_from: row.get(5)?,
392 valid_to: row.get(6)?,
393 score: 0.0,
394 importance: compute_decayed_importance(base_score, last_accessed, access_count),
395 })
396 })?
397 .collect::<rusqlite::Result<Vec<_>>>()?
398 .into_iter()
399 .map(|r| (r.id, r))
400 .collect();
401
402 let mut ordered: Vec<MemoryRecord> = results
403 .keys
404 .iter()
405 .zip(results.distances.iter())
406 .filter_map(|(&key, &dist)| {
407 let id = key as i64;
408 record_map.remove(&id).map(|mut rec| {
409 rec.score = 1.0 - dist;
410 rec
411 })
412 })
413 .collect();
414
415 ordered.sort_by(|a, b| {
416 b.score
417 .partial_cmp(&a.score)
418 .unwrap_or(std::cmp::Ordering::Equal)
419 });
420 Ok(ordered)
421 }
422
423 pub fn get_memories(
424 &self,
425 wing: Option<&str>,
426 room: Option<&str>,
427 limit: usize,
428 ) -> Result<Vec<MemoryRecord>> {
429 let (sql, params_dyn): (String, Vec<Box<dyn rusqlite::ToSql>>) = match (wing, room) {
430 (Some(w), Some(r)) => (
431 format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE wing = ?1 AND room = ?2 ORDER BY valid_from DESC LIMIT {limit}"),
432 vec![Box::new(w.to_string()), Box::new(r.to_string())],
433 ),
434 (Some(w), None) => (
435 format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE wing = ?1 ORDER BY valid_from DESC LIMIT {limit}"),
436 vec![Box::new(w.to_string())],
437 ),
438 (None, Some(r)) => (
439 format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE room = ?1 ORDER BY valid_from DESC LIMIT {limit}"),
440 vec![Box::new(r.to_string())],
441 ),
442 (None, None) => (
443 format!("SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories ORDER BY valid_from DESC LIMIT {limit}"),
444 vec![],
445 ),
446 };
447 let mut stmt = self.db.prepare(&sql)?;
448 let params_ref: Vec<&dyn rusqlite::ToSql> = params_dyn.iter().map(|p| p.as_ref()).collect();
449 let records = stmt
450 .query_map(params_ref.as_slice(), |row| {
451 let last_accessed: i64 = row.get(7)?;
452 let access_count: i64 = row.get(8)?;
453 let base_score: f32 = row.get(9)?;
454 Ok(MemoryRecord {
455 id: row.get(0)?,
456 text_content: row.get(1)?,
457 wing: row.get(2)?,
458 room: row.get(3)?,
459 source_file: row.get(4)?,
460 valid_from: row.get(5)?,
461 valid_to: row.get(6)?,
462 score: 0.0,
463 importance: compute_decayed_importance(base_score, last_accessed, access_count),
464 })
465 })?
466 .collect::<rusqlite::Result<Vec<_>>>()?;
467 Ok(records)
468 }
469
470 pub fn get_all_ids(&self, wing: Option<&str>) -> Result<Vec<i64>> {
471 if let Some(w) = wing {
472 let mut stmt = self.db.prepare("SELECT id FROM memories WHERE wing = ?1")?;
473 let ids = stmt
474 .query_map(params![w], |row| row.get(0))?
475 .collect::<rusqlite::Result<Vec<i64>>>()?;
476 Ok(ids)
477 } else {
478 let mut stmt = self.db.prepare("SELECT id FROM memories")?;
479 let ids = stmt
480 .query_map([], |row| row.get(0))?
481 .collect::<rusqlite::Result<Vec<i64>>>()?;
482 Ok(ids)
483 }
484 }
485
486 pub fn get_memory_by_id(&self, id: i64) -> Result<MemoryRecord> {
487 self.db.query_row(
488 "SELECT id, text_content, wing, room, source_file, valid_from, valid_to, last_accessed, access_count, importance_score FROM memories WHERE id = ?1",
489 params![id],
490 |row| {
491 let last_accessed: i64 = row.get(7)?;
492 let access_count: i64 = row.get(8)?;
493 let base_score: f32 = row.get(9)?;
494 Ok(MemoryRecord {
495 id: row.get(0)?,
496 text_content: row.get(1)?,
497 wing: row.get(2)?,
498 room: row.get(3)?,
499 source_file: row.get(4)?,
500 valid_from: row.get(5)?,
501 valid_to: row.get(6)?,
502 score: 0.0,
503 importance: compute_decayed_importance(base_score, last_accessed, access_count),
504 })
505 },
506 ).context("Memory not found")
507 }
508
509 pub fn update_memory_summary(&self, id: i64, new_summary: &str) -> Result<()> {
510 self.db.execute(
511 "UPDATE memories SET text_content = ?1 WHERE id = ?2",
512 params![new_summary, id],
513 )?;
514 Ok(())
515 }
516
517 pub fn touch_memory(&self, id: i64) -> Result<()> {
518 let now = now_unix();
519 self.db.execute(
520 "UPDATE memories SET access_count = access_count + 1, last_accessed = ?1 WHERE id = ?2",
521 params![now, id],
522 )?;
523 Ok(())
524 }
525
526 pub fn delete_memory(&self, id: i64) -> Result<()> {
527 self.db
528 .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
529 Ok(())
530 }
531
532 pub fn has_source_file(&self, source_file: &str) -> Result<bool> {
533 let count: i64 = self.db.query_row(
534 "SELECT COUNT(*) FROM memories WHERE source_file = ?1 LIMIT 1",
535 params![source_file],
536 |row| row.get(0),
537 )?;
538 Ok(count > 0)
539 }
540
541 pub fn get_wings_rooms(&self) -> Result<Vec<(String, String)>> {
542 let mut stmt = self
543 .db
544 .prepare("SELECT DISTINCT wing, room FROM memories ORDER BY wing, room")?;
545 let pairs = stmt
546 .query_map([], |row| {
547 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
548 })?
549 .collect::<rusqlite::Result<Vec<_>>>()?;
550 Ok(pairs)
551 }
552
553 pub fn save_index(&self, index_path: impl AsRef<Path>) -> Result<()> {
554 let path = index_path
555 .as_ref()
556 .to_str()
557 .ok_or_else(|| anyhow!("Non-UTF8 path"))?;
558 self.index
559 .save(path)
560 .map_err(|e| anyhow!("Save failed: {e}"))
561 }
562
563 pub fn memory_count(&self) -> Result<u64> {
564 self.db
565 .query_row("SELECT COUNT(*) FROM memories", [], |row| {
566 row.get::<_, i64>(0)
567 })
568 .map(|n| n as u64)
569 .context("Count failed")
570 }
571
572 pub fn index_size(&self) -> usize {
573 self.index.size()
574 }
575
576 pub fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
577 let mut batch = self
578 .embedder
579 .embed(vec![text.to_string()], None)
580 .context("fastembed failed")?;
581 let vec = batch.pop().ok_or_else(|| anyhow!("Empty batch"))?;
582 if vec.len() != VECTOR_DIMS {
583 return Err(anyhow!("Expected {VECTOR_DIMS}-dim, got {}", vec.len()));
584 }
585 Ok(vec)
586 }
587}
588
589impl Drop for VectorStorage {
590 fn drop(&mut self) {
591 let _ = self.db.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);");
592 }
593}