Skip to main content

engine/
graph.rs

1//! CE-5: Memory Knowledge Graph
2//!
3//! Persistent SQLite sidecar that stores directed edges between memories.
4//! Edge building is asynchronous (spawned tokio task) — write latency unaffected.
5//! BFS traversal targets ≤100 ms for depth ≤3 on a 1000-memory namespace.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::{Arc, Mutex};
9
10use rusqlite::{params, Connection};
11use serde::{Deserialize, Serialize};
12
13/// How two memories are related.
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum EdgeType {
17    /// Cosine similarity ≥ `RELATED_TO_THRESHOLD` (0.85).
18    RelatedTo,
19    /// Both memories share at least one `entity:*` tag (CE-4).
20    SharesEntity,
21    /// Memory A was created before memory B and they are related.
22    Precedes,
23    /// Explicitly linked via `POST /v1/memories/:id/links`.
24    LinkedBy,
25}
26
27impl EdgeType {
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            EdgeType::RelatedTo => "related_to",
31            EdgeType::SharesEntity => "shares_entity",
32            EdgeType::Precedes => "precedes",
33            EdgeType::LinkedBy => "linked_by",
34        }
35    }
36
37    fn from_str(s: &str) -> Self {
38        match s {
39            "shares_entity" => EdgeType::SharesEntity,
40            "precedes" => EdgeType::Precedes,
41            "linked_by" => EdgeType::LinkedBy,
42            _ => EdgeType::RelatedTo,
43        }
44    }
45}
46
47/// A directed graph edge between two memories.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct GraphEdge {
50    pub from_id: String,
51    pub to_id: String,
52    pub edge_type: EdgeType,
53    /// Cosine similarity for `related_to`; 1.0 for explicit/entity edges.
54    pub weight: f32,
55    pub created_at: u64,
56    pub namespace: String,
57}
58
59/// A node returned in graph traversal results.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct GraphNode {
62    pub memory_id: String,
63    pub depth: u32,
64    /// Edges connecting this node to the previous node in the traversal.
65    pub incoming_edges: Vec<GraphEdge>,
66}
67
68/// Export of all graph edges for an agent namespace.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GraphExport {
71    pub namespace: String,
72    pub node_count: usize,
73    pub edge_count: usize,
74    pub edges: Vec<GraphEdge>,
75}
76
77const RELATED_TO_THRESHOLD: f32 = 0.85;
78const MAX_EDGES_PER_MEMORY: usize = 50;
79
80// ---------------------------------------------------------------------------
81// OBS-1: Audit event types
82// ---------------------------------------------------------------------------
83
84/// A business-event audit record.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AuditEvent {
87    pub id: i64,
88    pub event_type: String,
89    pub agent_id: String,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub memory_id: Option<String>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub session_id: Option<String>,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub importance: Option<f32>,
96    /// Unix milliseconds
97    pub timestamp: u64,
98}
99
100/// Insert payload (no `id` yet — assigned by SQLite AUTOINCREMENT).
101#[derive(Debug, Clone)]
102pub struct AuditEventInsert {
103    pub event_type: String,
104    pub agent_id: String,
105    pub memory_id: Option<String>,
106    pub session_id: Option<String>,
107    pub importance: Option<f32>,
108    /// Unix milliseconds
109    pub timestamp: u64,
110}
111
112/// Thread-safe graph engine backed by a SQLite database.
113#[derive(Clone)]
114pub struct MemoryGraphEngine {
115    conn: Arc<Mutex<Connection>>,
116}
117
118impl MemoryGraphEngine {
119    /// Open (or create) a graph database at the given path.
120    /// Pass `":memory:"` for an ephemeral in-process database.
121    pub fn open(path: &str) -> Result<Self, rusqlite::Error> {
122        let conn = Connection::open(path)?;
123        conn.execute_batch(
124            "PRAGMA journal_mode=WAL;
125             PRAGMA synchronous=NORMAL;
126             CREATE TABLE IF NOT EXISTS edges (
127                 from_id    TEXT NOT NULL,
128                 to_id      TEXT NOT NULL,
129                 edge_type  TEXT NOT NULL,
130                 weight     REAL NOT NULL DEFAULT 1.0,
131                 created_at INTEGER NOT NULL,
132                 namespace  TEXT NOT NULL,
133                 PRIMARY KEY (from_id, to_id, edge_type)
134             );
135             CREATE INDEX IF NOT EXISTS idx_edges_from    ON edges(from_id);
136             CREATE INDEX IF NOT EXISTS idx_edges_to      ON edges(to_id);
137             CREATE INDEX IF NOT EXISTS idx_edges_ns      ON edges(namespace);
138             CREATE TABLE IF NOT EXISTS audit_events (
139                 id         INTEGER PRIMARY KEY AUTOINCREMENT,
140                 event_type TEXT NOT NULL,
141                 agent_id   TEXT NOT NULL,
142                 memory_id  TEXT,
143                 session_id TEXT,
144                 importance REAL,
145                 timestamp  INTEGER NOT NULL
146             );
147             CREATE INDEX IF NOT EXISTS idx_audit_agent ON audit_events(agent_id);
148             CREATE INDEX IF NOT EXISTS idx_audit_type  ON audit_events(event_type);
149             CREATE INDEX IF NOT EXISTS idx_audit_ts    ON audit_events(timestamp);",
150        )?;
151        Ok(Self {
152            conn: Arc::new(Mutex::new(conn)),
153        })
154    }
155
156    /// Upsert a single directed edge (idempotent).
157    pub fn upsert_edge(&self, edge: &GraphEdge) -> Result<(), rusqlite::Error> {
158        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
159        conn.execute(
160            "INSERT OR REPLACE INTO edges
161                 (from_id, to_id, edge_type, weight, created_at, namespace)
162             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
163            params![
164                edge.from_id,
165                edge.to_id,
166                edge.edge_type.as_str(),
167                edge.weight,
168                edge.created_at as i64,
169                edge.namespace,
170            ],
171        )?;
172        Ok(())
173    }
174
175    /// Delete all edges involving a memory (used when a memory is forgotten).
176    pub fn remove_memory(&self, memory_id: &str) -> Result<(), rusqlite::Error> {
177        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
178        conn.execute(
179            "DELETE FROM edges WHERE from_id = ?1 OR to_id = ?1",
180            params![memory_id],
181        )?;
182        Ok(())
183    }
184
185    /// Return all edges incident to a memory (both outbound and inbound).
186    pub fn get_edges(&self, memory_id: &str) -> Vec<GraphEdge> {
187        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
188        let mut stmt = match conn.prepare(
189            "SELECT from_id, to_id, edge_type, weight, created_at, namespace
190             FROM edges
191             WHERE from_id = ?1 OR to_id = ?1",
192        ) {
193            Ok(s) => s,
194            Err(_) => return Vec::new(),
195        };
196        stmt.query_map(params![memory_id], row_to_edge)
197            .map(|rows| rows.filter_map(|r| r.ok()).collect())
198            .unwrap_or_default()
199    }
200
201    /// BFS traversal starting from `root_id` up to `max_depth`.
202    /// Returns nodes in BFS order, each annotated with depth and incoming edges.
203    pub fn traverse(&self, root_id: &str, max_depth: u32, namespace: &str) -> Vec<GraphNode> {
204        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
205
206        let mut visited: HashSet<String> = HashSet::new();
207        let mut queue: VecDeque<(String, u32)> = VecDeque::new();
208        let mut result: Vec<GraphNode> = Vec::new();
209
210        visited.insert(root_id.to_string());
211        queue.push_back((root_id.to_string(), 0));
212
213        // Root node with no incoming edges
214        result.push(GraphNode {
215            memory_id: root_id.to_string(),
216            depth: 0,
217            incoming_edges: Vec::new(),
218        });
219
220        while let Some((current, depth)) = queue.pop_front() {
221            if depth >= max_depth {
222                continue;
223            }
224
225            // Get all neighbors (both directions) in this namespace
226            let mut stmt = match conn.prepare(
227                "SELECT from_id, to_id, edge_type, weight, created_at, namespace
228                 FROM edges
229                 WHERE (from_id = ?1 OR to_id = ?1) AND namespace = ?2",
230            ) {
231                Ok(s) => s,
232                Err(_) => continue,
233            };
234
235            let edges: Vec<GraphEdge> = stmt
236                .query_map(params![current, namespace], row_to_edge)
237                .map(|rows| rows.filter_map(|r| r.ok()).collect())
238                .unwrap_or_default();
239
240            // Group edges by neighbor
241            let mut neighbor_edges: HashMap<String, Vec<GraphEdge>> = HashMap::new();
242            for edge in &edges {
243                let neighbor = if edge.from_id == current {
244                    edge.to_id.clone()
245                } else {
246                    edge.from_id.clone()
247                };
248                if !visited.contains(&neighbor) {
249                    neighbor_edges
250                        .entry(neighbor)
251                        .or_default()
252                        .push(edge.clone());
253                }
254            }
255
256            for (neighbor, inc_edges) in neighbor_edges {
257                visited.insert(neighbor.clone());
258                queue.push_back((neighbor.clone(), depth + 1));
259                result.push(GraphNode {
260                    memory_id: neighbor,
261                    depth: depth + 1,
262                    incoming_edges: inc_edges,
263                });
264            }
265        }
266
267        result
268    }
269
270    /// BFS shortest path between two memories.
271    /// Returns the sequence of memory IDs from `from_id` to `to_id`, inclusive.
272    /// Returns `None` if no path exists.
273    pub fn shortest_path(
274        &self,
275        from_id: &str,
276        to_id: &str,
277        namespace: &str,
278    ) -> Option<Vec<String>> {
279        if from_id == to_id {
280            return Some(vec![from_id.to_string()]);
281        }
282
283        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
284        let mut visited: HashSet<String> = HashSet::new();
285        let mut queue: VecDeque<Vec<String>> = VecDeque::new();
286
287        visited.insert(from_id.to_string());
288        queue.push_back(vec![from_id.to_string()]);
289
290        while let Some(path) = queue.pop_front() {
291            let current = path.last().unwrap();
292
293            let mut stmt = conn
294                .prepare(
295                    "SELECT from_id, to_id FROM edges
296                     WHERE (from_id = ?1 OR to_id = ?1) AND namespace = ?2",
297                )
298                .ok()?;
299
300            let neighbors: Vec<String> = stmt
301                .query_map(params![current, namespace], |row| {
302                    let from: String = row.get(0)?;
303                    let to: String = row.get(1)?;
304                    Ok((from, to))
305                })
306                .ok()?
307                .filter_map(|r| r.ok())
308                .map(|(from, to)| if from == *current { to } else { from })
309                .collect();
310
311            for neighbor in neighbors {
312                if visited.contains(&neighbor) {
313                    continue;
314                }
315                let mut new_path = path.clone();
316                new_path.push(neighbor.clone());
317                if neighbor == to_id {
318                    return Some(new_path);
319                }
320                visited.insert(neighbor);
321                queue.push_back(new_path);
322            }
323        }
324
325        None
326    }
327
328    /// Export all edges in a namespace.
329    pub fn export(&self, namespace: &str) -> GraphExport {
330        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
331
332        let edges: Vec<GraphEdge> = {
333            let mut stmt = match conn.prepare(
334                "SELECT from_id, to_id, edge_type, weight, created_at, namespace
335                 FROM edges WHERE namespace = ?1",
336            ) {
337                Ok(s) => s,
338                Err(_) => {
339                    return GraphExport {
340                        namespace: namespace.to_string(),
341                        node_count: 0,
342                        edge_count: 0,
343                        edges: Vec::new(),
344                    }
345                }
346            };
347            stmt.query_map(params![namespace], row_to_edge)
348                .map(|rows| rows.filter_map(|r| r.ok()).collect())
349                .unwrap_or_default()
350        };
351
352        // Count unique node IDs
353        let mut nodes: HashSet<String> = HashSet::new();
354        for e in &edges {
355            nodes.insert(e.from_id.clone());
356            nodes.insert(e.to_id.clone());
357        }
358
359        GraphExport {
360            namespace: namespace.to_string(),
361            node_count: nodes.len(),
362            edge_count: edges.len(),
363            edges,
364        }
365    }
366
367    // -----------------------------------------------------------------------
368    // OBS-1: Audit log methods
369    // -----------------------------------------------------------------------
370
371    /// Insert a business-event audit record.
372    pub fn insert_audit_event(&self, event: &AuditEventInsert) -> Result<(), rusqlite::Error> {
373        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
374        conn.execute(
375            "INSERT INTO audit_events
376                 (event_type, agent_id, memory_id, session_id, importance, timestamp)
377             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
378            params![
379                event.event_type,
380                event.agent_id,
381                event.memory_id,
382                event.session_id,
383                event.importance,
384                event.timestamp as i64,
385            ],
386        )?;
387        Ok(())
388    }
389
390    /// Query audit events with optional filters.
391    ///
392    /// Uses always-bound params with IS-NULL guards so the query is pre-compiled
393    /// once and indexes are still usable for the common single-filter case.
394    pub fn query_audit_events(
395        &self,
396        agent_id: Option<&str>,
397        event_type: Option<&str>,
398        from_ts: Option<u64>,
399        to_ts: Option<u64>,
400        limit: usize,
401    ) -> Vec<AuditEvent> {
402        let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
403        let limit = limit.min(10_000) as i64;
404        let mut stmt = match conn.prepare(
405            "SELECT id, event_type, agent_id, memory_id, session_id, importance, timestamp
406             FROM audit_events
407             WHERE (?1 IS NULL OR agent_id = ?1)
408               AND (?2 IS NULL OR event_type = ?2)
409               AND (?3 IS NULL OR timestamp >= ?3)
410               AND (?4 IS NULL OR timestamp <= ?4)
411             ORDER BY timestamp DESC
412             LIMIT ?5",
413        ) {
414            Ok(s) => s,
415            Err(_) => return Vec::new(),
416        };
417        let from_ts_val = from_ts.map(|v| v as i64);
418        let to_ts_val = to_ts.map(|v| v as i64);
419        stmt.query_map(
420            params![agent_id, event_type, from_ts_val, to_ts_val, limit],
421            |row| {
422                Ok(AuditEvent {
423                    id: row.get(0)?,
424                    event_type: row.get(1)?,
425                    agent_id: row.get(2)?,
426                    memory_id: row.get(3)?,
427                    session_id: row.get(4)?,
428                    importance: row.get(5)?,
429                    timestamp: row.get::<_, i64>(6)? as u64,
430                })
431            },
432        )
433        .map(|rows| rows.filter_map(|r| r.ok()).collect())
434        .unwrap_or_default()
435    }
436
437    /// Build edges for a newly stored memory.
438    ///
439    /// This is called from `store_memory` inside a `tokio::spawn` so it does
440    /// not block the HTTP response.  It computes:
441    /// - `related_to`   edges for all existing memories with cosine ≥ 0.85
442    /// - `shares_entity` edges for memories sharing at least one `entity:*` tag
443    /// - `precedes`     edges when memory A was stored before memory B and they
444    ///   also qualify as `related_to`
445    pub fn build_edges_for_new_memory(
446        &self,
447        new_id: &str,
448        new_embedding: &[f32],
449        new_tags: &[String],
450        new_created_at: u64,
451        namespace: &str,
452        existing: &[(String, Vec<f32>, Vec<String>, u64)], // (id, embedding, tags, created_at)
453    ) {
454        let now = std::time::SystemTime::now()
455            .duration_since(std::time::UNIX_EPOCH)
456            .unwrap_or_default()
457            .as_secs();
458
459        let new_entity_tags: HashSet<&str> = new_tags
460            .iter()
461            .filter(|t| t.starts_with("entity:"))
462            .map(|t| t.as_str())
463            .collect();
464
465        let mut edge_count = 0usize;
466
467        for (other_id, other_embedding, other_tags, other_created_at) in existing {
468            if other_id == new_id || edge_count >= MAX_EDGES_PER_MEMORY {
469                break;
470            }
471
472            let similarity = cosine_similarity(new_embedding, other_embedding);
473
474            if similarity >= RELATED_TO_THRESHOLD {
475                let _ = self.upsert_edge(&GraphEdge {
476                    from_id: new_id.to_string(),
477                    to_id: other_id.clone(),
478                    edge_type: EdgeType::RelatedTo,
479                    weight: similarity,
480                    created_at: now,
481                    namespace: namespace.to_string(),
482                });
483                edge_count += 1;
484
485                // `precedes` edge: older memory precedes newer one
486                if *other_created_at < new_created_at {
487                    let _ = self.upsert_edge(&GraphEdge {
488                        from_id: other_id.clone(),
489                        to_id: new_id.to_string(),
490                        edge_type: EdgeType::Precedes,
491                        weight: 1.0,
492                        created_at: now,
493                        namespace: namespace.to_string(),
494                    });
495                } else {
496                    let _ = self.upsert_edge(&GraphEdge {
497                        from_id: new_id.to_string(),
498                        to_id: other_id.clone(),
499                        edge_type: EdgeType::Precedes,
500                        weight: 1.0,
501                        created_at: now,
502                        namespace: namespace.to_string(),
503                    });
504                }
505            }
506
507            // shares_entity: any common entity tag
508            let other_entity_tags: HashSet<&str> = other_tags
509                .iter()
510                .filter(|t| t.starts_with("entity:"))
511                .map(|t| t.as_str())
512                .collect();
513            if !new_entity_tags.is_empty()
514                && new_entity_tags
515                    .intersection(&other_entity_tags)
516                    .next()
517                    .is_some()
518            {
519                let _ = self.upsert_edge(&GraphEdge {
520                    from_id: new_id.to_string(),
521                    to_id: other_id.clone(),
522                    edge_type: EdgeType::SharesEntity,
523                    weight: 1.0,
524                    created_at: now,
525                    namespace: namespace.to_string(),
526                });
527                edge_count += 1;
528            }
529        }
530    }
531}
532
533// ---------------------------------------------------------------------------
534// Helper functions
535// ---------------------------------------------------------------------------
536
537fn row_to_edge(row: &rusqlite::Row<'_>) -> rusqlite::Result<GraphEdge> {
538    Ok(GraphEdge {
539        from_id: row.get(0)?,
540        to_id: row.get(1)?,
541        edge_type: EdgeType::from_str(&row.get::<_, String>(2)?),
542        weight: row.get(3)?,
543        created_at: row.get::<_, i64>(4)? as u64,
544        namespace: row.get(5)?,
545    })
546}
547
548/// Cosine similarity between two unit-ish vectors.
549fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
550    if a.len() != b.len() || a.is_empty() {
551        return 0.0;
552    }
553    let mut dot = 0.0f32;
554    let mut norm_a = 0.0f32;
555    let mut norm_b = 0.0f32;
556    for (x, y) in a.iter().zip(b.iter()) {
557        dot += x * y;
558        norm_a += x * x;
559        norm_b += y * y;
560    }
561    let denom = norm_a.sqrt() * norm_b.sqrt();
562    if denom == 0.0 {
563        0.0
564    } else {
565        (dot / denom).clamp(-1.0, 1.0)
566    }
567}
568
569// ---------------------------------------------------------------------------
570// Environment-driven factory
571// ---------------------------------------------------------------------------
572
573/// Open a `MemoryGraphEngine` using `DAKERA_DATA_DIR` if set, otherwise `:memory:`.
574pub fn open_from_env() -> Arc<MemoryGraphEngine> {
575    let path = std::env::var("DAKERA_DATA_DIR")
576        .map(|dir| format!("{}/graph.db", dir))
577        .unwrap_or_else(|_| ":memory:".to_string());
578
579    match MemoryGraphEngine::open(&path) {
580        Ok(engine) => {
581            tracing::info!(path = %path, "CE-5: memory knowledge graph opened");
582            Arc::new(engine)
583        }
584        Err(e) => {
585            tracing::warn!(error = %e, "CE-5: failed to open graph.db, falling back to :memory:");
586            Arc::new(MemoryGraphEngine::open(":memory:").expect("in-memory sqlite must work"))
587        }
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594
595    fn test_engine() -> MemoryGraphEngine {
596        MemoryGraphEngine::open(":memory:").unwrap()
597    }
598
599    fn dummy_embedding(seed: f32, dim: usize) -> Vec<f32> {
600        // All-same value → cosine similarity = 1.0 between any two of these
601        let v = vec![seed / 10.0; dim];
602        // Normalize
603        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
604        if norm == 0.0 {
605            v
606        } else {
607            v.iter().map(|x| x / norm).collect()
608        }
609    }
610
611    #[test]
612    fn test_upsert_and_get_edges() {
613        let g = test_engine();
614        g.upsert_edge(&GraphEdge {
615            from_id: "mem_a".into(),
616            to_id: "mem_b".into(),
617            edge_type: EdgeType::RelatedTo,
618            weight: 0.9,
619            created_at: 1000,
620            namespace: "ns1".into(),
621        })
622        .unwrap();
623
624        let edges = g.get_edges("mem_a");
625        assert_eq!(edges.len(), 1);
626        assert_eq!(edges[0].to_id, "mem_b");
627        assert_eq!(edges[0].edge_type, EdgeType::RelatedTo);
628    }
629
630    #[test]
631    fn test_bfs_traversal() {
632        let g = test_engine();
633        let ns = "test_ns";
634        // Chain: a -> b -> c
635        for (from, to) in [("mem_a", "mem_b"), ("mem_b", "mem_c")] {
636            g.upsert_edge(&GraphEdge {
637                from_id: from.into(),
638                to_id: to.into(),
639                edge_type: EdgeType::RelatedTo,
640                weight: 0.9,
641                created_at: 1000,
642                namespace: ns.into(),
643            })
644            .unwrap();
645        }
646
647        let nodes = g.traverse("mem_a", 3, ns);
648        let ids: Vec<&str> = nodes.iter().map(|n| n.memory_id.as_str()).collect();
649        assert!(ids.contains(&"mem_a"));
650        assert!(ids.contains(&"mem_b"));
651        assert!(ids.contains(&"mem_c"));
652    }
653
654    #[test]
655    fn test_shortest_path() {
656        let g = test_engine();
657        let ns = "test_ns2";
658        // a-b, b-c, a-c (direct)
659        for (from, to) in [("ma", "mb"), ("mb", "mc"), ("ma", "mc")] {
660            g.upsert_edge(&GraphEdge {
661                from_id: from.into(),
662                to_id: to.into(),
663                edge_type: EdgeType::RelatedTo,
664                weight: 0.9,
665                created_at: 1000,
666                namespace: ns.into(),
667            })
668            .unwrap();
669        }
670
671        let path = g.shortest_path("ma", "mc", ns).unwrap();
672        // Direct path ma->mc has length 2
673        assert_eq!(path.len(), 2);
674        assert_eq!(path[0], "ma");
675        assert_eq!(path[1], "mc");
676    }
677
678    #[test]
679    fn test_build_edges_for_new_memory() {
680        let g = test_engine();
681        let ns = "build_test";
682        let dim = 4;
683
684        // Two very similar embeddings (seed 1 and 2 after normalization are both [0.5, 0.5, 0.5, 0.5])
685        let emb_a = dummy_embedding(1.0, dim);
686        let emb_b = dummy_embedding(2.0, dim);
687        let emb_new = dummy_embedding(1.5, dim);
688
689        g.build_edges_for_new_memory(
690            "mem_new",
691            &emb_new,
692            &[],
693            2000,
694            ns,
695            &[
696                ("mem_a".into(), emb_a, vec![], 1000),
697                ("mem_b".into(), emb_b, vec![], 1500),
698            ],
699        );
700
701        let edges = g.get_edges("mem_new");
702        // All three have identical direction → cosine = 1.0 → should produce related_to edges
703        assert!(!edges.is_empty());
704    }
705
706    #[test]
707    fn test_remove_memory() {
708        let g = test_engine();
709        g.upsert_edge(&GraphEdge {
710            from_id: "del_me".into(),
711            to_id: "other".into(),
712            edge_type: EdgeType::RelatedTo,
713            weight: 0.9,
714            created_at: 0,
715            namespace: "ns".into(),
716        })
717        .unwrap();
718
719        g.remove_memory("del_me").unwrap();
720        assert!(g.get_edges("del_me").is_empty());
721    }
722
723    // -----------------------------------------------------------------------
724    // OBS-1: Audit event tests
725    // -----------------------------------------------------------------------
726
727    #[test]
728    fn test_audit_event_insert_and_query() {
729        let g = test_engine();
730        let insert = AuditEventInsert {
731            event_type: "memory.stored".to_string(),
732            agent_id: "agent-1".to_string(),
733            memory_id: Some("mem_abc".to_string()),
734            session_id: Some("sess_xyz".to_string()),
735            importance: Some(0.8),
736            timestamp: 1_700_000_000_000,
737        };
738        g.insert_audit_event(&insert).unwrap();
739
740        let events = g.query_audit_events(None, None, None, None, 10);
741        assert_eq!(events.len(), 1);
742        assert_eq!(events[0].event_type, "memory.stored");
743        assert_eq!(events[0].agent_id, "agent-1");
744        assert_eq!(events[0].memory_id.as_deref(), Some("mem_abc"));
745        assert_eq!(events[0].session_id.as_deref(), Some("sess_xyz"));
746        assert!((events[0].importance.unwrap() - 0.8).abs() < 1e-5);
747        assert_eq!(events[0].timestamp, 1_700_000_000_000);
748    }
749
750    #[test]
751    fn test_audit_query_filter_by_agent() {
752        let g = test_engine();
753        for i in 0..5u64 {
754            g.insert_audit_event(&AuditEventInsert {
755                event_type: "memory.recalled".to_string(),
756                agent_id: if i < 3 { "agent-a" } else { "agent-b" }.to_string(),
757                memory_id: None,
758                session_id: None,
759                importance: None,
760                timestamp: 1_000 + i,
761            })
762            .unwrap();
763        }
764        let for_a = g.query_audit_events(Some("agent-a"), None, None, None, 100);
765        assert_eq!(for_a.len(), 3);
766        let for_b = g.query_audit_events(Some("agent-b"), None, None, None, 100);
767        assert_eq!(for_b.len(), 2);
768    }
769
770    #[test]
771    fn test_audit_query_filter_by_event_type() {
772        let g = test_engine();
773        g.insert_audit_event(&AuditEventInsert {
774            event_type: "memory.stored".to_string(),
775            agent_id: "ag".to_string(),
776            memory_id: None,
777            session_id: None,
778            importance: None,
779            timestamp: 1,
780        })
781        .unwrap();
782        g.insert_audit_event(&AuditEventInsert {
783            event_type: "session.started".to_string(),
784            agent_id: "ag".to_string(),
785            memory_id: None,
786            session_id: None,
787            importance: None,
788            timestamp: 2,
789        })
790        .unwrap();
791
792        let stored = g.query_audit_events(None, Some("memory.stored"), None, None, 10);
793        assert_eq!(stored.len(), 1);
794        let sessions = g.query_audit_events(None, Some("session.started"), None, None, 10);
795        assert_eq!(sessions.len(), 1);
796    }
797
798    #[test]
799    fn test_audit_query_time_range() {
800        let g = test_engine();
801        for ts in [100u64, 200, 300, 400, 500] {
802            g.insert_audit_event(&AuditEventInsert {
803                event_type: "ev".to_string(),
804                agent_id: "ag".to_string(),
805                memory_id: None,
806                session_id: None,
807                importance: None,
808                timestamp: ts,
809            })
810            .unwrap();
811        }
812        // [200, 400] should return 3 events
813        let events = g.query_audit_events(None, None, Some(200), Some(400), 100);
814        assert_eq!(events.len(), 3);
815    }
816
817    #[test]
818    fn test_audit_query_limit() {
819        let g = test_engine();
820        for i in 0..20u64 {
821            g.insert_audit_event(&AuditEventInsert {
822                event_type: "ev".to_string(),
823                agent_id: "ag".to_string(),
824                memory_id: None,
825                session_id: None,
826                importance: None,
827                timestamp: i,
828            })
829            .unwrap();
830        }
831        let events = g.query_audit_events(None, None, None, None, 5);
832        assert_eq!(events.len(), 5);
833    }
834}