Skip to main content

cersei_memory/
graph.rs

1//! Graph-backed memory using Grafeo embedded graph database.
2//!
3//! Optional feature: enable with `features = ["graph"]` in Cargo.toml.
4//!
5//! ## Schema (v2)
6//! ```text
7//! (:Memory {id, content, mem_type, confidence, created_at, updated_at,
8//!           last_validated_at, decay_rate, embedding_model_version})
9//!   -[:RELATES_TO {relationship, weight}]-> (:Memory)
10//!
11//! (:Session {session_id, started_at, model, turns})
12//!   -[:PRODUCED]-> (:Memory)
13//!
14//! (:Topic {name})
15//!   -[:TAGGED]-> (:Memory)
16//!
17//! (:SchemaVersion {singleton, version, migrated_at, code_version})
18//! ```
19
20#[cfg(feature = "graph")]
21use grafeo::GrafeoDB;
22
23use crate::memdir::MemoryType;
24use cersei_types::*;
25use std::path::Path;
26
27// Re-export migration utilities
28pub use crate::graph_migrate::{self, effective_confidence, VersionCheck, CURRENT_SCHEMA_VERSION};
29
30/// Graph-backed memory store.
31pub struct GraphMemory {
32    #[cfg(feature = "graph")]
33    db: GrafeoDB,
34    #[cfg(not(feature = "graph"))]
35    _phantom: (),
36}
37
38/// Stats about the graph memory store.
39#[derive(Debug, Clone, Default)]
40pub struct GraphStats {
41    pub memory_count: usize,
42    pub session_count: usize,
43    pub topic_count: usize,
44    pub relationship_count: usize,
45}
46
47// ─── Centralized GQL queries ───────────────────────────────────────────────
48
49#[cfg(feature = "graph")]
50mod gql {
51    pub fn escape(s: &str) -> String {
52        s.replace('\\', "\\\\").replace('\'', "\\'")
53    }
54
55    pub fn insert_memory(
56        id: &str,
57        content: &str,
58        mem_type: &str,
59        confidence: f32,
60        now: &str,
61    ) -> String {
62        format!(
63            "INSERT (:Memory {{id: '{id}', content: '{content}', mem_type: '{mem_type}', \
64             confidence: {confidence}, created_at: '{now}', updated_at: '{now}', \
65             last_validated_at: '{now}', decay_rate: 0.01, embedding_model_version: ''}})"
66        )
67    }
68
69    pub fn link_memories(from_id: &str, to_id: &str, relationship: &str) -> String {
70        format!(
71            "MATCH (a:Memory {{id: '{from_id}'}}), (b:Memory {{id: '{to_id}'}}) \
72             INSERT (a)-[:RELATES_TO {{relationship: '{relationship}'}}]->(b)"
73        )
74    }
75
76    pub fn tag_memory(memory_id: &str, topic: &str) -> String {
77        format!(
78            "MATCH (m:Memory {{id: '{memory_id}'}}) \
79             INSERT (:Topic {{name: '{topic}'}})-[:TAGGED]->(m)"
80        )
81    }
82
83    pub fn insert_session(session_id: &str, now: &str, model: &str, turns: u32) -> String {
84        format!(
85            "INSERT (:Session {{session_id: '{session_id}', started_at: '{now}', \
86             model: '{model}', turns: {turns}}})"
87        )
88    }
89
90    pub fn recall(escaped_query: &str, limit: usize) -> String {
91        format!(
92            "MATCH (m:Memory) WHERE m.content CONTAINS '{escaped_query}' RETURN m.content LIMIT {limit}"
93        )
94    }
95
96    pub fn by_type(type_str: &str) -> String {
97        format!("MATCH (m:Memory {{mem_type: '{type_str}'}}) RETURN m.content")
98    }
99
100    pub fn by_topic(topic: &str) -> String {
101        format!("MATCH (:Topic {{name: '{topic}'}})-[:TAGGED]->(m:Memory) RETURN m.content")
102    }
103
104    pub fn revalidate(memory_id: &str, now: &str) -> String {
105        // Since Grafeo may not support SET, we use a workaround:
106        // Delete and re-insert would lose data. Instead we just track validation
107        // through the SchemaVersion system. For now this is a no-op query that
108        // verifies the node exists.
109        format!("MATCH (m:Memory {{id: '{memory_id}'}}) RETURN m.id")
110    }
111
112    pub const COUNT_MEMORIES: &str = "MATCH (m:Memory) RETURN count(m)";
113    pub const COUNT_SESSIONS: &str = "MATCH (s:Session) RETURN count(s)";
114    pub const COUNT_TOPICS: &str = "MATCH (t:Topic) RETURN count(t)";
115    pub const COUNT_RELATIONSHIPS: &str = "MATCH ()-[r:RELATES_TO]->() RETURN count(r)";
116}
117
118impl GraphMemory {
119    /// Open a persistent graph database at the given path.
120    /// Automatically checks schema version and runs migrations if needed.
121    #[cfg(feature = "graph")]
122    pub fn open(path: &Path) -> Result<Self> {
123        let db = GrafeoDB::open(path)
124            .map_err(|e| CerseiError::Config(format!("Failed to open graph DB: {}", e)))?;
125
126        // Version check and auto-migrate
127        match graph_migrate::check_version(&db) {
128            VersionCheck::UpToDate => {}
129            VersionCheck::NeedsMigration { from, to } => {
130                graph_migrate::run_migrations(&db, from, to)?;
131            }
132            VersionCheck::CodeBehind {
133                graph_version,
134                code_version,
135            } => {
136                tracing::warn!(
137                    "Graph schema v{} is newer than code v{}. Forward-compatible reads will be used.",
138                    graph_version, code_version
139                );
140            }
141        }
142
143        Ok(Self { db })
144    }
145
146    /// Create an in-memory graph database (no persistence).
147    /// Automatically stamps the current schema version.
148    #[cfg(feature = "graph")]
149    pub fn open_in_memory() -> Result<Self> {
150        let db = GrafeoDB::new_in_memory();
151
152        // Fresh in-memory graph always needs version stamp
153        match graph_migrate::check_version(&db) {
154            VersionCheck::UpToDate => {}
155            VersionCheck::NeedsMigration { from, to } => {
156                graph_migrate::run_migrations(&db, from, to)?;
157            }
158            _ => {}
159        }
160
161        Ok(Self { db })
162    }
163
164    /// Fallback: graph feature not enabled.
165    #[cfg(not(feature = "graph"))]
166    pub fn open(_path: &Path) -> Result<Self> {
167        Err(CerseiError::Config(
168            "Graph memory requires the 'graph' feature. Enable it in Cargo.toml.".into(),
169        ))
170    }
171
172    /// Fallback: graph feature not enabled.
173    #[cfg(not(feature = "graph"))]
174    pub fn open_in_memory() -> Result<Self> {
175        Err(CerseiError::Config(
176            "Graph memory requires the 'graph' feature. Enable it in Cargo.toml.".into(),
177        ))
178    }
179
180    // ─── Write operations ────────────────────────────────────────────────
181
182    /// Store a memory as a graph node (v2 schema: includes decay and embedding fields).
183    #[cfg(feature = "graph")]
184    pub fn store_memory(
185        &self,
186        content: &str,
187        mem_type: MemoryType,
188        confidence: f32,
189    ) -> Result<String> {
190        let session = self.db.session();
191        let mem_type_str = format!("{:?}", mem_type);
192        let now = chrono::Utc::now().to_rfc3339();
193        let id = uuid::Uuid::new_v4().to_string();
194        let escaped = gql::escape(content);
195
196        let query = gql::insert_memory(&id, &escaped, &mem_type_str, confidence, &now);
197        session
198            .execute(&query)
199            .map_err(|e| CerseiError::Config(format!("Graph insert failed: {}", e)))?;
200
201        Ok(id)
202    }
203
204    /// Link two memories with a named relationship.
205    #[cfg(feature = "graph")]
206    pub fn link_memories(&self, from_id: &str, to_id: &str, relationship: &str) -> Result<()> {
207        let session = self.db.session();
208        let query = gql::link_memories(from_id, to_id, relationship);
209        session
210            .execute(&query)
211            .map_err(|e| CerseiError::Config(format!("Graph link failed: {}", e)))?;
212        Ok(())
213    }
214
215    /// Tag a memory with a topic.
216    #[cfg(feature = "graph")]
217    pub fn tag_memory(&self, memory_id: &str, topic: &str) -> Result<()> {
218        let session = self.db.session();
219        let query = gql::tag_memory(memory_id, topic);
220        session
221            .execute(&query)
222            .map_err(|e| CerseiError::Config(format!("Graph tag failed: {}", e)))?;
223        Ok(())
224    }
225
226    /// Record a session in the graph.
227    #[cfg(feature = "graph")]
228    pub fn record_session(&self, session_id: &str, model: Option<&str>, turns: u32) -> Result<()> {
229        let session = self.db.session();
230        let now = chrono::Utc::now().to_rfc3339();
231        let model_str = model.unwrap_or("unknown");
232        let query = gql::insert_session(session_id, &now, model_str, turns);
233        session
234            .execute(&query)
235            .map_err(|e| CerseiError::Config(format!("Graph session record failed: {}", e)))?;
236        Ok(())
237    }
238
239    /// Revalidate a memory — resets the confidence decay clock.
240    /// Returns Ok(true) if the memory was found, Ok(false) if not.
241    #[cfg(feature = "graph")]
242    pub fn revalidate_memory(&self, memory_id: &str) -> Result<bool> {
243        let session = self.db.session();
244        let query = gql::revalidate(memory_id, &chrono::Utc::now().to_rfc3339());
245        match session.execute(&query) {
246            Ok(result) => Ok(result.iter().next().is_some()),
247            Err(e) => Err(CerseiError::Config(format!(
248                "Graph revalidate failed: {}",
249                e
250            ))),
251        }
252    }
253
254    // ─── Query operations ────────────────────────────────────────────────
255
256    /// Recall memories matching a text query (substring match).
257    #[cfg(feature = "graph")]
258    pub fn recall(&self, query_text: &str, limit: usize) -> Vec<String> {
259        let session = self.db.session();
260        let escaped = gql::escape(query_text);
261        let query = gql::recall(&escaped, limit);
262        match session.execute(&query) {
263            Ok(result) => result
264                .iter()
265                .filter_map(|row| row.first().map(|v| format!("{}", v)))
266                .collect(),
267            Err(_) => Vec::new(),
268        }
269    }
270
271    /// Recall memories matching a text query with a relevance score.
272    /// Score = fraction of query words found in each memory, in [0, 1].
273    /// Results are ranked by score descending; ties preserve insertion order.
274    /// Pulls up to 4× `limit` candidates via substring match, then re-ranks.
275    #[cfg(feature = "graph")]
276    pub fn recall_top_k(&self, query_text: &str, limit: usize) -> Vec<(String, f32)> {
277        if limit == 0 || query_text.trim().is_empty() {
278            return Vec::new();
279        }
280        // Pull a generous candidate set so word-overlap re-ranking has room.
281        let candidates = self.recall(query_text, limit.saturating_mul(4).max(16));
282        let words: Vec<String> = query_text
283            .split_whitespace()
284            .filter_map(|w| {
285                let w = w
286                    .trim_matches(|c: char| !c.is_alphanumeric())
287                    .to_lowercase();
288                if w.is_empty() || w.len() < 2 {
289                    None
290                } else {
291                    Some(w)
292                }
293            })
294            .collect();
295        if words.is_empty() {
296            // Fall back to uniform top-limit.
297            return candidates
298                .into_iter()
299                .take(limit)
300                .map(|c| (c, 1.0))
301                .collect();
302        }
303        let mut scored: Vec<(String, f32)> = candidates
304            .into_iter()
305            .map(|c| {
306                let lower = c.to_lowercase();
307                let hits = words.iter().filter(|w| lower.contains(w.as_str())).count();
308                (c, hits as f32 / words.len() as f32)
309            })
310            .collect();
311        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
312        scored.truncate(limit);
313        scored
314    }
315
316    /// Get all memories of a specific type.
317    #[cfg(feature = "graph")]
318    pub fn by_type(&self, mem_type: MemoryType) -> Vec<String> {
319        let session = self.db.session();
320        let type_str = format!("{:?}", mem_type);
321        let query = gql::by_type(&type_str);
322        match session.execute(&query) {
323            Ok(result) => result
324                .iter()
325                .filter_map(|row| row.first().map(|v| format!("{}", v)))
326                .collect(),
327            Err(_) => Vec::new(),
328        }
329    }
330
331    /// Get memories tagged with a specific topic.
332    #[cfg(feature = "graph")]
333    pub fn by_topic(&self, topic: &str) -> Vec<String> {
334        let session = self.db.session();
335        let query = gql::by_topic(topic);
336        match session.execute(&query) {
337            Ok(result) => result
338                .iter()
339                .filter_map(|row| row.first().map(|v| format!("{}", v)))
340                .collect(),
341            Err(_) => Vec::new(),
342        }
343    }
344
345    /// Get graph statistics.
346    #[cfg(feature = "graph")]
347    pub fn stats(&self) -> GraphStats {
348        let session = self.db.session();
349        let count = |query: &str| -> usize {
350            session
351                .execute(query)
352                .ok()
353                .and_then(|r| r.scalar::<i64>().ok())
354                .map(|v| v as usize)
355                .unwrap_or(0)
356        };
357
358        GraphStats {
359            memory_count: count(gql::COUNT_MEMORIES),
360            session_count: count(gql::COUNT_SESSIONS),
361            topic_count: count(gql::COUNT_TOPICS),
362            relationship_count: count(gql::COUNT_RELATIONSHIPS),
363        }
364    }
365
366    /// Get the current schema version of the graph.
367    #[cfg(feature = "graph")]
368    pub fn schema_version(&self) -> VersionCheck {
369        graph_migrate::check_version(&self.db)
370    }
371
372    // ─── Fallback implementations (no graph feature) ─────────────────────
373
374    #[cfg(not(feature = "graph"))]
375    pub fn store_memory(&self, _: &str, _: MemoryType, _: f32) -> Result<String> {
376        Err(CerseiError::Config("Graph feature not enabled".into()))
377    }
378
379    #[cfg(not(feature = "graph"))]
380    pub fn recall_top_k(&self, _: &str, _: usize) -> Vec<(String, f32)> {
381        Vec::new()
382    }
383
384    #[cfg(not(feature = "graph"))]
385    pub fn link_memories(&self, _: &str, _: &str, _: &str) -> Result<()> {
386        Err(CerseiError::Config("Graph feature not enabled".into()))
387    }
388
389    #[cfg(not(feature = "graph"))]
390    pub fn tag_memory(&self, _: &str, _: &str) -> Result<()> {
391        Err(CerseiError::Config("Graph feature not enabled".into()))
392    }
393
394    #[cfg(not(feature = "graph"))]
395    pub fn record_session(&self, _: &str, _: Option<&str>, _: u32) -> Result<()> {
396        Err(CerseiError::Config("Graph feature not enabled".into()))
397    }
398
399    #[cfg(not(feature = "graph"))]
400    pub fn revalidate_memory(&self, _: &str) -> Result<bool> {
401        Err(CerseiError::Config("Graph feature not enabled".into()))
402    }
403
404    #[cfg(not(feature = "graph"))]
405    pub fn recall(&self, _: &str, _: usize) -> Vec<String> {
406        Vec::new()
407    }
408
409    #[cfg(not(feature = "graph"))]
410    pub fn by_type(&self, _: MemoryType) -> Vec<String> {
411        Vec::new()
412    }
413
414    #[cfg(not(feature = "graph"))]
415    pub fn by_topic(&self, _: &str) -> Vec<String> {
416        Vec::new()
417    }
418
419    #[cfg(not(feature = "graph"))]
420    pub fn stats(&self) -> GraphStats {
421        GraphStats::default()
422    }
423
424    #[cfg(not(feature = "graph"))]
425    pub fn schema_version(&self) -> VersionCheck {
426        VersionCheck::UpToDate
427    }
428}
429
430/// Check if graph memory is available (compiled with the feature).
431pub fn is_graph_available() -> bool {
432    cfg!(feature = "graph")
433}