pub mod embedding;
pub mod fts;
pub mod query_mod;
pub mod search;
pub use self::embedding::{blob_to_vec, vec_to_blob};
pub use self::query_mod::map_row_to_memory;
use chrono::Utc;
use rusqlite::{Connection, OptionalExtension, Result as SqliteResult, params};
use std::path::Path;
use uuid::Uuid;
#[derive(Clone, Debug)]
pub struct Memory {
pub id: String,
pub project_id: String,
pub content: String,
pub metadata: Option<String>,
#[allow(dead_code)] pub embedding: Vec<f32>,
pub similarity: Option<f64>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug)]
pub enum Error {
Sqlite(String),
InvalidBlobSize { expected: usize, actual: usize },
MismatchedDimensions { expected: usize, actual: usize },
EmptyVector,
InvalidEmbedding(String),
InvalidLimit(String),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Sqlite(msg) => write!(f, "Database error: {}", msg),
Error::InvalidBlobSize { expected, actual } => {
write!(
f,
"Invalid BLOB size: expected {} bytes, got {} bytes",
expected, actual
)
}
Error::MismatchedDimensions { expected, actual } => {
write!(
f,
"Mismatched dimensions: expected {} dimensions, got {} dimensions",
expected, actual
)
}
Error::EmptyVector => write!(f, "Cannot compute similarity with empty vector"),
Error::InvalidEmbedding(msg) => write!(f, "Invalid embedding: {}", msg),
Error::InvalidLimit(msg) => write!(f, "Invalid limit: {}", msg),
}
}
}
impl std::error::Error for Error {}
impl From<rusqlite::Error> for Error {
fn from(err: rusqlite::Error) -> Self {
Error::Sqlite(err.to_string())
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub struct Database {
conn: Connection,
}
fn create_schema(conn: &mut Connection) -> Result<()> {
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
metadata TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_memories_project ON memories(project_id);
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
content,
project_id UNINDEXED,
tokenize='porter unicode61',
content_rowid='rowid',
content='memories'
);
CREATE TRIGGER IF NOT EXISTS memories_fts_insert AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, content, project_id)
VALUES (new.rowid, new.content, new.project_id);
END;
CREATE TRIGGER IF NOT EXISTS memories_fts_delete AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, project_id)
VALUES('delete', old.rowid, old.content, old.project_id);
END;
CREATE TRIGGER IF NOT EXISTS memories_fts_update AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, content, project_id)
VALUES('delete', old.rowid, old.content, old.project_id);
INSERT INTO memories_fts(rowid, content, project_id)
VALUES (new.rowid, new.content, new.project_id);
END;
"#,
)?;
Ok(())
}
impl Database {
pub fn open(path: &Path) -> Result<Self> {
let mut conn = Connection::open(path)?;
create_schema(&mut conn)?;
Ok(Self { conn })
}
pub fn insert(
&self,
project_id: &str,
content: &str,
embedding: &[f32],
metadata: Option<&str>,
) -> Result<String> {
let id = Uuid::new_v4().to_string();
let now = Utc::now().to_rfc3339();
let blob = vec_to_blob(embedding)?;
self.conn.execute(
r#"
INSERT INTO memories (id, project_id, content, embedding, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
params![&id, project_id, content, &blob, metadata, &now, &now],
)?;
Ok(id)
}
#[cfg(test)]
pub(crate) fn insert_with_time(
&self,
project_id: &str,
content: &str,
embedding: &[f32],
metadata: Option<&str>,
created_at: &str,
updated_at: &str,
) -> Result<String> {
let id = Uuid::new_v4().to_string();
let blob = vec_to_blob(embedding)?;
self.conn.execute(
r#"
INSERT INTO memories (id, project_id, content, embedding, metadata, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
params![&id, project_id, content, &blob, metadata, created_at, updated_at],
)?;
Ok(id)
}
pub fn get(&self, id: &str) -> Result<Option<Memory>> {
let mut stmt = self.conn.prepare(
r#"
SELECT id, project_id, content, metadata, embedding, created_at, updated_at
FROM memories
WHERE id = ?1
"#,
)?;
let result = stmt.query_row([id], map_row_to_memory).optional()?;
Ok(result)
}
pub fn list(&self, project_id: &str, limit: usize) -> Result<Vec<Memory>> {
let mut stmt = self.conn.prepare(
r#"
SELECT id, project_id, content, metadata, embedding, created_at, updated_at
FROM memories
WHERE project_id = ?1
ORDER BY created_at DESC
LIMIT ?2
"#,
)?;
let memories: SqliteResult<Vec<Memory>> = stmt
.query_map(params![project_id, limit as i64], map_row_to_memory)?
.collect();
Ok(memories?)
}
pub fn update(&self, id: &str, content: &str, embedding: &[f32]) -> Result<()> {
let now = Utc::now().to_rfc3339();
let blob = vec_to_blob(embedding)?;
let rows = self.conn.execute(
r#"
UPDATE memories
SET content = ?1, embedding = ?2, updated_at = ?3
WHERE id = ?4
"#,
params![content, &blob, &now, id],
)?;
if rows == 0 {
return Err(Error::Sqlite("No memory found".to_string()));
}
Ok(())
}
pub fn delete(&self, id: &str) -> Result<bool> {
let rows = self
.conn
.execute("DELETE FROM memories WHERE id = ?1", [id])?;
Ok(rows > 0)
}
#[allow(dead_code)] pub fn list_since(
&self,
project_id: &str,
since_timestamp: &str,
limit: usize,
) -> Result<Vec<Memory>> {
let _parsed = chrono::DateTime::parse_from_rfc3339(since_timestamp)
.map_err(|e| Error::Sqlite(format!("Invalid RFC3339 timestamp: {}", e)))?;
let mut stmt = self.conn.prepare(
r#"
SELECT id, project_id, content, metadata, embedding, created_at, updated_at
FROM memories
WHERE project_id = ?1 AND created_at > ?2
ORDER BY created_at DESC
LIMIT ?3
"#,
)?;
let memories: SqliteResult<Vec<Memory>> = stmt
.query_map(
params![project_id, since_timestamp, limit as i64],
map_row_to_memory,
)?
.collect();
Ok(memories?)
}
#[allow(dead_code)] pub fn get_many(&self, ids: &[&str]) -> Result<Vec<Option<Memory>>> {
if ids.is_empty() {
return Ok(Vec::new());
}
let placeholders = ids
.iter()
.enumerate()
.map(|(i, _)| format!("?{}", i + 1))
.collect::<Vec<_>>()
.join(", ");
let query = format!(
r#"
SELECT id, project_id, content, metadata, embedding, created_at, updated_at
FROM memories
WHERE id IN ({})
"#,
placeholders
);
let mut stmt = self.conn.prepare(&query)?;
let params: Vec<&dyn rusqlite::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::ToSql).collect();
let rows: SqliteResult<Vec<(String, Memory)>> = stmt
.query_map(params.as_slice(), |row| {
let blob: Vec<u8> = row.get(4)?;
let embedding = blob_to_vec(&blob).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
4,
rusqlite::types::Type::Blob,
Box::new(e),
)
})?;
Ok((
row.get::<_, String>(0)?,
Memory {
id: row.get(0)?,
project_id: row.get(1)?,
content: row.get(2)?,
metadata: row.get(3)?,
embedding,
similarity: None,
created_at: row.get(5)?,
updated_at: row.get(6)?,
},
))
})?
.collect();
let found_memories: std::collections::HashMap<String, Memory> = rows?.into_iter().collect();
let results: Vec<Option<Memory>> = ids
.iter()
.map(|id| found_memories.get(*id).cloned())
.collect();
Ok(results)
}
#[cfg(test)]
pub(crate) fn conn(&self) -> &Connection {
&self.conn
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::EMBEDDING_DIMS;
use tempfile::TempDir;
fn create_test_db() -> Database {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.db");
let db = Database::open(&path).unwrap();
std::mem::forget(dir);
db
}
#[test]
fn test_insert_and_get() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id = db
.insert("proj1", "test content", &embedding, None)
.unwrap();
let memory = db.get(&id).unwrap();
assert!(memory.is_some());
let m = memory.unwrap();
assert_eq!(m.content, "test content");
assert_eq!(m.project_id, "proj1");
}
#[test]
fn test_insert_with_metadata() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id = db
.insert(
"proj1",
"test content",
&embedding,
Some(r#"{"key": "value"}"#),
)
.unwrap();
let m = db.get(&id).unwrap().unwrap();
assert_eq!(m.metadata, Some(r#"{"key": "value"}"#.to_string()));
}
#[test]
fn test_insert_invalid_embedding() {
let db = create_test_db();
let embedding = vec![0.1f32; 256];
let result = db.insert("proj1", "test", &embedding, None);
assert!(result.is_err());
}
#[test]
fn test_get_nonexistent() {
let db = create_test_db();
let memory = db.get("nonexistent").unwrap();
assert!(memory.is_none());
}
#[test]
fn test_list_ordering() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id1 = db
.insert_with_time(
"proj1",
"first",
&embedding,
None,
"2024-01-01T00:00:00Z",
"2024-01-01T00:00:00Z",
)
.unwrap();
let id2 = db
.insert_with_time(
"proj1",
"second",
&embedding,
None,
"2024-01-02T00:00:00Z",
"2024-01-02T00:00:00Z",
)
.unwrap();
let memories = db.list("proj1", 10).unwrap();
assert_eq!(memories.len(), 2);
assert_eq!(memories[0].id, id2); assert_eq!(memories[1].id, id1);
}
#[test]
fn test_list_limit() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
for i in 0..5 {
db.insert("proj1", &format!("content {}", i), &embedding, None)
.unwrap();
}
let memories = db.list("proj1", 2).unwrap();
assert_eq!(memories.len(), 2);
}
#[test]
fn test_update() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id = db.insert("proj1", "original", &embedding, None).unwrap();
db.update(&id, "updated", &embedding).unwrap();
let m = db.get(&id).unwrap().unwrap();
assert_eq!(m.content, "updated");
}
#[test]
fn test_update_nonexistent() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let result = db.update("nonexistent", "content", &embedding);
assert!(result.is_err());
}
#[test]
fn test_delete() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
let id = db.insert("proj1", "content", &embedding, None).unwrap();
let deleted = db.delete(&id).unwrap();
assert!(deleted);
let memory = db.get(&id).unwrap();
assert!(memory.is_none());
}
#[test]
fn test_delete_nonexistent() {
let db = create_test_db();
let deleted = db.delete("nonexistent").unwrap();
assert!(!deleted);
}
#[test]
fn test_project_isolation() {
let db = create_test_db();
let embedding = vec![0.1f32; 384];
db.insert("proj1", "proj1 content", &embedding, None)
.unwrap();
db.insert("proj2", "proj2 content", &embedding, None)
.unwrap();
let list1 = db.list("proj1", 10).unwrap();
let list2 = db.list("proj2", 10).unwrap();
assert_eq!(list1.len(), 1);
assert_eq!(list2.len(), 1);
assert_eq!(list1[0].project_id, "proj1");
assert_eq!(list2[0].project_id, "proj2");
}
#[test]
fn test_get_includes_embedding() {
let db = create_test_db();
let embedding = vec![0.1f32; EMBEDDING_DIMS];
let id = db
.insert("proj1", "test content", &embedding, None)
.unwrap();
let memory = db.get(&id).unwrap().unwrap();
assert_eq!(memory.embedding.len(), EMBEDDING_DIMS);
for (i, &val) in embedding.iter().enumerate() {
assert!((memory.embedding[i] - val).abs() < 1e-6);
}
}
#[test]
fn test_list_includes_embeddings() {
let db = create_test_db();
let embedding1 = vec![0.1f32; EMBEDDING_DIMS];
let embedding2 = vec![0.2f32; EMBEDDING_DIMS];
db.insert("proj1", "first", &embedding1, None).unwrap();
db.insert("proj1", "second", &embedding2, None).unwrap();
let memories = db.list("proj1", 10).unwrap();
assert_eq!(memories.len(), 2);
for memory in &memories {
assert_eq!(memory.embedding.len(), EMBEDDING_DIMS);
}
}
#[test]
fn test_embedding_roundtrip() {
let db = create_test_db();
let original = [0.123f32, 0.456f32, 0.789f32];
let mut full_embedding = vec![0.1f32; EMBEDDING_DIMS];
full_embedding[0] = original[0];
full_embedding[1] = original[1];
full_embedding[EMBEDDING_DIMS - 1] = original[2];
let id = db.insert("proj1", "test", &full_embedding, None).unwrap();
let memory = db.get(&id).unwrap().unwrap();
assert_eq!(memory.embedding.len(), EMBEDDING_DIMS);
assert!((memory.embedding[0] - original[0]).abs() < 1e-6);
assert!((memory.embedding[1] - original[1]).abs() < 1e-6);
assert!((memory.embedding[EMBEDDING_DIMS - 1] - original[2]).abs() < 1e-6);
}
}