use std::path::Path;
use anyhow::{Context, Result};
use parking_lot::Mutex;
use rusqlite::Connection;
const SCHEMA: &str = r#"
-- ─────────────────────────────────────────────
-- 1. Memory entries
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
memory_type TEXT NOT NULL,
content TEXT NOT NULL,
summary TEXT,
importance REAL NOT NULL DEFAULT 0.5,
tier TEXT NOT NULL DEFAULT 'warm',
protection TEXT NOT NULL DEFAULT 'none',
source TEXT NOT NULL DEFAULT 'unknown',
session_id TEXT,
tags TEXT, -- JSON array
metadata TEXT, -- JSON object
access_count INTEGER NOT NULL DEFAULT 0,
pinned INTEGER NOT NULL DEFAULT 0,
auto_classified INTEGER NOT NULL DEFAULT 0,
session_appearances INTEGER NOT NULL DEFAULT 0,
decay_score REAL NOT NULL DEFAULT 1.0,
compaction_level INTEGER NOT NULL DEFAULT 0,
content_hash INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
accessed_at TEXT,
decay_rate REAL NOT NULL DEFAULT 0.01
);
CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(memory_type);
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);
CREATE INDEX IF NOT EXISTS idx_memories_importance ON memories(importance);
CREATE INDEX IF NOT EXISTS idx_memories_tier ON memories(tier);
-- ─────────────────────────────────────────────
-- 2. FTS5 full-text search (BM25)
-- ─────────────────────────────────────────────
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
id,
content,
memory_type,
content='memories',
content_rowid='rowid',
tokenize="unicode61"
);
-- Triggers to keep FTS in sync with memories
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, id, content, memory_type)
VALUES (new.rowid, new.id, new.content, new.memory_type);
END;
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, id, content, memory_type)
VALUES ('delete', old.rowid, old.id, old.content, old.memory_type);
END;
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, id, content, memory_type)
VALUES ('delete', old.rowid, old.id, old.content, old.memory_type);
INSERT INTO memories_fts(rowid, id, content, memory_type)
VALUES (new.rowid, new.id, new.content, new.memory_type);
END;
-- ─────────────────────────────────────────────
-- 3. Embedding cache
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS embedding_cache (
content_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TEXT NOT NULL
);
-- ─────────────────────────────────────────────
-- 4. Dream state
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS dream_state (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
-- ─────────────────────────────────────────────
-- 5. Learning patterns
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS patterns (
id TEXT PRIMARY KEY,
strategy TEXT NOT NULL,
domain TEXT,
quality REAL NOT NULL DEFAULT 0.5,
use_count INTEGER NOT NULL DEFAULT 0,
success_rate REAL NOT NULL DEFAULT 0.0,
is_long_term INTEGER NOT NULL DEFAULT 0,
embedding BLOB,
data TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
-- ─────────────────────────────────────────────
-- 6. Projects (RFC-011)
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT,
paths TEXT, -- JSON array of PathBuf strings
tags TEXT, -- JSON array of strings
emoji TEXT NOT NULL DEFAULT '📦',
source TEXT NOT NULL DEFAULT 'manual',
memory_visible INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
last_active_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_projects_name ON projects(name);
-- ─────────────────────────────────────────────
-- 7. Project-Memory junction (RFC-011)
-- ─────────────────────────────────────────────
CREATE TABLE IF NOT EXISTS project_memory (
project_id TEXT NOT NULL,
memory_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (project_id, memory_id)
);
CREATE INDEX IF NOT EXISTS idx_pm_project ON project_memory(project_id);
CREATE INDEX IF NOT EXISTS idx_pm_memory ON project_memory(memory_id);
"#;
const VEC_SCHEMA_TEMPLATE: &str = r#"
CREATE VIRTUAL TABLE IF NOT EXISTS memory_vectors USING vec0(
embedding float[{DIM}]
);
"#;
pub struct MemoryDatabase {
conn: Mutex<Connection>,
embedding_dim: usize,
}
impl std::fmt::Debug for MemoryDatabase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryDatabase")
.field("embedding_dim", &self.embedding_dim)
.finish()
}
}
impl MemoryDatabase {
pub fn open(db_path: &Path, embedding_dim: usize) -> Result<Self> {
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Creating memory DB directory: {}", parent.display()))?;
}
Self::register_vec_extension();
let conn = Connection::open(db_path)
.with_context(|| format!("Opening memory DB: {}", db_path.display()))?;
conn.execute_batch("PRAGMA journal_mode=WAL;")?;
conn.execute_batch("PRAGMA synchronous=NORMAL;")?;
conn.execute_batch("PRAGMA foreign_keys=ON;")?;
conn.execute_batch(SCHEMA)
.context("Initializing memory database schema")?;
conn.execute_batch(&VEC_SCHEMA_TEMPLATE.replace("{DIM}", &embedding_dim.to_string()))
.context("Initializing sqlite-vec virtual table")?;
tracing::info!(
path = %db_path.display(),
dim = embedding_dim,
"Memory database opened"
);
Ok(Self {
conn: Mutex::new(conn),
embedding_dim,
})
}
pub fn open_in_memory(embedding_dim: usize) -> Result<Self> {
Self::register_vec_extension();
let conn = Connection::open_in_memory()?;
conn.execute_batch(SCHEMA)?;
conn.execute_batch(&VEC_SCHEMA_TEMPLATE.replace("{DIM}", &embedding_dim.to_string()))?;
Ok(Self {
conn: Mutex::new(conn),
embedding_dim,
})
}
fn register_vec_extension() {
static REGISTERED: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
if !REGISTERED.swap(true, std::sync::atomic::Ordering::SeqCst) {
unsafe {
#[allow(clippy::missing_transmute_annotations)]
rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(
sqlite_vec::sqlite3_vec_init as *const (),
)));
}
}
}
pub fn conn(&self) -> parking_lot::MutexGuard<'_, Connection> {
self.conn.lock()
}
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
pub fn backup(&self, backup_path: &Path) -> Result<()> {
{
let conn = self.conn();
conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
}
let db_path = {
let conn = self.conn();
conn.path()
.map(std::path::PathBuf::from)
.ok_or_else(|| anyhow::anyhow!("Cannot backup in-memory database"))?
};
std::fs::copy(&db_path, backup_path).with_context(|| {
format!("Copying {} to {}", db_path.display(), backup_path.display())
})?;
tracing::info!(path = %backup_path.display(), "Memory database backed up");
Ok(())
}
pub fn get_dream_state(&self, key: &str) -> Result<Option<String>> {
let conn = self.conn();
let mut stmt = conn.prepare("SELECT value FROM dream_state WHERE key = ?1")?;
let mut rows = stmt.query(rusqlite::params![key])?;
match rows.next()? {
Some(row) => Ok(Some(row.get(0)?)),
None => Ok(None),
}
}
pub fn set_dream_state(&self, key: &str, value: &str) -> Result<()> {
let conn = self.conn();
conn.execute(
"INSERT OR REPLACE INTO dream_state (key, value) VALUES (?1, ?2)",
rusqlite::params![key, value],
)?;
Ok(())
}
pub fn is_migration_complete(&self) -> bool {
self.get_dream_state("migration_v1_complete")
.ok()
.flatten()
.map(|v| v == "true")
.unwrap_or(false)
}
pub fn save_project(&self, project: &crate::project::Project) -> Result<()> {
let conn = self.conn();
conn.execute(
"INSERT OR REPLACE INTO projects
(id, name, description, paths, tags, emoji, source, memory_visible, created_at, updated_at, last_active_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
rusqlite::params![
project.id.to_string(),
project.name,
project.description,
serde_json::to_string(&project.paths)?,
serde_json::to_string(&project.tags)?,
project.emoji,
project.source.to_string(),
project.memory_visible as i32,
project.created_at.to_rfc3339(),
project.updated_at.to_rfc3339(),
project.last_active_at.to_rfc3339(),
],
)?;
Ok(())
}
pub fn list_projects(&self) -> Result<Vec<crate::project::Project>> {
let conn = self.conn();
let mut stmt = conn.prepare(
"SELECT id, name, description, paths, tags, emoji, source, memory_visible,
created_at, updated_at, last_active_at
FROM projects ORDER BY name",
)?;
let rows = stmt.query_map([], row_to_project)?;
rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
}
pub fn delete_project(&self, id: &str) -> Result<()> {
let conn = self.conn();
conn.execute(
"DELETE FROM project_memory WHERE project_id = ?1",
rusqlite::params![id],
)?;
conn.execute("DELETE FROM projects WHERE id = ?1", rusqlite::params![id])?;
Ok(())
}
pub fn link_project_memory(&self, project_id: &str, memory_id: &str) -> Result<()> {
let conn = self.conn();
conn.execute(
"INSERT OR IGNORE INTO project_memory (project_id, memory_id, created_at) VALUES (?1, ?2, datetime('now'))",
rusqlite::params![project_id, memory_id],
)?;
Ok(())
}
pub fn unlink_project_memory(&self, project_id: &str, memory_id: &str) -> Result<()> {
let conn = self.conn();
conn.execute(
"DELETE FROM project_memory WHERE project_id = ?1 AND memory_id = ?2",
rusqlite::params![project_id, memory_id],
)?;
Ok(())
}
pub fn get_project_memory_ids(&self, project_id: &str) -> Result<Vec<String>> {
let conn = self.conn();
let mut stmt = conn.prepare(
"SELECT memory_id FROM project_memory WHERE project_id = ?1 ORDER BY created_at DESC",
)?;
let rows = stmt.query_map(rusqlite::params![project_id], |row| row.get(0))?;
rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
}
}
pub fn f32_slice_to_bytes(vec: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(vec.len() * 4);
for &v in vec {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes
}
pub fn bytes_to_f32_slice(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| {
let arr: [u8; 4] = chunk.try_into().expect("chunk must be 4 bytes");
f32::from_le_bytes(arr)
})
.collect()
}
fn row_to_project(row: &rusqlite::Row<'_>) -> rusqlite::Result<crate::project::Project> {
use crate::project::{Project, ProjectSource};
use chrono::{DateTime, Utc};
use std::path::PathBuf;
let id_str: String = row.get(0)?;
let name: String = row.get(1)?;
let description: String = row.get::<_, Option<String>>(2)?.unwrap_or_default();
let paths_str: String = row
.get::<_, Option<String>>(3)?
.unwrap_or_else(|| "[]".to_string());
let tags_str: String = row
.get::<_, Option<String>>(4)?
.unwrap_or_else(|| "[]".to_string());
let emoji: String = row
.get::<_, Option<String>>(5)?
.unwrap_or_else(|| "📦".to_string());
let source_str: String = row
.get::<_, Option<String>>(6)?
.unwrap_or_else(|| "manual".to_string());
let memory_visible: bool = row.get::<_, Option<i32>>(7)?.unwrap_or(1) != 0;
let created_at: String = row.get(8)?;
let updated_at: String = row.get(9)?;
let last_active_at: String = row.get(10)?;
let id = uuid::Uuid::parse_str(&id_str).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
})?;
let paths: Vec<PathBuf> = serde_json::from_str(&paths_str).unwrap_or_default();
let tags: Vec<String> = serde_json::from_str(&tags_str).unwrap_or_default();
let source = match source_str.as_str() {
"auto_detected" => ProjectSource::AutoDetected,
_ => ProjectSource::Manual,
};
Ok(Project {
id,
name,
description,
paths,
tags,
emoji,
source,
memory_visible,
created_at: created_at
.parse::<DateTime<Utc>>()
.unwrap_or_else(|_| Utc::now()),
updated_at: updated_at
.parse::<DateTime<Utc>>()
.unwrap_or_else(|_| Utc::now()),
last_active_at: last_active_at
.parse::<DateTime<Utc>>()
.unwrap_or_else(|_| Utc::now()),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_db_schema_init() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let conn = db.conn();
let tables: Vec<String> = conn
.prepare(
"SELECT name FROM sqlite_master WHERE type='table' OR type='view' ORDER BY name",
)
.unwrap()
.query_map([], |row| row.get(0))
.unwrap()
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(error = %e, "Failed to deserialize memory row, skipping");
None
}
})
.collect();
assert!(
tables.contains(&"memories".to_string()),
"memories table missing"
);
assert!(
tables.contains(&"embedding_cache".to_string()),
"embedding_cache table missing"
);
assert!(
tables.contains(&"dream_state".to_string()),
"dream_state table missing"
);
assert!(
tables.contains(&"patterns".to_string()),
"patterns table missing"
);
assert!(
tables.contains(&"projects".to_string()),
"projects table missing"
);
assert!(
tables.contains(&"project_memory".to_string()),
"project_memory table missing"
);
}
#[test]
fn test_db_fts5_tables() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let conn = db.conn();
let tables: Vec<String> = conn
.prepare(
"SELECT name FROM sqlite_master WHERE type='table' OR type='view' ORDER BY name",
)
.unwrap()
.query_map([], |row| row.get(0))
.unwrap()
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(error = %e, "Failed to deserialize memory row, skipping");
None
}
})
.collect();
assert!(
tables.contains(&"memories_fts".to_string()),
"memories_fts missing"
);
}
#[test]
fn test_dream_state() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
assert_eq!(db.get_dream_state("test_key").unwrap(), None);
db.set_dream_state("test_key", "hello").unwrap();
assert_eq!(
db.get_dream_state("test_key").unwrap(),
Some("hello".to_string())
);
db.set_dream_state("test_key", "updated").unwrap();
assert_eq!(
db.get_dream_state("test_key").unwrap(),
Some("updated".to_string())
);
}
#[test]
fn test_migration_flag() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
assert!(!db.is_migration_complete());
db.set_dream_state("migration_v1_complete", "true").unwrap();
assert!(db.is_migration_complete());
}
#[test]
fn test_f32_bytes_roundtrip() {
let original: Vec<f32> = vec![0.1, 0.2, 0.3, -1.5, 42.0, 0.0];
let bytes = f32_slice_to_bytes(&original);
let restored = bytes_to_f32_slice(&bytes);
assert_eq!(original.len(), restored.len());
for (a, b) in original.iter().zip(restored.iter()) {
assert!((a - b).abs() < 1e-6, "Mismatch: {} vs {}", a, b);
}
}
#[test]
fn test_insert_and_query_memory() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let conn = db.conn();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params![
"test-id-1",
"fact",
"Rust is a systems programming language",
0.6,
"warm",
"test",
"2026-01-01T00:00:00Z",
"2026-01-01T00:00:00Z",
],
).unwrap();
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 1);
let content: String = conn
.query_row(
"SELECT content FROM memories WHERE id = ?1",
rusqlite::params!["test-id-1"],
|row| row.get(0),
)
.unwrap();
assert_eq!(content, "Rust is a systems programming language");
}
#[test]
fn test_fts5_korean_search() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let conn = db.conn();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('kr-1', 'fact', '한국어 테스트 메모리입니다', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('kr-2', 'fact', '영어 테스트 데이터입니다', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
let results: Vec<String> = conn
.prepare("SELECT id FROM memories_fts WHERE memories_fts MATCH ?1")
.unwrap()
.query_map(rusqlite::params!["한국어"], |row| row.get(0))
.unwrap()
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(error = %e, "Failed to deserialize memory row, skipping");
None
}
})
.collect();
assert!(
results.contains(&"kr-1".to_string()),
"Korean FTS should find kr-1, got: {:?}",
results
);
}
#[test]
fn test_fts5_bm25_scoring() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let conn = db.conn();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('bm-1', 'fact', 'Rust programming language safety', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('bm-2', 'fact', 'Python programming data science', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
conn.execute(
"INSERT INTO memories (id, memory_type, content, importance, tier, source, created_at, updated_at)
VALUES ('bm-3', 'fact', 'Rust Rust Rust systems programming', 0.5, 'warm', 'test', '2026-01-01T00:00:00Z', '2026-01-01T00:00:00Z')",
[],
).unwrap();
let results: Vec<(String, f64)> = conn
.prepare(
"SELECT m.id, -bm25(memories_fts) as score
FROM memories_fts f
JOIN memories m ON m.id = f.id
WHERE memories_fts MATCH 'Rust'
ORDER BY score DESC",
)
.unwrap()
.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
})
.unwrap()
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(error = %e, "Failed to deserialize memory row, skipping");
None
}
})
.collect();
assert!(!results.is_empty(), "BM25 should return results");
assert_eq!(results[0].0, "bm-3", "Most relevant should be bm-3");
}
#[test]
fn test_backup_skipped_in_memory() {
let db = MemoryDatabase::open_in_memory(256).unwrap();
let dir = tempfile::tempdir().unwrap();
let backup_path = dir.path().join("backup.db");
assert!(db.backup(&backup_path).is_err());
}
}