use rusqlite::{params, Connection, OptionalExtension};
use serde::{Deserialize, Serialize};
use std::fmt;
use crate::error::{EngramError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum ScopeLevel {
Global = 0,
Org = 1,
User = 2,
Session = 3,
Agent = 4,
}
impl fmt::Display for ScopeLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScopeLevel::Global => write!(f, "global"),
ScopeLevel::Org => write!(f, "org"),
ScopeLevel::User => write!(f, "user"),
ScopeLevel::Session => write!(f, "session"),
ScopeLevel::Agent => write!(f, "agent"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MemoryScope {
pub level: ScopeLevel,
pub path: String,
}
impl MemoryScope {
pub fn new(level: ScopeLevel, path: impl Into<String>) -> Result<Self> {
let path = path.into();
let expected_segments = level as usize + 1; let actual_segments = path.split('/').count();
if actual_segments != expected_segments {
return Err(EngramError::InvalidInput(format!(
"scope path '{}' has {} segment(s) but level {:?} requires {}",
path, actual_segments, level, expected_segments
)));
}
if !path.starts_with("global") {
return Err(EngramError::InvalidInput(format!(
"scope path must start with 'global', got '{}'",
path
)));
}
Ok(Self { level, path })
}
pub fn global() -> Self {
Self {
level: ScopeLevel::Global,
path: "global".to_string(),
}
}
pub fn parse(path: &str) -> Result<Self> {
let segments: Vec<&str> = path.split('/').collect();
if segments.is_empty() || segments[0] != "global" {
return Err(EngramError::InvalidInput(format!(
"scope path must start with 'global', got '{}'",
path
)));
}
let level = match segments.len() {
1 => ScopeLevel::Global,
2 => ScopeLevel::Org,
3 => ScopeLevel::User,
4 => ScopeLevel::Session,
5 => ScopeLevel::Agent,
n => {
return Err(EngramError::InvalidInput(format!(
"scope path has {} segments; maximum supported depth is 5 (Agent)",
n
)))
}
};
Ok(Self {
level,
path: path.to_string(),
})
}
pub fn parent(&self) -> Option<MemoryScope> {
if self.level == ScopeLevel::Global {
return None;
}
let last_slash = self.path.rfind('/')?;
let parent_path = &self.path[..last_slash];
let parent_level = match self.level {
ScopeLevel::Org => ScopeLevel::Global,
ScopeLevel::User => ScopeLevel::Org,
ScopeLevel::Session => ScopeLevel::User,
ScopeLevel::Agent => ScopeLevel::Session,
ScopeLevel::Global => unreachable!(),
};
Some(MemoryScope {
level: parent_level,
path: parent_path.to_string(),
})
}
pub fn ancestors(&self) -> Vec<MemoryScope> {
let mut result = Vec::new();
let mut current = self.parent();
while let Some(scope) = current {
current = scope.parent();
result.push(scope);
}
result
}
pub fn contains(&self, other: &MemoryScope) -> bool {
if self == other {
return true;
}
if other.level <= self.level {
return false;
}
other.path.starts_with(&format!("{}/", self.path))
}
}
impl fmt::Display for MemoryScope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.path)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScopeNode {
pub scope: MemoryScope,
pub memory_count: i64,
pub children: Vec<ScopeNode>,
}
pub fn set_scope(conn: &Connection, memory_id: i64, scope: &MemoryScope) -> Result<()> {
let rows = conn.execute(
"UPDATE memories SET scope_path = ?1 WHERE id = ?2",
params![scope.path, memory_id],
)?;
if rows == 0 {
return Err(EngramError::NotFound(memory_id));
}
Ok(())
}
pub fn get_scope(conn: &Connection, memory_id: i64) -> Result<MemoryScope> {
let path: Option<String> = conn
.query_row(
"SELECT scope_path FROM memories WHERE id = ?1",
params![memory_id],
|row| row.get(0),
)
.optional()?;
match path {
Some(p) => MemoryScope::parse(&p),
None => Err(EngramError::NotFound(memory_id)),
}
}
pub fn list_scopes(conn: &Connection) -> Result<Vec<MemoryScope>> {
let mut stmt =
conn.prepare("SELECT DISTINCT scope_path FROM memories WHERE scope_path IS NOT NULL")?;
let scopes = stmt
.query_map([], |row| row.get::<_, String>(0))?
.filter_map(|r| r.ok())
.filter_map(|path| MemoryScope::parse(&path).ok())
.collect();
Ok(scopes)
}
pub fn move_scope(conn: &Connection, memory_id: i64, new_scope: &MemoryScope) -> Result<()> {
set_scope(conn, memory_id, new_scope)
}
pub fn search_scoped(conn: &Connection, query: &str, scope: &MemoryScope) -> Result<Vec<i64>> {
let mut paths: Vec<String> = vec![scope.path.clone()];
for ancestor in scope.ancestors() {
paths.push(ancestor.path);
}
let placeholders: Vec<String> = paths.iter().map(|_| "?".to_string()).collect();
let in_clause = placeholders.join(", ");
let sql = format!(
"SELECT id FROM memories WHERE content LIKE ? AND scope_path IN ({}) ORDER BY id DESC",
in_clause
);
let like_query = format!("%{}%", query);
let mut stmt = conn.prepare(&sql)?;
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
param_values.push(Box::new(like_query));
for p in &paths {
param_values.push(Box::new(p.clone()));
}
let refs: Vec<&dyn rusqlite::types::ToSql> = param_values.iter().map(|b| b.as_ref()).collect();
let ids: Vec<i64> = stmt
.query_map(refs.as_slice(), |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
Ok(ids)
}
pub fn scope_tree(conn: &Connection) -> Result<Vec<ScopeNode>> {
let mut stmt = conn.prepare(
"SELECT scope_path, COUNT(*) as cnt FROM memories
WHERE scope_path IS NOT NULL
GROUP BY scope_path
ORDER BY scope_path",
)?;
let rows: Vec<(String, i64)> = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
.filter_map(|r| r.ok())
.collect();
let mut nodes: Vec<ScopeNode> = rows
.into_iter()
.filter_map(|(path, count)| {
MemoryScope::parse(&path).ok().map(|scope| ScopeNode {
scope,
memory_count: count,
children: Vec::new(),
})
})
.collect();
nodes.sort_by_key(|n| n.scope.level as usize);
build_tree(nodes)
}
fn build_tree(mut nodes: Vec<ScopeNode>) -> Result<Vec<ScopeNode>> {
nodes.sort_by(|a, b| (b.scope.level as usize).cmp(&(a.scope.level as usize)));
let mut roots: Vec<ScopeNode> = Vec::new();
while let Some(node) = nodes.pop() {
if node.scope.level == ScopeLevel::Global {
roots.push(node);
continue;
}
let parent_path = match node.scope.parent() {
Some(p) => p.path,
None => {
roots.push(node);
continue;
}
};
if let Some(parent) = nodes.iter_mut().find(|n| n.scope.path == parent_path) {
parent.children.push(node);
} else {
roots.push(node);
}
}
Ok(roots)
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::Connection;
fn setup_db() -> Connection {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY,
content TEXT NOT NULL,
scope_path TEXT DEFAULT 'global'
);",
)
.unwrap();
conn
}
fn insert(conn: &Connection, id: i64, content: &str, scope: &str) {
conn.execute(
"INSERT INTO memories (id, content, scope_path) VALUES (?1, ?2, ?3)",
params![id, content, scope],
)
.unwrap();
}
#[test]
fn test_parse_global() {
let s = MemoryScope::parse("global").unwrap();
assert_eq!(s.level, ScopeLevel::Global);
assert_eq!(s.path, "global");
}
#[test]
fn test_parse_org() {
let s = MemoryScope::parse("global/org:acme").unwrap();
assert_eq!(s.level, ScopeLevel::Org);
assert_eq!(s.path, "global/org:acme");
}
#[test]
fn test_parse_user() {
let s = MemoryScope::parse("global/org:acme/user:alice").unwrap();
assert_eq!(s.level, ScopeLevel::User);
}
#[test]
fn test_parse_invalid_prefix() {
assert!(MemoryScope::parse("org:acme").is_err());
}
#[test]
fn test_parse_too_deep() {
assert!(MemoryScope::parse("global/org:a/user:b/session:c/agent:d/extra:e").is_err());
}
#[test]
fn test_parent() {
let agent = MemoryScope::parse("global/org:acme/user:alice/session:s1/agent:bot").unwrap();
let session = agent.parent().unwrap();
assert_eq!(session.level, ScopeLevel::Session);
assert_eq!(session.path, "global/org:acme/user:alice/session:s1");
let user = session.parent().unwrap();
assert_eq!(user.level, ScopeLevel::User);
let org = user.parent().unwrap();
assert_eq!(org.level, ScopeLevel::Org);
let global = org.parent().unwrap();
assert_eq!(global.level, ScopeLevel::Global);
assert!(global.parent().is_none());
}
#[test]
fn test_ancestors() {
let user = MemoryScope::parse("global/org:acme/user:alice").unwrap();
let ancestors = user.ancestors();
assert_eq!(ancestors.len(), 2);
assert_eq!(ancestors[0].level, ScopeLevel::Org);
assert_eq!(ancestors[1].level, ScopeLevel::Global);
}
#[test]
fn test_contains_parent_contains_child() {
let global = MemoryScope::global();
let org = MemoryScope::parse("global/org:acme").unwrap();
let user = MemoryScope::parse("global/org:acme/user:alice").unwrap();
assert!(global.contains(&org));
assert!(global.contains(&user));
assert!(org.contains(&user));
}
#[test]
fn test_contains_child_does_not_contain_parent() {
let global = MemoryScope::global();
let org = MemoryScope::parse("global/org:acme").unwrap();
assert!(!org.contains(&global));
}
#[test]
fn test_contains_sibling_false() {
let alice = MemoryScope::parse("global/org:acme/user:alice").unwrap();
let bob = MemoryScope::parse("global/org:acme/user:bob").unwrap();
assert!(!alice.contains(&bob));
assert!(!bob.contains(&alice));
}
#[test]
fn test_contains_self_true() {
let s = MemoryScope::global();
assert!(s.contains(&s));
}
#[test]
fn test_set_and_get_scope() {
let conn = setup_db();
insert(&conn, 1, "hello", "global");
let new_scope = MemoryScope::parse("global/org:acme").unwrap();
set_scope(&conn, 1, &new_scope).unwrap();
let retrieved = get_scope(&conn, 1).unwrap();
assert_eq!(retrieved, new_scope);
}
#[test]
fn test_get_scope_not_found() {
let conn = setup_db();
let err = get_scope(&conn, 999).unwrap_err();
assert!(matches!(err, EngramError::NotFound(999)));
}
#[test]
fn test_set_scope_not_found() {
let conn = setup_db();
let scope = MemoryScope::global();
let err = set_scope(&conn, 999, &scope).unwrap_err();
assert!(matches!(err, EngramError::NotFound(999)));
}
#[test]
fn test_search_scoped_ancestor_inheritance() {
let conn = setup_db();
insert(&conn, 1, "common knowledge", "global");
insert(&conn, 2, "acme org policy", "global/org:acme");
insert(
&conn,
3,
"alice personal note",
"global/org:acme/user:alice",
);
insert(&conn, 4, "bob personal note", "global/org:acme/user:bob");
let alice_scope = MemoryScope::parse("global/org:acme/user:alice").unwrap();
let ids = search_scoped(&conn, "knowledge", &alice_scope).unwrap();
assert!(ids.contains(&1), "global memory should be visible");
assert!(!ids.contains(&3));
let ids = search_scoped(&conn, "policy", &alice_scope).unwrap();
assert!(ids.contains(&2), "org memory should be visible");
let ids = search_scoped(&conn, "alice", &alice_scope).unwrap();
assert!(ids.contains(&3));
let ids = search_scoped(&conn, "bob", &alice_scope).unwrap();
assert!(
!ids.contains(&4),
"bob's memory must not be visible to alice"
);
}
#[test]
fn test_move_scope() {
let conn = setup_db();
insert(&conn, 1, "memory", "global");
let new_scope = MemoryScope::parse("global/org:acme/user:alice").unwrap();
move_scope(&conn, 1, &new_scope).unwrap();
let retrieved = get_scope(&conn, 1).unwrap();
assert_eq!(retrieved.path, "global/org:acme/user:alice");
}
#[test]
fn test_scope_tree() {
let conn = setup_db();
insert(&conn, 1, "a", "global");
insert(&conn, 2, "b", "global");
insert(&conn, 3, "c", "global/org:acme");
insert(&conn, 4, "d", "global/org:acme/user:alice");
let tree = scope_tree(&conn).unwrap();
let global_node = tree.iter().find(|n| n.scope.level == ScopeLevel::Global);
assert!(global_node.is_some(), "global node must be present");
let global_node = global_node.unwrap();
assert_eq!(global_node.memory_count, 2); }
#[test]
fn test_global_has_no_parent() {
let global = MemoryScope::global();
assert!(global.parent().is_none());
assert!(global.ancestors().is_empty());
}
#[test]
fn test_display_scope_level() {
assert_eq!(ScopeLevel::Global.to_string(), "global");
assert_eq!(ScopeLevel::Org.to_string(), "org");
assert_eq!(ScopeLevel::User.to_string(), "user");
assert_eq!(ScopeLevel::Session.to_string(), "session");
assert_eq!(ScopeLevel::Agent.to_string(), "agent");
}
#[test]
fn test_display_memory_scope() {
let s = MemoryScope::parse("global/org:acme/user:alice").unwrap();
assert_eq!(s.to_string(), "global/org:acme/user:alice");
}
#[test]
fn test_list_scopes() {
let conn = setup_db();
insert(&conn, 1, "a", "global");
insert(&conn, 2, "b", "global/org:acme");
insert(&conn, 3, "c", "global/org:acme");
let scopes = list_scopes(&conn).unwrap();
assert_eq!(scopes.len(), 2);
let paths: Vec<&str> = scopes.iter().map(|s| s.path.as_str()).collect();
assert!(paths.contains(&"global"));
assert!(paths.contains(&"global/org:acme"));
}
}