1use rusqlite::{params, Connection, OptionalExtension};
17use serde::{Deserialize, Serialize};
18use std::fmt;
19
20use crate::error::{EngramError, Result};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
26pub enum ScopeLevel {
27 Global = 0,
28 Org = 1,
29 User = 2,
30 Session = 3,
31 Agent = 4,
32}
33
34impl fmt::Display for ScopeLevel {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 ScopeLevel::Global => write!(f, "global"),
38 ScopeLevel::Org => write!(f, "org"),
39 ScopeLevel::User => write!(f, "user"),
40 ScopeLevel::Session => write!(f, "session"),
41 ScopeLevel::Agent => write!(f, "agent"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct MemoryScope {
54 pub level: ScopeLevel,
55 pub path: String,
56}
57
58impl MemoryScope {
59 pub fn new(level: ScopeLevel, path: impl Into<String>) -> Result<Self> {
61 let path = path.into();
62 let expected_segments = level as usize + 1; let actual_segments = path.split('/').count();
64 if actual_segments != expected_segments {
65 return Err(EngramError::InvalidInput(format!(
66 "scope path '{}' has {} segment(s) but level {:?} requires {}",
67 path, actual_segments, level, expected_segments
68 )));
69 }
70 if !path.starts_with("global") {
72 return Err(EngramError::InvalidInput(format!(
73 "scope path must start with 'global', got '{}'",
74 path
75 )));
76 }
77 Ok(Self { level, path })
78 }
79
80 pub fn global() -> Self {
82 Self {
83 level: ScopeLevel::Global,
84 path: "global".to_string(),
85 }
86 }
87
88 pub fn parse(path: &str) -> Result<Self> {
93 let segments: Vec<&str> = path.split('/').collect();
94 if segments.is_empty() || segments[0] != "global" {
95 return Err(EngramError::InvalidInput(format!(
96 "scope path must start with 'global', got '{}'",
97 path
98 )));
99 }
100 let level = match segments.len() {
101 1 => ScopeLevel::Global,
102 2 => ScopeLevel::Org,
103 3 => ScopeLevel::User,
104 4 => ScopeLevel::Session,
105 5 => ScopeLevel::Agent,
106 n => {
107 return Err(EngramError::InvalidInput(format!(
108 "scope path has {} segments; maximum supported depth is 5 (Agent)",
109 n
110 )))
111 }
112 };
113 Ok(Self {
114 level,
115 path: path.to_string(),
116 })
117 }
118
119 pub fn parent(&self) -> Option<MemoryScope> {
121 if self.level == ScopeLevel::Global {
122 return None;
123 }
124 let last_slash = self.path.rfind('/')?;
126 let parent_path = &self.path[..last_slash];
127 let parent_level = match self.level {
129 ScopeLevel::Org => ScopeLevel::Global,
130 ScopeLevel::User => ScopeLevel::Org,
131 ScopeLevel::Session => ScopeLevel::User,
132 ScopeLevel::Agent => ScopeLevel::Session,
133 ScopeLevel::Global => unreachable!(),
134 };
135 Some(MemoryScope {
136 level: parent_level,
137 path: parent_path.to_string(),
138 })
139 }
140
141 pub fn ancestors(&self) -> Vec<MemoryScope> {
144 let mut result = Vec::new();
145 let mut current = self.parent();
146 while let Some(scope) = current {
147 current = scope.parent();
148 result.push(scope);
149 }
150 result
151 }
152
153 pub fn contains(&self, other: &MemoryScope) -> bool {
158 if self == other {
159 return true;
160 }
161 if other.level <= self.level {
163 return false;
164 }
165 other.path.starts_with(&format!("{}/", self.path))
167 }
168}
169
170impl fmt::Display for MemoryScope {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 write!(f, "{}", self.path)
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ScopeNode {
181 pub scope: MemoryScope,
182 pub memory_count: i64,
183 pub children: Vec<ScopeNode>,
184}
185
186pub fn set_scope(conn: &Connection, memory_id: i64, scope: &MemoryScope) -> Result<()> {
190 let rows = conn.execute(
191 "UPDATE memories SET scope_path = ?1 WHERE id = ?2",
192 params![scope.path, memory_id],
193 )?;
194 if rows == 0 {
195 return Err(EngramError::NotFound(memory_id));
196 }
197 Ok(())
198}
199
200pub fn get_scope(conn: &Connection, memory_id: i64) -> Result<MemoryScope> {
202 let path: Option<String> = conn
203 .query_row(
204 "SELECT scope_path FROM memories WHERE id = ?1",
205 params![memory_id],
206 |row| row.get(0),
207 )
208 .optional()?;
209
210 match path {
211 Some(p) => MemoryScope::parse(&p),
212 None => Err(EngramError::NotFound(memory_id)),
213 }
214}
215
216pub fn list_scopes(conn: &Connection) -> Result<Vec<MemoryScope>> {
218 let mut stmt =
219 conn.prepare("SELECT DISTINCT scope_path FROM memories WHERE scope_path IS NOT NULL")?;
220 let scopes = stmt
221 .query_map([], |row| row.get::<_, String>(0))?
222 .filter_map(|r| r.ok())
223 .filter_map(|path| MemoryScope::parse(&path).ok())
224 .collect();
225 Ok(scopes)
226}
227
228pub fn move_scope(conn: &Connection, memory_id: i64, new_scope: &MemoryScope) -> Result<()> {
230 set_scope(conn, memory_id, new_scope)
231}
232
233pub fn search_scoped(conn: &Connection, query: &str, scope: &MemoryScope) -> Result<Vec<i64>> {
244 let mut paths: Vec<String> = vec![scope.path.clone()];
246 for ancestor in scope.ancestors() {
247 paths.push(ancestor.path);
248 }
249
250 let placeholders: Vec<String> = paths.iter().map(|_| "?".to_string()).collect();
262 let in_clause = placeholders.join(", ");
263 let sql = format!(
264 "SELECT id FROM memories WHERE content LIKE ? AND scope_path IN ({}) ORDER BY id DESC",
265 in_clause
266 );
267
268 let like_query = format!("%{}%", query);
269 let mut stmt = conn.prepare(&sql)?;
270
271 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
273 param_values.push(Box::new(like_query));
274 for p in &paths {
275 param_values.push(Box::new(p.clone()));
276 }
277
278 let refs: Vec<&dyn rusqlite::types::ToSql> = param_values.iter().map(|b| b.as_ref()).collect();
279
280 let ids: Vec<i64> = stmt
281 .query_map(refs.as_slice(), |row| row.get(0))?
282 .filter_map(|r| r.ok())
283 .collect();
284
285 Ok(ids)
286}
287
288pub fn scope_tree(conn: &Connection) -> Result<Vec<ScopeNode>> {
293 let mut stmt = conn.prepare(
295 "SELECT scope_path, COUNT(*) as cnt FROM memories
296 WHERE scope_path IS NOT NULL
297 GROUP BY scope_path
298 ORDER BY scope_path",
299 )?;
300
301 let rows: Vec<(String, i64)> = stmt
302 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
303 .filter_map(|r| r.ok())
304 .collect();
305
306 let mut nodes: Vec<ScopeNode> = rows
308 .into_iter()
309 .filter_map(|(path, count)| {
310 MemoryScope::parse(&path).ok().map(|scope| ScopeNode {
311 scope,
312 memory_count: count,
313 children: Vec::new(),
314 })
315 })
316 .collect();
317
318 nodes.sort_by_key(|n| n.scope.level as usize);
320
321 build_tree(nodes)
322}
323
324fn build_tree(mut nodes: Vec<ScopeNode>) -> Result<Vec<ScopeNode>> {
326 nodes.sort_by(|a, b| (b.scope.level as usize).cmp(&(a.scope.level as usize)));
328
329 let mut roots: Vec<ScopeNode> = Vec::new();
331
332 while let Some(node) = nodes.pop() {
335 if node.scope.level == ScopeLevel::Global {
336 roots.push(node);
337 continue;
338 }
339 let parent_path = match node.scope.parent() {
341 Some(p) => p.path,
342 None => {
343 roots.push(node);
344 continue;
345 }
346 };
347 if let Some(parent) = nodes.iter_mut().find(|n| n.scope.path == parent_path) {
349 parent.children.push(node);
350 } else {
351 roots.push(node);
353 }
354 }
355
356 Ok(roots)
357}
358
359#[cfg(test)]
362mod tests {
363 use super::*;
364 use rusqlite::Connection;
365
366 fn setup_db() -> Connection {
367 let conn = Connection::open_in_memory().unwrap();
368 conn.execute_batch(
369 "CREATE TABLE IF NOT EXISTS memories (
370 id INTEGER PRIMARY KEY,
371 content TEXT NOT NULL,
372 scope_path TEXT DEFAULT 'global'
373 );",
374 )
375 .unwrap();
376 conn
377 }
378
379 fn insert(conn: &Connection, id: i64, content: &str, scope: &str) {
380 conn.execute(
381 "INSERT INTO memories (id, content, scope_path) VALUES (?1, ?2, ?3)",
382 params![id, content, scope],
383 )
384 .unwrap();
385 }
386
387 #[test]
390 fn test_parse_global() {
391 let s = MemoryScope::parse("global").unwrap();
392 assert_eq!(s.level, ScopeLevel::Global);
393 assert_eq!(s.path, "global");
394 }
395
396 #[test]
397 fn test_parse_org() {
398 let s = MemoryScope::parse("global/org:acme").unwrap();
399 assert_eq!(s.level, ScopeLevel::Org);
400 assert_eq!(s.path, "global/org:acme");
401 }
402
403 #[test]
404 fn test_parse_user() {
405 let s = MemoryScope::parse("global/org:acme/user:alice").unwrap();
406 assert_eq!(s.level, ScopeLevel::User);
407 }
408
409 #[test]
410 fn test_parse_invalid_prefix() {
411 assert!(MemoryScope::parse("org:acme").is_err());
412 }
413
414 #[test]
415 fn test_parse_too_deep() {
416 assert!(MemoryScope::parse("global/org:a/user:b/session:c/agent:d/extra:e").is_err());
417 }
418
419 #[test]
422 fn test_parent() {
423 let agent = MemoryScope::parse("global/org:acme/user:alice/session:s1/agent:bot").unwrap();
424 let session = agent.parent().unwrap();
425 assert_eq!(session.level, ScopeLevel::Session);
426 assert_eq!(session.path, "global/org:acme/user:alice/session:s1");
427
428 let user = session.parent().unwrap();
429 assert_eq!(user.level, ScopeLevel::User);
430
431 let org = user.parent().unwrap();
432 assert_eq!(org.level, ScopeLevel::Org);
433
434 let global = org.parent().unwrap();
435 assert_eq!(global.level, ScopeLevel::Global);
436 assert!(global.parent().is_none());
437 }
438
439 #[test]
440 fn test_ancestors() {
441 let user = MemoryScope::parse("global/org:acme/user:alice").unwrap();
442 let ancestors = user.ancestors();
443 assert_eq!(ancestors.len(), 2);
444 assert_eq!(ancestors[0].level, ScopeLevel::Org);
445 assert_eq!(ancestors[1].level, ScopeLevel::Global);
446 }
447
448 #[test]
451 fn test_contains_parent_contains_child() {
452 let global = MemoryScope::global();
453 let org = MemoryScope::parse("global/org:acme").unwrap();
454 let user = MemoryScope::parse("global/org:acme/user:alice").unwrap();
455
456 assert!(global.contains(&org));
457 assert!(global.contains(&user));
458 assert!(org.contains(&user));
459 }
460
461 #[test]
462 fn test_contains_child_does_not_contain_parent() {
463 let global = MemoryScope::global();
464 let org = MemoryScope::parse("global/org:acme").unwrap();
465 assert!(!org.contains(&global));
466 }
467
468 #[test]
469 fn test_contains_sibling_false() {
470 let alice = MemoryScope::parse("global/org:acme/user:alice").unwrap();
471 let bob = MemoryScope::parse("global/org:acme/user:bob").unwrap();
472 assert!(!alice.contains(&bob));
473 assert!(!bob.contains(&alice));
474 }
475
476 #[test]
477 fn test_contains_self_true() {
478 let s = MemoryScope::global();
479 assert!(s.contains(&s));
480 }
481
482 #[test]
485 fn test_set_and_get_scope() {
486 let conn = setup_db();
487 insert(&conn, 1, "hello", "global");
488
489 let new_scope = MemoryScope::parse("global/org:acme").unwrap();
490 set_scope(&conn, 1, &new_scope).unwrap();
491
492 let retrieved = get_scope(&conn, 1).unwrap();
493 assert_eq!(retrieved, new_scope);
494 }
495
496 #[test]
497 fn test_get_scope_not_found() {
498 let conn = setup_db();
499 let err = get_scope(&conn, 999).unwrap_err();
500 assert!(matches!(err, EngramError::NotFound(999)));
501 }
502
503 #[test]
504 fn test_set_scope_not_found() {
505 let conn = setup_db();
506 let scope = MemoryScope::global();
507 let err = set_scope(&conn, 999, &scope).unwrap_err();
508 assert!(matches!(err, EngramError::NotFound(999)));
509 }
510
511 #[test]
514 fn test_search_scoped_ancestor_inheritance() {
515 let conn = setup_db();
516 insert(&conn, 1, "common knowledge", "global");
518 insert(&conn, 2, "acme org policy", "global/org:acme");
520 insert(
522 &conn,
523 3,
524 "alice personal note",
525 "global/org:acme/user:alice",
526 );
527 insert(&conn, 4, "bob personal note", "global/org:acme/user:bob");
529
530 let alice_scope = MemoryScope::parse("global/org:acme/user:alice").unwrap();
531
532 let ids = search_scoped(&conn, "knowledge", &alice_scope).unwrap();
534 assert!(ids.contains(&1), "global memory should be visible");
535 assert!(!ids.contains(&3));
536
537 let ids = search_scoped(&conn, "policy", &alice_scope).unwrap();
539 assert!(ids.contains(&2), "org memory should be visible");
540
541 let ids = search_scoped(&conn, "alice", &alice_scope).unwrap();
543 assert!(ids.contains(&3));
544
545 let ids = search_scoped(&conn, "bob", &alice_scope).unwrap();
547 assert!(
548 !ids.contains(&4),
549 "bob's memory must not be visible to alice"
550 );
551 }
552
553 #[test]
556 fn test_move_scope() {
557 let conn = setup_db();
558 insert(&conn, 1, "memory", "global");
559
560 let new_scope = MemoryScope::parse("global/org:acme/user:alice").unwrap();
561 move_scope(&conn, 1, &new_scope).unwrap();
562
563 let retrieved = get_scope(&conn, 1).unwrap();
564 assert_eq!(retrieved.path, "global/org:acme/user:alice");
565 }
566
567 #[test]
570 fn test_scope_tree() {
571 let conn = setup_db();
572 insert(&conn, 1, "a", "global");
573 insert(&conn, 2, "b", "global");
574 insert(&conn, 3, "c", "global/org:acme");
575 insert(&conn, 4, "d", "global/org:acme/user:alice");
576
577 let tree = scope_tree(&conn).unwrap();
578 let global_node = tree.iter().find(|n| n.scope.level == ScopeLevel::Global);
580 assert!(global_node.is_some(), "global node must be present");
581
582 let global_node = global_node.unwrap();
583 assert_eq!(global_node.memory_count, 2); }
585
586 #[test]
589 fn test_global_has_no_parent() {
590 let global = MemoryScope::global();
591 assert!(global.parent().is_none());
592 assert!(global.ancestors().is_empty());
593 }
594
595 #[test]
598 fn test_display_scope_level() {
599 assert_eq!(ScopeLevel::Global.to_string(), "global");
600 assert_eq!(ScopeLevel::Org.to_string(), "org");
601 assert_eq!(ScopeLevel::User.to_string(), "user");
602 assert_eq!(ScopeLevel::Session.to_string(), "session");
603 assert_eq!(ScopeLevel::Agent.to_string(), "agent");
604 }
605
606 #[test]
607 fn test_display_memory_scope() {
608 let s = MemoryScope::parse("global/org:acme/user:alice").unwrap();
609 assert_eq!(s.to_string(), "global/org:acme/user:alice");
610 }
611
612 #[test]
615 fn test_list_scopes() {
616 let conn = setup_db();
617 insert(&conn, 1, "a", "global");
618 insert(&conn, 2, "b", "global/org:acme");
619 insert(&conn, 3, "c", "global/org:acme");
620
621 let scopes = list_scopes(&conn).unwrap();
622 assert_eq!(scopes.len(), 2);
623 let paths: Vec<&str> = scopes.iter().map(|s| s.path.as_str()).collect();
624 assert!(paths.contains(&"global"));
625 assert!(paths.contains(&"global/org:acme"));
626 }
627}