use std::path::Path;
use chrono::{DateTime, Utc};
use rusqlite::{Connection, params};
use crate::db::{get_max_memory_id, init_schema};
use crate::error::MemoryError;
use crate::memory::{AddMemoryResult, AmendResult, GraphStats, Memory, Stats};
use crate::relationship::{
ConnectResult, StrengthenResult, add_relationship_event, canonicalize, get_relationship,
relationship_exists,
};
use crate::search::{SearchResult, surface_candidates};
use crate::{MAX_STRENGTHEN_SET, SearchParams};
pub struct MemoryStore {
conn: Connection,
cached_max_mem: i64,
}
impl MemoryStore {
pub fn open(path: impl AsRef<Path>) -> Result<Self, MemoryError> {
let conn = Connection::open(path)?;
Self::init(conn)
}
pub fn open_in_memory() -> Result<Self, MemoryError> {
let conn = Connection::open_in_memory()?;
Self::init(conn)
}
fn init(conn: Connection) -> Result<Self, MemoryError> {
init_schema(&conn)?;
let cached_max_mem = get_max_memory_id(&conn)?;
Ok(Self {
conn,
cached_max_mem,
})
}
pub fn max_memory_id(&self) -> i64 {
self.cached_max_mem
}
pub fn add(
&mut self,
text: &str,
source: Option<&str>,
) -> Result<AddMemoryResult, MemoryError> {
self.add_with_options(text, source, None)
}
pub fn add_with_options(
&mut self,
text: &str,
source: Option<&str>,
datetime_str: Option<&str>,
) -> Result<AddMemoryResult, MemoryError> {
let datetime = if let Some(dt_str) = datetime_str {
DateTime::parse_from_rfc3339(dt_str)
.map(|dt| dt.with_timezone(&Utc))
.map_err(|e| {
MemoryError::InvalidInput(format!(
"Invalid datetime format (expected RFC3339, e.g. '2024-01-15T10:30:00Z'): {}",
e
))
})?
} else {
Utc::now()
};
let datetime_str_to_store = datetime.to_rfc3339();
self.conn.execute(
"INSERT INTO memories (datetime, text, source) VALUES (?1, ?2, ?3)",
params![datetime_str_to_store, text, source],
)?;
let new_id = self.conn.last_insert_rowid();
self.cached_max_mem = new_id;
let memory = Memory {
id: new_id,
datetime,
text: text.to_string(),
source: source.map(|s| s.to_string()),
};
Ok(AddMemoryResult { memory })
}
pub fn amend(&mut self, id: i64, new_text: &str) -> Result<AmendResult, MemoryError> {
let memory = self
.get(id)?
.ok_or_else(|| MemoryError::InvalidInput(format!("Memory {} does not exist", id)))?;
let has_later_relationship: bool = self.conn.query_row(
"SELECT EXISTS(
SELECT 1 FROM relationships
WHERE (from_mem = ?1 AND to_mem > ?1)
OR (to_mem = ?1 AND from_mem > ?1)
)",
params![id],
|row| row.get(0),
)?;
if has_later_relationship {
return Err(MemoryError::InvalidInput(format!(
"Cannot amend memory {} - it has relationships to later memories",
id
)));
}
self.conn.execute(
"UPDATE memories SET text = ?1 WHERE id = ?2",
params![new_text, id],
)?;
let updated = Memory {
id: memory.id,
datetime: memory.datetime,
text: new_text.to_string(),
source: memory.source,
};
Ok(AmendResult { memory: updated })
}
pub fn search(&self, query: &str, params: &SearchParams) -> Result<SearchResult, MemoryError> {
surface_candidates(&self.conn, query, params, self.cached_max_mem)
}
pub fn strengthen(&mut self, ids: &[i64]) -> Result<StrengthenResult, MemoryError> {
if ids.len() > MAX_STRENGTHEN_SET {
return Err(MemoryError::InvalidInput(format!(
"Cannot strengthen more than {} memories at once (got {})",
MAX_STRENGTHEN_SET,
ids.len()
)));
}
if ids.is_empty() {
return Err(MemoryError::InvalidInput(
"At least one memory ID is required".to_string(),
));
}
if ids.len() == 1 {
return Err(MemoryError::InvalidInput(
"At least two memory IDs are required to create relationships".to_string(),
));
}
let tx = self.conn.transaction()?;
let mut relationships = Vec::new();
let mut event_count = 0;
let default_params = SearchParams::default();
for i in 0..ids.len() {
for j in (i + 1)..ids.len() {
let (from_mem, to_mem) = canonicalize(ids[i], ids[j]);
add_relationship_event(&tx, from_mem, to_mem, self.cached_max_mem, 1.0)?;
event_count += 1;
if let Some(rel) =
get_relationship(&tx, from_mem, to_mem, self.cached_max_mem, &default_params)?
{
relationships.push(rel);
}
}
}
tx.commit()?;
Ok(StrengthenResult {
relationships,
event_count,
})
}
pub fn connect(&mut self, ids: &[i64]) -> Result<ConnectResult, MemoryError> {
if ids.len() > MAX_STRENGTHEN_SET {
return Err(MemoryError::InvalidInput(format!(
"Cannot connect more than {} memories at once (got {})",
MAX_STRENGTHEN_SET,
ids.len()
)));
}
if ids.len() < 2 {
return Err(MemoryError::InvalidInput(
"At least two memory IDs are required to create connections".to_string(),
));
}
let tx = self.conn.transaction()?;
let mut created = Vec::new();
let mut skipped = Vec::new();
let default_params = SearchParams::default();
for i in 0..ids.len() {
for j in (i + 1)..ids.len() {
let (from_mem, to_mem) = canonicalize(ids[i], ids[j]);
if relationship_exists(&tx, from_mem, to_mem)? {
skipped.push((from_mem, to_mem));
} else {
add_relationship_event(&tx, from_mem, to_mem, self.cached_max_mem, 1.0)?;
if let Some(rel) = get_relationship(
&tx,
from_mem,
to_mem,
self.cached_max_mem,
&default_params,
)? {
created.push(rel);
}
}
}
}
tx.commit()?;
Ok(ConnectResult { created, skipped })
}
pub fn get(&self, id: i64) -> Result<Option<Memory>, MemoryError> {
let mut stmt = self
.conn
.prepare("SELECT id, datetime, text, source FROM memories WHERE id = ?1")?;
let mut rows = stmt.query(params![id])?;
if let Some(row) = rows.next()? {
Ok(Some(Self::row_to_memory(row)?))
} else {
Ok(None)
}
}
pub fn tail(&self, n: usize) -> Result<Vec<Memory>, MemoryError> {
self.list(None, None, Some(n))
}
pub fn list(
&self,
from_id: Option<i64>,
to_id: Option<i64>,
limit: Option<usize>,
) -> Result<Vec<Memory>, MemoryError> {
let mut conditions = Vec::new();
let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
if let Some(from) = from_id {
conditions.push("id >= ?".to_string());
params_vec.push(Box::new(from));
}
if let Some(to) = to_id {
conditions.push("id <= ?".to_string());
params_vec.push(Box::new(to));
}
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!(" WHERE {}", conditions.join(" AND "))
};
let limit_clause = if let Some(n) = limit {
params_vec.push(Box::new(n as i64));
" LIMIT ?".to_string()
} else {
String::new()
};
let query = format!(
"SELECT id, datetime, text, source FROM memories{} ORDER BY id DESC{}",
where_clause, limit_clause
);
let mut stmt = self.conn.prepare(&query)?;
let param_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
let rows = stmt.query_map(param_refs.as_slice(), |row| Self::row_to_memory(row))?;
let memories: Result<Vec<_>, _> = rows.collect();
Ok(memories?)
}
pub fn stray(&self, limit: usize) -> Result<Vec<Memory>, MemoryError> {
let mut stmt = self.conn.prepare(
"SELECT m.id, m.datetime, m.text, m.source
FROM memories m
WHERE NOT EXISTS (
SELECT 1 FROM relationships r
WHERE r.from_mem = m.id OR r.to_mem = m.id
)
ORDER BY RANDOM()
LIMIT ?1",
)?;
let rows = stmt.query_map(params![limit as i64], |row| Self::row_to_memory(row))?;
let memories: Result<Vec<_>, _> = rows.collect();
Ok(memories?)
}
pub fn stats(&self) -> Result<Stats, MemoryError> {
let memory_count: i64 =
self.conn
.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
let min_memory_id: Option<i64> =
self.conn
.query_row("SELECT MIN(id) FROM memories", [], |row| row.get(0))?;
let max_memory_id: Option<i64> =
self.conn
.query_row("SELECT MAX(id) FROM memories", [], |row| row.get(0))?;
let relationship_count: i64 = self.conn.query_row(
"SELECT COUNT(DISTINCT from_mem || '-' || to_mem) FROM relationships",
[],
|row| row.get(0),
)?;
let relationship_event_count: i64 =
self.conn
.query_row("SELECT COUNT(*) FROM relationships", [], |row| row.get(0))?;
let mut stmt = self.conn.prepare(
"SELECT DISTINCT source FROM memories WHERE source IS NOT NULL ORDER BY source",
)?;
let sources: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
let graph = self.compute_graph_stats()?;
Ok(Stats {
memory_count,
min_memory_id,
max_memory_id,
relationship_count,
relationship_event_count,
unique_sources: sources,
graph,
})
}
fn compute_graph_stats(&self) -> Result<GraphStats, MemoryError> {
use std::collections::{HashMap, HashSet};
let stray_count: i64 = self.conn.query_row(
"SELECT COUNT(*) FROM memories m
WHERE NOT EXISTS (
SELECT 1 FROM relationships r
WHERE r.from_mem = m.id OR r.to_mem = m.id
)",
[],
|row| row.get(0),
)?;
let mut adj: HashMap<i64, HashSet<i64>> = HashMap::new();
let mut stmt = self
.conn
.prepare("SELECT DISTINCT from_mem, to_mem FROM relationships")?;
let edges = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
for edge in edges {
let (from, to) = edge?;
adj.entry(from).or_default().insert(to);
adj.entry(to).or_default().insert(from);
}
let degrees: Vec<i64> = adj
.values()
.map(|neighbors| neighbors.len() as i64)
.collect();
let max_degree = degrees.iter().copied().max().unwrap_or(0);
let avg_degree = if degrees.is_empty() {
0.0
} else {
degrees.iter().sum::<i64>() as f64 / degrees.len() as f64
};
let leaf_count = degrees.iter().filter(|&&d| d == 1).count() as i64;
let mut visited: HashSet<i64> = HashSet::new();
let mut island_sizes: Vec<i64> = Vec::new();
for &node in adj.keys() {
if visited.contains(&node) {
continue;
}
let mut queue = vec![node];
let mut component_size = 0i64;
while let Some(current) = queue.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current);
component_size += 1;
if let Some(neighbors) = adj.get(¤t) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
queue.push(neighbor);
}
}
}
}
island_sizes.push(component_size);
}
let island_count = island_sizes.len() as i64;
let largest_island_size = island_sizes.iter().copied().max().unwrap_or(0);
Ok(GraphStats {
stray_count,
island_count,
largest_island_size,
leaf_count,
max_degree,
avg_degree,
})
}
pub fn get_many(&self, ids: &[i64]) -> Result<Vec<Memory>, MemoryError> {
if ids.is_empty() {
return Ok(vec![]);
}
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let query = format!(
"SELECT id, datetime, text, source FROM memories WHERE id IN ({})",
placeholders
);
let mut stmt = self.conn.prepare(&query)?;
let params: Vec<&dyn rusqlite::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::ToSql).collect();
let rows = stmt.query_map(params.as_slice(), |row| {
let datetime_str: String = row.get(1)?;
let datetime = DateTime::parse_from_rfc3339(&datetime_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Memory {
id: row.get(0)?,
datetime,
text: row.get(2)?,
source: row.get(3)?,
})
})?;
let memories: Result<Vec<_>, _> = rows.collect();
Ok(memories?)
}
fn row_to_memory(row: &rusqlite::Row) -> Result<Memory, rusqlite::Error> {
let datetime_str: String = row.get(1)?;
let datetime = DateTime::parse_from_rfc3339(&datetime_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Memory {
id: row.get(0)?,
datetime,
text: row.get(2)?,
source: row.get(3)?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_open_in_memory() {
let store = MemoryStore::open_in_memory().unwrap();
assert_eq!(store.max_memory_id(), 0);
}
#[test]
fn test_add_memory() {
let mut store = MemoryStore::open_in_memory().unwrap();
let result = store.add("Test memory", Some("test")).unwrap();
assert_eq!(result.memory.text, "Test memory");
assert_eq!(result.memory.source, Some("test".to_string()));
assert_eq!(store.max_memory_id(), 1);
let result2 = store.add("Second memory", None).unwrap();
assert_eq!(result2.memory.id, 2);
assert_eq!(store.max_memory_id(), 2);
}
#[test]
fn test_get_memory() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("Test memory", Some("test")).unwrap();
let mem = store.get(1).unwrap().unwrap();
assert_eq!(mem.text, "Test memory");
assert_eq!(mem.source, Some("test".to_string()));
let none = store.get(999).unwrap();
assert!(none.is_none());
}
#[test]
fn test_get_many() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("First", None).unwrap();
store.add("Second", None).unwrap();
store.add("Third", None).unwrap();
let memories = store.get_many(&[1, 3]).unwrap();
assert_eq!(memories.len(), 2);
let empty = store.get_many(&[]).unwrap();
assert!(empty.is_empty());
}
#[test]
fn test_strengthen() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("mem1", None).unwrap();
store.add("mem2", None).unwrap();
store.add("mem3", None).unwrap();
let result = store.strengthen(&[1, 2, 3]).unwrap();
assert_eq!(result.event_count, 3); assert_eq!(result.relationships.len(), 3);
for rel in &result.relationships {
assert_eq!(rel.event_count, 1);
assert!((rel.effective_strength - 1.0).abs() < 0.001);
}
let result = store.strengthen(&[1, 2]).unwrap();
assert_eq!(result.event_count, 1);
assert_eq!(result.relationships.len(), 1);
assert_eq!(result.relationships[0].event_count, 2); assert!((result.relationships[0].effective_strength - 2.0).abs() < 0.001);
}
#[test]
fn test_strengthen_validation() {
let mut store = MemoryStore::open_in_memory().unwrap();
let err = store.strengthen(&[]).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
let err = store.strengthen(&[1]).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
let ids: Vec<i64> = (1..=15).collect();
let err = store.strengthen(&ids).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
}
#[test]
fn test_search() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("First memory about cats", None).unwrap();
store.add("Second memory about dogs", None).unwrap();
store.add("Third memory about birds", None).unwrap();
let params = SearchParams::default();
let err = store.search("", ¶ms).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
let result = store.search("cats", ¶ms).unwrap();
assert!(!result.memories.is_empty());
assert!(result.memories[0].memory.text.contains("cats"));
}
#[test]
fn test_list() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("First", Some("src1")).unwrap();
store.add("Second", Some("src1")).unwrap();
store.add("Third", Some("src2")).unwrap();
store.add("Fourth", None).unwrap();
store.add("Fifth", Some("src2")).unwrap();
let all = store.list(None, None, None).unwrap();
assert_eq!(all.len(), 5);
assert_eq!(all[0].id, 5);
assert_eq!(all[4].id, 1);
let from_3 = store.list(Some(3), None, None).unwrap();
assert_eq!(from_3.len(), 3);
assert_eq!(from_3[0].id, 5);
assert_eq!(from_3[2].id, 3);
let to_3 = store.list(None, Some(3), None).unwrap();
assert_eq!(to_3.len(), 3);
assert_eq!(to_3[0].id, 3);
assert_eq!(to_3[2].id, 1);
let range = store.list(Some(2), Some(4), None).unwrap();
assert_eq!(range.len(), 3);
assert_eq!(range[0].id, 4);
assert_eq!(range[2].id, 2);
let limited = store.list(None, None, Some(2)).unwrap();
assert_eq!(limited.len(), 2);
assert_eq!(limited[0].id, 5);
assert_eq!(limited[1].id, 4);
let range_limited = store.list(Some(1), Some(5), Some(2)).unwrap();
assert_eq!(range_limited.len(), 2);
}
#[test]
fn test_tail_uses_list() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("First", None).unwrap();
store.add("Second", None).unwrap();
store.add("Third", None).unwrap();
let tail = store.tail(2).unwrap();
assert_eq!(tail.len(), 2);
assert_eq!(tail[0].id, 3);
assert_eq!(tail[1].id, 2);
}
#[test]
fn test_stats() {
let mut store = MemoryStore::open_in_memory().unwrap();
let stats = store.stats().unwrap();
assert_eq!(stats.memory_count, 0);
assert_eq!(stats.min_memory_id, None);
assert_eq!(stats.max_memory_id, None);
assert_eq!(stats.relationship_count, 0);
assert_eq!(stats.relationship_event_count, 0);
assert!(stats.unique_sources.is_empty());
store.add("First", Some("src1")).unwrap();
store.add("Second", Some("src2")).unwrap();
store.add("Third", Some("src1")).unwrap();
store.add("Fourth", None).unwrap();
let stats = store.stats().unwrap();
assert_eq!(stats.memory_count, 4);
assert_eq!(stats.min_memory_id, Some(1));
assert_eq!(stats.max_memory_id, Some(4));
assert_eq!(stats.unique_sources, vec!["src1", "src2"]);
assert_eq!(stats.graph.stray_count, 4); assert_eq!(stats.graph.island_count, 0); assert_eq!(stats.graph.largest_island_size, 0);
store.strengthen(&[1, 2]).unwrap();
store.strengthen(&[1, 2, 3]).unwrap();
let stats = store.stats().unwrap();
assert_eq!(stats.relationship_count, 3); assert_eq!(stats.relationship_event_count, 4);
assert_eq!(stats.graph.stray_count, 1); assert_eq!(stats.graph.island_count, 1); assert_eq!(stats.graph.largest_island_size, 3);
assert_eq!(stats.graph.leaf_count, 0); assert_eq!(stats.graph.max_degree, 2);
}
#[test]
fn test_connect() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("mem1", None).unwrap();
store.add("mem2", None).unwrap();
store.add("mem3", None).unwrap();
store.add("mem4", None).unwrap();
let result = store.connect(&[1, 2]).unwrap();
assert_eq!(result.created.len(), 1);
assert!(result.skipped.is_empty());
assert_eq!(result.created[0].from_mem, 1);
assert_eq!(result.created[0].to_mem, 2);
let result = store.connect(&[1, 2]).unwrap();
assert!(result.created.is_empty());
assert_eq!(result.skipped.len(), 1);
assert_eq!(result.skipped[0], (1, 2));
let result = store.connect(&[1, 2, 3]).unwrap();
assert_eq!(result.created.len(), 2);
assert_eq!(result.skipped.len(), 1);
assert_eq!(result.skipped[0], (1, 2));
let stats = store.stats().unwrap();
assert_eq!(stats.relationship_count, 3); assert_eq!(stats.relationship_event_count, 3); }
#[test]
fn test_connect_validation() {
let mut store = MemoryStore::open_in_memory().unwrap();
let err = store.connect(&[1]).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
let err = store.connect(&[]).unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
}
#[test]
fn test_amend() {
let mut store = MemoryStore::open_in_memory().unwrap();
store.add("Original text", Some("test")).unwrap();
store.add("Second memory", None).unwrap();
store.add("Third memory", None).unwrap();
let result = store.amend(3, "Updated third").unwrap();
assert_eq!(result.memory.id, 3);
assert_eq!(result.memory.text, "Updated third");
let mem = store.get(3).unwrap().unwrap();
assert_eq!(mem.text, "Updated third");
let result = store.amend(1, "Updated first").unwrap();
assert_eq!(result.memory.text, "Updated first");
store.connect(&[1, 2]).unwrap();
let err = store.amend(1, "Try again").unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
let result = store.amend(2, "Updated second").unwrap();
assert_eq!(result.memory.text, "Updated second");
let err = store.amend(999, "Nope").unwrap_err();
assert!(matches!(err, MemoryError::InvalidInput(_)));
}
}