use chrono::Utc;
use storage::SqlitePool;
use thiserror::Error;
use uuid::Uuid;
#[derive(Debug, Error)]
pub enum EpisodicError {
#[error("SQLite error: {0}")]
Sqlite(#[from] storage::sqlite::SqliteError),
#[error("Episode not found: {0}")]
NotFound(String),
}
#[derive(Debug, Clone)]
pub struct Episode {
pub id: String,
pub session_id: String,
pub namespace: String,
pub role: String,
pub content: String,
pub timestamp: String,
pub importance: f64,
pub decay_rate: f64,
pub reinforcement_count: i32,
pub last_accessed: Option<String>,
pub agent: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Session {
pub id: String,
pub started_at: String,
pub ended_at: Option<String>,
pub channel: String,
}
#[derive(Debug, Clone)]
pub struct FtsResult {
pub episode_id: String,
pub content: String,
pub rank: f64,
pub timestamp: String,
pub agent: Option<String>,
pub importance: f64,
}
pub(crate) fn sanitize_fts5_query(query: &str) -> String {
query
.chars()
.map(|c| {
if c.is_alphanumeric() || c.is_whitespace() {
c
} else {
' '
}
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
pub struct EpisodicStore {
db: SqlitePool,
}
impl EpisodicStore {
pub fn new(db: SqlitePool) -> Self {
Self { db }
}
pub fn pool(&self) -> &SqlitePool {
&self.db
}
pub fn create_session(&self, channel: &str) -> Result<String, EpisodicError> {
let id = Uuid::new_v4().to_string();
self.db.with_conn(|conn| {
conn.execute(
"INSERT INTO sessions (id, channel) VALUES (?1, ?2)",
rusqlite::params![id, channel],
)?;
Ok(id.clone())
})?;
Ok(id)
}
pub fn ensure_session(&self, session_id: &str, channel: &str) -> Result<(), EpisodicError> {
self.db.with_conn(|conn| {
conn.execute(
"INSERT OR IGNORE INTO sessions (id, channel) VALUES (?1, ?2)",
rusqlite::params![session_id, channel],
)?;
Ok(())
})?;
Ok(())
}
pub fn end_session(&self, session_id: &str) -> Result<(), EpisodicError> {
let now = Utc::now().to_rfc3339();
self.db.with_conn(|conn| {
conn.execute(
"UPDATE sessions SET ended_at = ?1 WHERE id = ?2",
rusqlite::params![now, session_id],
)?;
Ok(())
})?;
Ok(())
}
pub fn get_session(&self, session_id: &str) -> Result<Session, EpisodicError> {
let result = self.db.with_conn(|conn| {
conn.query_row(
"SELECT id, started_at, ended_at, channel FROM sessions WHERE id = ?1",
[session_id],
|row| {
Ok(Session {
id: row.get(0)?,
started_at: row.get(1)?,
ended_at: row.get(2)?,
channel: row.get(3)?,
})
},
)
.map_err(|e| e.into())
});
match result {
Ok(session) => Ok(session),
Err(storage::sqlite::SqliteError::Rusqlite(rusqlite::Error::QueryReturnedNoRows)) => {
Err(EpisodicError::NotFound(session_id.to_string()))
}
Err(e) => Err(EpisodicError::Sqlite(e)),
}
}
pub fn store_episode(
&self,
session_id: &str,
role: &str,
content: &str,
importance: f64,
namespace: Option<&str>,
agent: Option<&str>,
) -> Result<String, EpisodicError> {
let id = Uuid::new_v4().to_string();
let encrypted_content = self.db.encrypt_content(content);
let is_encrypted = self.db.is_encrypted();
let namespace = namespace.unwrap_or("personal");
self.db.with_conn(|conn| {
let tx = conn.unchecked_transaction()?;
tx.execute(
"INSERT INTO episodes (id, session_id, namespace, role, content, importance, agent)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
rusqlite::params![
id,
session_id,
namespace,
role,
encrypted_content,
importance,
agent
],
)?;
let row_id = conn.last_insert_rowid();
if !is_encrypted {
tx.execute(
"INSERT INTO episodes_fts (rowid, content) VALUES (?1, ?2)",
rusqlite::params![row_id, content],
)?;
}
tx.commit()?;
Ok(())
})?;
Ok(id)
}
pub fn get_session_history(
&self,
session_id: &str,
limit: usize,
) -> Result<Vec<Episode>, EpisodicError> {
let pool = &self.db;
Ok(self.db.with_conn(|conn| {
let mut stmt = conn.prepare(
"SELECT id, session_id, role, content, timestamp,
namespace, importance, decay_rate, reinforcement_count, last_accessed, agent
FROM episodes
WHERE session_id = ?1
ORDER BY timestamp ASC
LIMIT ?2",
)?;
let episodes = stmt
.query_map(rusqlite::params![session_id, limit as i64], |row| {
let raw: String = row.get(3)?;
Ok((
Episode {
id: row.get(0)?,
session_id: row.get(1)?,
role: row.get(2)?,
content: String::new(),
timestamp: row.get(4)?,
namespace: row.get(5)?,
importance: row.get(6)?,
decay_rate: row.get(7)?,
reinforcement_count: row.get(8)?,
last_accessed: row.get(9)?,
agent: row.get(10)?,
},
raw,
))
})?
.filter_map(|r| {
let (mut ep, raw) = r.ok()?;
ep.content = pool.try_decrypt_content(&raw)?;
Some(ep)
})
.collect::<Vec<_>>();
Ok(episodes)
})?)
}
pub fn get_episode(&self, episode_id: &str) -> Result<Option<Episode>, EpisodicError> {
let pool = &self.db;
Ok(self.db.with_conn(|conn| {
let mut stmt = conn.prepare(
"SELECT id, session_id, role, content, timestamp,
namespace, importance, decay_rate, reinforcement_count, last_accessed, agent
FROM episodes
WHERE id = ?1",
)?;
let mut rows = stmt.query([episode_id])?;
if let Some(row) = rows.next()? {
let raw: String = row.get(3)?;
let content = pool.try_decrypt_content(&raw).unwrap_or(raw);
Ok(Some(Episode {
id: row.get(0)?,
session_id: row.get(1)?,
role: row.get(2)?,
content,
timestamp: row.get(4)?,
namespace: row.get(5)?,
importance: row.get(6)?,
decay_rate: row.get(7)?,
reinforcement_count: row.get(8)?,
last_accessed: row.get(9)?,
agent: row.get(10)?,
}))
} else {
Ok(None)
}
})?)
}
pub fn reinforce(&self, episode_id: &str) -> Result<(), EpisodicError> {
let now = Utc::now().to_rfc3339();
let rows = self.db.with_conn(|conn| {
let rows = conn.execute(
"UPDATE episodes SET reinforcement_count = reinforcement_count + 1,
last_accessed = ?1
WHERE id = ?2",
rusqlite::params![now, episode_id],
)?;
Ok(rows)
})?;
if rows == 0 {
return Err(EpisodicError::NotFound(episode_id.to_string()));
}
Ok(())
}
pub fn search_bm25(
&self,
query: &str,
limit: usize,
namespace: Option<&str>,
agent: Option<&str>,
) -> Result<Vec<FtsResult>, EpisodicError> {
let sanitized = sanitize_fts5_query(query);
if sanitized.is_empty() {
return Ok(Vec::new());
}
Ok(self.db.with_conn(|conn| {
let mut sql = String::from(
"SELECT e.id, f.content, f.rank, e.timestamp, e.agent, e.importance
FROM episodes_fts f
JOIN episodes e ON e.rowid = f.rowid
WHERE episodes_fts MATCH ?1",
);
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(sanitized)];
if let Some(ns) = namespace {
sql.push_str(&format!(
" AND (e.namespace = ?{} OR e.namespace LIKE ?{})",
params.len() + 1,
params.len() + 2
));
params.push(Box::new(ns.to_string()));
params.push(Box::new(format!("{}/%", ns)));
}
if let Some(a) = agent {
sql.push_str(&format!(" AND e.agent = ?{}", params.len() + 1));
params.push(Box::new(a.to_string()));
}
sql.push_str(&format!(" ORDER BY f.rank LIMIT ?{}", params.len() + 1));
params.push(Box::new(limit as i64));
let mut stmt = conn.prepare(&sql)?;
let param_refs: Vec<&dyn rusqlite::types::ToSql> =
params.iter().map(|p| p.as_ref()).collect();
let results = stmt
.query_map(param_refs.as_slice(), |row| {
Ok(FtsResult {
episode_id: row.get(0)?,
content: row.get(1)?,
rank: row.get(2)?,
timestamp: row.get(3)?,
agent: row.get(4)?,
importance: row.get(5)?,
})
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
})?)
}
pub fn count(&self) -> Result<i64, EpisodicError> {
Ok(self.db.with_conn(|conn| {
let count: i64 =
conn.query_row("SELECT COUNT(*) FROM episodes", [], |row| row.get(0))?;
Ok(count)
})?)
}
pub fn has_episodes_in_namespace(
&self,
namespace: Option<&str>,
) -> Result<bool, EpisodicError> {
Ok(self.db.with_conn(|conn| {
let exists: i64 = if let Some(ns) = namespace {
let prefix = format!("{ns}/%");
conn.query_row(
"SELECT EXISTS(SELECT 1 FROM episodes
WHERE namespace = ?1 OR namespace LIKE ?2
LIMIT 1)",
rusqlite::params![ns, &prefix],
|row| row.get(0),
)?
} else {
conn.query_row("SELECT EXISTS(SELECT 1 FROM episodes LIMIT 1)", [], |row| {
row.get(0)
})?
};
Ok(exists != 0)
})?)
}
pub fn recent(
&self,
limit: usize,
namespace: Option<&str>,
) -> Result<Vec<Episode>, EpisodicError> {
let pool = &self.db;
Ok(self.db.with_conn(|conn| {
if let Some(ns) = namespace {
let mut stmt = conn.prepare(
"SELECT id, session_id, role, content, timestamp,
namespace, importance, decay_rate, reinforcement_count, last_accessed, agent
FROM episodes
WHERE namespace = ?1 OR namespace LIKE ?2
ORDER BY timestamp DESC
LIMIT ?3",
)?;
let prefix = format!("{}/%", ns);
let row_to_raw = |row: &rusqlite::Row<'_>| -> rusqlite::Result<(Episode, String)> {
let raw: String = row.get(3)?;
Ok((Episode {
id: row.get(0)?,
session_id: row.get(1)?,
role: row.get(2)?,
content: String::new(),
timestamp: row.get(4)?,
namespace: row.get(5)?,
importance: row.get(6)?,
decay_rate: row.get(7)?,
reinforcement_count: row.get(8)?,
last_accessed: row.get(9)?,
agent: row.get(10)?,
}, raw))
};
let decrypt_filter = |r: rusqlite::Result<(Episode, String)>| -> Option<Episode> {
let (mut ep, raw) = r.ok()?;
ep.content = pool.try_decrypt_content(&raw)?;
Some(ep)
};
let episodes: Vec<Episode> = stmt
.query_map(rusqlite::params![ns, &prefix, limit as i64], row_to_raw)?
.filter_map(decrypt_filter)
.collect();
Ok(episodes)
} else {
let mut stmt = conn.prepare(
"SELECT id, session_id, role, content, timestamp,
namespace, importance, decay_rate, reinforcement_count, last_accessed, agent
FROM episodes
ORDER BY timestamp DESC
LIMIT ?1",
)?;
let row_to_raw = |row: &rusqlite::Row<'_>| -> rusqlite::Result<(Episode, String)> {
let raw: String = row.get(3)?;
Ok((Episode {
id: row.get(0)?,
session_id: row.get(1)?,
role: row.get(2)?,
content: String::new(),
timestamp: row.get(4)?,
namespace: row.get(5)?,
importance: row.get(6)?,
decay_rate: row.get(7)?,
reinforcement_count: row.get(8)?,
last_accessed: row.get(9)?,
agent: row.get(10)?,
}, raw))
};
let decrypt_filter = |r: rusqlite::Result<(Episode, String)>| -> Option<Episode> {
let (mut ep, raw) = r.ok()?;
ep.content = pool.try_decrypt_content(&raw)?;
Some(ep)
};
let episodes: Vec<Episode> = stmt
.query_map([limit as i64], row_to_raw)?
.filter_map(decrypt_filter)
.collect();
Ok(episodes)
}
})?)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_store() -> EpisodicStore {
let pool = SqlitePool::open_memory().unwrap();
EpisodicStore::new(pool)
}
#[test]
fn test_create_session() {
let store = test_store();
let id = store.create_session("cli").unwrap();
assert!(!id.is_empty());
let session = store.get_session(&id).unwrap();
assert_eq!(session.channel, "cli");
assert!(session.ended_at.is_none());
}
#[test]
fn test_end_session() {
let store = test_store();
let id = store.create_session("cli").unwrap();
store.end_session(&id).unwrap();
let session = store.get_session(&id).unwrap();
assert!(session.ended_at.is_some());
}
#[test]
fn test_store_and_retrieve_episodes() {
let store = test_store();
let session = store.create_session("cli").unwrap();
store
.store_episode(&session, "user", "Hello Brain!", 0.5, None, None)
.unwrap();
store
.store_episode(
&session,
"assistant",
"Hello! How can I help?",
0.5,
None,
None,
)
.unwrap();
store
.store_episode(&session, "user", "What's the weather?", 0.3, None, None)
.unwrap();
let history = store.get_session_history(&session, 10).unwrap();
assert_eq!(history.len(), 3);
assert_eq!(history[0].role, "user");
assert_eq!(history[0].content, "Hello Brain!");
assert_eq!(history[1].role, "assistant");
}
#[test]
fn test_episode_count() {
let store = test_store();
let session = store.create_session("cli").unwrap();
assert_eq!(store.count().unwrap(), 0);
store
.store_episode(&session, "user", "Test message", 0.5, None, None)
.unwrap();
assert_eq!(store.count().unwrap(), 1);
}
#[test]
fn test_reinforce() {
let store = test_store();
let session = store.create_session("cli").unwrap();
let ep_id = store
.store_episode(&session, "user", "Important fact", 0.8, None, None)
.unwrap();
let history = store.get_session_history(&session, 10).unwrap();
assert_eq!(history[0].reinforcement_count, 0);
store.reinforce(&ep_id).unwrap();
store.reinforce(&ep_id).unwrap();
let history = store.get_session_history(&session, 10).unwrap();
assert_eq!(history[0].reinforcement_count, 2);
assert!(history[0].last_accessed.is_some());
}
#[test]
fn test_bm25_search() {
let store = test_store();
let session = store.create_session("cli").unwrap();
store
.store_episode(
&session,
"user",
"I love programming in Rust",
0.7,
None,
None,
)
.unwrap();
store
.store_episode(
&session,
"user",
"Python is great for scripting",
0.5,
None,
None,
)
.unwrap();
store
.store_episode(
&session,
"user",
"Rust has amazing performance",
0.8,
None,
None,
)
.unwrap();
let results = store.search_bm25("Rust", 10, None, None).unwrap();
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.content.contains("Rust")));
}
#[test]
fn test_recent_episodes() {
let store = test_store();
let s1 = store.create_session("cli").unwrap();
let s2 = store.create_session("whatsapp").unwrap();
store
.store_episode(&s1, "user", "First message", 0.5, None, None)
.unwrap();
store
.store_episode(&s2, "user", "Second message", 0.5, None, None)
.unwrap();
let recent = store.recent(10, None).unwrap();
assert_eq!(recent.len(), 2);
let contents: Vec<&str> = recent.iter().map(|e| e.content.as_str()).collect();
assert!(contents.contains(&"First message"));
assert!(contents.contains(&"Second message"));
}
#[test]
fn test_namespace_filtered_search_and_recent() {
let store = test_store();
let session = store.create_session("cli").unwrap();
store
.store_episode(
&session,
"user",
"Rust memory model notes",
0.7,
Some("work"),
None,
)
.unwrap();
store
.store_episode(
&session,
"user",
"Rust hobby project",
0.7,
Some("personal"),
None,
)
.unwrap();
let work_hits = store.search_bm25("Rust", 10, Some("work"), None).unwrap();
assert_eq!(work_hits.len(), 1);
assert!(work_hits[0].content.contains("memory model"));
let personal_recent = store.recent(10, Some("personal")).unwrap();
assert_eq!(personal_recent.len(), 1);
assert_eq!(personal_recent[0].namespace, "personal");
}
#[test]
fn test_search_bm25_apostrophe_query_no_syntax_error() {
let store = test_store();
let session = store.create_session("cli").unwrap();
store
.store_episode(
&session,
"user",
"I've completed Brain project using Rust programming language",
0.8,
None,
None,
)
.unwrap();
let results = store
.search_bm25(
"I've completed brain project using rust programing language.",
10,
None,
None,
)
.unwrap();
assert!(!results.is_empty());
}
}