Skip to main content

agentic_memory/
contracts.rs

1//! Contracts bridge — implements agentic-contracts v0.2.0 traits for Memory.
2//!
3//! This module provides `MemorySister`, a contracts-compliant wrapper
4//! around the core `MemoryGraph` + engines. It implements:
5//!
6//! - `Sister` — lifecycle management
7//! - `SessionManagement` — append-only sequential sessions
8//! - `Grounding` — BM25-based claim verification
9//! - `Queryable` — unified query interface
10//! - `FileFormatReader/FileFormatWriter` — .amem file I/O
11//!
12//! The MCP server can use `MemorySister` instead of raw graph + engines
13//! to get compile-time contracts compliance.
14
15use agentic_contracts::prelude::*;
16use std::path::{Path, PathBuf};
17use std::time::Instant;
18
19use crate::engine::text_search::TextSearchParams;
20use crate::engine::{QueryEngine, WriteEngine};
21use crate::graph::MemoryGraph;
22use crate::types::{AmemError, CognitiveEvent, DEFAULT_DIMENSION};
23
24// ═══════════════════════════════════════════════════════════════════
25// ERROR BRIDGE: AmemError → SisterError
26// ═══════════════════════════════════════════════════════════════════
27
28impl From<AmemError> for SisterError {
29    fn from(e: AmemError) -> Self {
30        match &e {
31            AmemError::NodeNotFound(id) => {
32                SisterError::not_found(format!("node {}", id))
33            }
34            AmemError::InvalidMagic => {
35                SisterError::new(ErrorCode::VersionMismatch, "Invalid .amem magic bytes")
36            }
37            AmemError::UnsupportedVersion(v) => {
38                SisterError::new(ErrorCode::VersionMismatch, format!("Unsupported .amem version: {}", v))
39            }
40            AmemError::ContentTooLarge { size, max } => {
41                SisterError::new(ErrorCode::InvalidInput, format!("Content too large: {} > {} bytes", size, max))
42            }
43            AmemError::DimensionMismatch { expected, got } => {
44                SisterError::new(ErrorCode::InvalidInput, format!("Dimension mismatch: expected {}, got {}", expected, got))
45            }
46            AmemError::InvalidConfidence(v) => {
47                SisterError::new(ErrorCode::InvalidInput, format!("Confidence must be [0.0, 1.0], got {}", v))
48            }
49            AmemError::Io(io_err) => {
50                SisterError::new(ErrorCode::StorageError, format!("I/O error: {}", io_err))
51            }
52            AmemError::Truncated => {
53                SisterError::new(ErrorCode::StorageError, "File is empty or truncated")
54            }
55            AmemError::Corrupt(offset) => {
56                SisterError::new(ErrorCode::ChecksumMismatch, format!("Corrupt data at offset {}", offset))
57            }
58            _ => {
59                SisterError::new(ErrorCode::MemoryError, e.to_string())
60            }
61        }
62    }
63}
64
65// ═══════════════════════════════════════════════════════════════════
66// SESSION STATE
67// ═══════════════════════════════════════════════════════════════════
68
69/// Session record for tracking sessions in MemorySister.
70#[derive(Debug, Clone)]
71struct SessionRecord {
72    id: ContextId,
73    session_id: u32,
74    name: String,
75    created_at: chrono::DateTime<chrono::Utc>,
76    node_count_at_start: usize,
77}
78
79// ═══════════════════════════════════════════════════════════════════
80// MEMORY SISTER — The contracts-compliant facade
81// ═══════════════════════════════════════════════════════════════════
82
83/// Contracts-compliant Memory sister.
84///
85/// Wraps `MemoryGraph` + engines and implements all v0.2.0 traits.
86/// This is the canonical "Memory as a sister" interface.
87pub struct MemorySister {
88    graph: MemoryGraph,
89    query_engine: QueryEngine,
90    write_engine: WriteEngine,
91    file_path: Option<PathBuf>,
92    start_time: Instant,
93
94    // Session state
95    current_session: Option<SessionRecord>,
96    sessions: Vec<SessionRecord>,
97    next_session_id: u32,
98}
99
100impl MemorySister {
101    /// Create from an existing graph (for migration from SessionManager).
102    pub fn from_graph(graph: MemoryGraph, file_path: Option<PathBuf>) -> Self {
103        let dimension = graph.dimension();
104        Self {
105            graph,
106            query_engine: QueryEngine::new(),
107            write_engine: WriteEngine::new(dimension),
108            file_path,
109            start_time: Instant::now(),
110            current_session: None,
111            sessions: vec![],
112            next_session_id: 1,
113        }
114    }
115
116    /// Get a reference to the underlying graph.
117    pub fn graph(&self) -> &MemoryGraph {
118        &self.graph
119    }
120
121    /// Get a mutable reference to the underlying graph.
122    pub fn graph_mut(&mut self) -> &mut MemoryGraph {
123        &mut self.graph
124    }
125
126    /// Get the query engine.
127    pub fn query_engine(&self) -> &QueryEngine {
128        &self.query_engine
129    }
130
131    /// Get the write engine.
132    pub fn write_engine(&self) -> &WriteEngine {
133        &self.write_engine
134    }
135
136    /// Get the current u32 session ID (for interop with existing code).
137    pub fn current_session_id(&self) -> Option<u32> {
138        self.current_session.as_ref().map(|s| s.session_id)
139    }
140}
141
142// ═══════════════════════════════════════════════════════════════════
143// SISTER TRAIT
144// ═══════════════════════════════════════════════════════════════════
145
146impl Sister for MemorySister {
147    const SISTER_TYPE: SisterType = SisterType::Memory;
148    const FILE_EXTENSION: &'static str = "amem";
149
150    fn init(config: SisterConfig) -> SisterResult<Self>
151    where
152        Self: Sized,
153    {
154        let dimension = config
155            .get_option::<usize>("dimension")
156            .unwrap_or(DEFAULT_DIMENSION);
157
158        let file_path = config.data_path.clone();
159
160        let graph = if let Some(ref path) = file_path {
161            if path.exists() {
162                #[cfg(feature = "format")]
163                {
164                    crate::format::AmemReader::read_from_file(path)
165                        .map_err(SisterError::from)?
166                }
167                #[cfg(not(feature = "format"))]
168                {
169                    MemoryGraph::new(dimension)
170                }
171            } else if config.create_if_missing {
172                MemoryGraph::new(dimension)
173            } else {
174                return Err(SisterError::new(
175                    ErrorCode::NotFound,
176                    format!("Memory file not found: {}", path.display()),
177                ));
178            }
179        } else {
180            MemoryGraph::new(dimension)
181        };
182
183        Ok(Self::from_graph(graph, file_path))
184    }
185
186    fn health(&self) -> HealthStatus {
187        HealthStatus {
188            healthy: true,
189            status: Status::Ready,
190            uptime: self.start_time.elapsed(),
191            resources: ResourceUsage {
192                memory_bytes: self.graph.node_count() * 256, // rough estimate
193                disk_bytes: 0,
194                open_handles: if self.file_path.is_some() { 1 } else { 0 },
195            },
196            warnings: vec![],
197            last_error: None,
198        }
199    }
200
201    fn version(&self) -> Version {
202        Version::new(0, 4, 1) // matches agentic-memory crate version
203    }
204
205    fn shutdown(&mut self) -> SisterResult<()> {
206        // End current session if active
207        if self.current_session.is_some() {
208            let _ = SessionManagement::end_session(self);
209        }
210
211        // Save to file if path is set
212        #[cfg(feature = "format")]
213        if let Some(ref path) = self.file_path {
214            let writer = crate::format::AmemWriter::new(self.graph.dimension());
215            writer.write_to_file(&self.graph, path)
216                .map_err(SisterError::from)?;
217        }
218
219        Ok(())
220    }
221
222    fn capabilities(&self) -> Vec<Capability> {
223        vec![
224            Capability::new("memory_add", "Add cognitive events to graph"),
225            Capability::new("memory_query", "Query memory by filters"),
226            Capability::new("memory_ground", "Verify claims against stored memories"),
227            Capability::new("memory_evidence", "Get detailed evidence for a query"),
228            Capability::new("memory_suggest", "Find similar memories when exact match fails"),
229            Capability::new("memory_similar", "Find semantically similar memories"),
230            Capability::new("memory_traverse", "Walk the graph following edges"),
231            Capability::new("memory_temporal", "Compare knowledge across time periods"),
232            Capability::new("memory_correct", "Record corrections to previous beliefs"),
233            Capability::new("conversation_log", "Log conversation context"),
234        ]
235    }
236}
237
238// ═══════════════════════════════════════════════════════════════════
239// SESSION MANAGEMENT
240// ═══════════════════════════════════════════════════════════════════
241
242impl SessionManagement for MemorySister {
243    fn start_session(&mut self, name: &str) -> SisterResult<ContextId> {
244        // End current session if active
245        if self.current_session.is_some() {
246            self.end_session()?;
247        }
248
249        let session_id = self.next_session_id;
250        self.next_session_id += 1;
251        let context_id = ContextId::new();
252
253        let record = SessionRecord {
254            id: context_id,
255            session_id,
256            name: name.to_string(),
257            created_at: chrono::Utc::now(),
258            node_count_at_start: self.graph.node_count(),
259        };
260
261        self.current_session = Some(record.clone());
262        self.sessions.push(record);
263
264        Ok(context_id)
265    }
266
267    fn end_session(&mut self) -> SisterResult<()> {
268        if self.current_session.is_none() {
269            return Err(SisterError::new(
270                ErrorCode::InvalidState,
271                "No active session to end",
272            ));
273        }
274        self.current_session = None;
275        Ok(())
276    }
277
278    fn current_session(&self) -> Option<ContextId> {
279        self.current_session.as_ref().map(|s| s.id)
280    }
281
282    fn current_session_info(&self) -> SisterResult<ContextInfo> {
283        let session = self.current_session.as_ref().ok_or_else(|| {
284            SisterError::new(ErrorCode::InvalidState, "No active session")
285        })?;
286
287        let nodes_in_session = self.graph.node_count() - session.node_count_at_start;
288
289        Ok(ContextInfo {
290            id: session.id,
291            name: session.name.clone(),
292            created_at: session.created_at,
293            updated_at: chrono::Utc::now(),
294            item_count: nodes_in_session,
295            size_bytes: nodes_in_session * 256,
296            metadata: Metadata::new(),
297        })
298    }
299
300    fn list_sessions(&self) -> SisterResult<Vec<ContextSummary>> {
301        Ok(self
302            .sessions
303            .iter()
304            .rev() // most recent first
305            .map(|s| ContextSummary {
306                id: s.id,
307                name: s.name.clone(),
308                created_at: s.created_at,
309                updated_at: s.created_at, // approximate
310                item_count: 0,            // would need per-session tracking
311                size_bytes: 0,
312            })
313            .collect())
314    }
315
316    fn export_session(&self, id: ContextId) -> SisterResult<ContextSnapshot> {
317        let session = self
318            .sessions
319            .iter()
320            .find(|s| s.id == id)
321            .ok_or_else(|| SisterError::context_not_found(id.to_string()))?;
322
323        // Export all nodes from this session
324        let session_nodes: Vec<&CognitiveEvent> = self
325            .graph
326            .nodes()
327            .iter()
328            .filter(|n| n.session_id == session.session_id)
329            .collect();
330
331        let data = serde_json::to_vec(&session_nodes)
332            .map_err(|e| SisterError::new(ErrorCode::Internal, e.to_string()))?;
333        let checksum = *blake3::hash(&data).as_bytes();
334
335        Ok(ContextSnapshot {
336            sister_type: SisterType::Memory,
337            version: Version::new(0, 4, 1),
338            context_info: ContextInfo {
339                id,
340                name: session.name.clone(),
341                created_at: session.created_at,
342                updated_at: chrono::Utc::now(),
343                item_count: session_nodes.len(),
344                size_bytes: data.len(),
345                metadata: Metadata::new(),
346            },
347            data,
348            checksum,
349            snapshot_at: chrono::Utc::now(),
350        })
351    }
352
353    fn import_session(&mut self, snapshot: ContextSnapshot) -> SisterResult<ContextId> {
354        if !snapshot.verify() {
355            return Err(SisterError::new(
356                ErrorCode::ChecksumMismatch,
357                "Session snapshot checksum verification failed",
358            ));
359        }
360
361        // Start a new session for the imported data
362        let context_id = self.start_session(&snapshot.context_info.name)?;
363
364        // Deserialize and ingest the nodes
365        let nodes: Vec<CognitiveEvent> = serde_json::from_slice(&snapshot.data)
366            .map_err(|e| SisterError::new(ErrorCode::InvalidInput, e.to_string()))?;
367
368        let session_id = self.current_session_id().unwrap_or(0);
369        // Re-tag nodes with the new session ID
370        let retagged: Vec<CognitiveEvent> = nodes
371            .into_iter()
372            .map(|mut n| {
373                n.session_id = session_id;
374                n
375            })
376            .collect();
377
378        self.write_engine
379            .ingest(&mut self.graph, retagged, vec![])
380            .map_err(SisterError::from)?;
381
382        Ok(context_id)
383    }
384}
385
386// ═══════════════════════════════════════════════════════════════════
387// GROUNDING
388// ═══════════════════════════════════════════════════════════════════
389
390impl Grounding for MemorySister {
391    fn ground(&self, claim: &str) -> SisterResult<GroundingResult> {
392        let params = TextSearchParams {
393            query: claim.to_string(),
394            max_results: 10,
395            event_types: vec![],
396            session_ids: vec![],
397            min_score: 0.3,
398        };
399
400        let matches = self
401            .query_engine
402            .text_search(
403                &self.graph,
404                self.graph.term_index.as_ref(),
405                self.graph.doc_lengths.as_ref(),
406                params,
407            )
408            .map_err(SisterError::from)?;
409
410        if matches.is_empty() {
411            return Ok(
412                GroundingResult::ungrounded(claim, "No matching memories found")
413                    .with_suggestions(
414                        self.graph
415                            .nodes()
416                            .iter()
417                            .rev()
418                            .take(3)
419                            .map(|n| n.content.clone())
420                            .collect(),
421                    ),
422            );
423        }
424
425        let best_score = matches
426            .iter()
427            .map(|m| m.score)
428            .fold(0.0f32, f32::max);
429
430        let evidence: Vec<GroundingEvidence> = matches
431            .iter()
432            .filter_map(|m| {
433                self.graph.get_node(m.node_id).map(|node| {
434                    GroundingEvidence::new(
435                        "memory_node",
436                        format!("node_{}", node.id),
437                        m.score as f64,
438                        &node.content,
439                    )
440                    .with_data("event_type", format!("{:?}", node.event_type))
441                    .with_data("session_id", node.session_id)
442                    .with_data("confidence", node.confidence)
443                    .with_data("matched_terms", m.matched_terms.clone())
444                })
445            })
446            .collect();
447
448        let confidence = best_score as f64;
449
450        if confidence > 0.5 {
451            Ok(GroundingResult::verified(claim, confidence)
452                .with_evidence(evidence)
453                .with_reason("Found matching memories via BM25 search"))
454        } else {
455            Ok(GroundingResult::partial(claim, confidence)
456                .with_evidence(evidence)
457                .with_reason("Some evidence found but low relevance"))
458        }
459    }
460
461    fn evidence(&self, query: &str, max_results: usize) -> SisterResult<Vec<EvidenceDetail>> {
462        let params = TextSearchParams {
463            query: query.to_string(),
464            max_results,
465            event_types: vec![],
466            session_ids: vec![],
467            min_score: 0.0,
468        };
469
470        let matches = self
471            .query_engine
472            .text_search(
473                &self.graph,
474                self.graph.term_index.as_ref(),
475                self.graph.doc_lengths.as_ref(),
476                params,
477            )
478            .map_err(SisterError::from)?;
479
480        Ok(matches
481            .iter()
482            .filter_map(|m| {
483                self.graph.get_node(m.node_id).map(|node| {
484                    let created_at = chrono::DateTime::from_timestamp_micros(node.created_at as i64)
485                        .unwrap_or_default();
486
487                    EvidenceDetail {
488                        evidence_type: "memory_node".to_string(),
489                        id: format!("node_{}", node.id),
490                        score: m.score as f64,
491                        created_at,
492                        source_sister: SisterType::Memory,
493                        content: node.content.clone(),
494                        data: {
495                            let mut meta = Metadata::new();
496                            if let Ok(v) = serde_json::to_value(format!("{:?}", node.event_type)) {
497                                meta.insert("event_type".to_string(), v);
498                            }
499                            if let Ok(v) = serde_json::to_value(node.session_id) {
500                                meta.insert("session_id".to_string(), v);
501                            }
502                            if let Ok(v) = serde_json::to_value(node.confidence) {
503                                meta.insert("confidence".to_string(), v);
504                            }
505                            meta
506                        },
507                    }
508                })
509            })
510            .collect())
511    }
512
513    fn suggest(&self, query: &str, limit: usize) -> SisterResult<Vec<GroundingSuggestion>> {
514        // Word-overlap fallback (similar to existing memory_suggest tool)
515        let query_lower = query.to_lowercase();
516        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
517
518        let mut scored: Vec<(f64, &CognitiveEvent)> = self
519            .graph
520            .nodes()
521            .iter()
522            .map(|node| {
523                let content_lower = node.content.to_lowercase();
524                let matched = query_words
525                    .iter()
526                    .filter(|w| content_lower.contains(**w))
527                    .count();
528                let score = if query_words.is_empty() {
529                    0.0
530                } else {
531                    matched as f64 / query_words.len() as f64
532                };
533                (score, node)
534            })
535            .filter(|(score, _)| *score > 0.0)
536            .collect();
537
538        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
539
540        Ok(scored
541            .into_iter()
542            .take(limit)
543            .map(|(score, node)| GroundingSuggestion {
544                item_type: "memory_node".to_string(),
545                id: format!("node_{}", node.id),
546                relevance_score: score,
547                description: node.content.clone(),
548                data: Metadata::new(),
549            })
550            .collect())
551    }
552}
553
554// ═══════════════════════════════════════════════════════════════════
555// QUERYABLE
556// ═══════════════════════════════════════════════════════════════════
557
558impl Queryable for MemorySister {
559    fn query(&self, query: Query) -> SisterResult<QueryResult> {
560        let start = Instant::now();
561
562        let results: Vec<serde_json::Value> = match query.query_type.as_str() {
563            "list" => {
564                let limit = query.limit.unwrap_or(50);
565                let offset = query.offset.unwrap_or(0);
566                self.graph
567                    .nodes()
568                    .iter()
569                    .skip(offset)
570                    .take(limit)
571                    .map(|n| {
572                        serde_json::json!({
573                            "id": n.id,
574                            "event_type": format!("{:?}", n.event_type),
575                            "content": n.content,
576                            "confidence": n.confidence,
577                            "session_id": n.session_id,
578                            "created_at": n.created_at,
579                        })
580                    })
581                    .collect()
582            }
583            "search" => {
584                let text = query.get_string("text").unwrap_or_default();
585                let max = query.limit.unwrap_or(20);
586
587                let params = TextSearchParams {
588                    query: text,
589                    max_results: max,
590                    event_types: vec![],
591                    session_ids: vec![],
592                    min_score: 0.0,
593                };
594
595                let matches = self
596                    .query_engine
597                    .text_search(
598                        &self.graph,
599                        self.graph.term_index.as_ref(),
600                        self.graph.doc_lengths.as_ref(),
601                        params,
602                    )
603                    .map_err(SisterError::from)?;
604
605                matches
606                    .iter()
607                    .filter_map(|m| {
608                        self.graph.get_node(m.node_id).map(|n| {
609                            serde_json::json!({
610                                "id": n.id,
611                                "content": n.content,
612                                "score": m.score,
613                                "matched_terms": m.matched_terms,
614                            })
615                        })
616                    })
617                    .collect()
618            }
619            "recent" => {
620                let count = query.limit.unwrap_or(10);
621                self.graph
622                    .nodes()
623                    .iter()
624                    .rev()
625                    .take(count)
626                    .map(|n| {
627                        serde_json::json!({
628                            "id": n.id,
629                            "event_type": format!("{:?}", n.event_type),
630                            "content": n.content,
631                            "confidence": n.confidence,
632                            "session_id": n.session_id,
633                            "created_at": n.created_at,
634                        })
635                    })
636                    .collect()
637            }
638            "get" => {
639                let id_str = query.get_string("id").unwrap_or_default();
640                let id: u64 = id_str.parse().unwrap_or(0);
641                if let Some(n) = self.graph.get_node(id) {
642                    vec![serde_json::json!({
643                        "id": n.id,
644                        "event_type": format!("{:?}", n.event_type),
645                        "content": n.content,
646                        "confidence": n.confidence,
647                        "session_id": n.session_id,
648                        "created_at": n.created_at,
649                    })]
650                } else {
651                    vec![]
652                }
653            }
654            _ => vec![],
655        };
656
657        let total = self.graph.node_count();
658        let has_more = results.len() < total;
659
660        Ok(QueryResult::new(query, results, start.elapsed()).with_pagination(total, has_more))
661    }
662
663    fn supports_query(&self, query_type: &str) -> bool {
664        matches!(
665            query_type,
666            "list" | "search" | "recent" | "get" | "related" | "temporal"
667        )
668    }
669
670    fn query_types(&self) -> Vec<QueryTypeInfo> {
671        vec![
672            QueryTypeInfo::new("list", "List all memory nodes with pagination")
673                .optional(vec!["limit", "offset"]),
674            QueryTypeInfo::new("search", "Search memories by text (BM25)")
675                .required(vec!["text"])
676                .optional(vec!["limit"]),
677            QueryTypeInfo::new("recent", "Get most recent memories")
678                .optional(vec!["limit"]),
679            QueryTypeInfo::new("get", "Get a specific memory node by ID")
680                .required(vec!["id"]),
681        ]
682    }
683}
684
685// ═══════════════════════════════════════════════════════════════════
686// FILE FORMAT
687// ═══════════════════════════════════════════════════════════════════
688
689#[cfg(feature = "format")]
690impl FileFormatReader for MemorySister {
691    fn read_file(path: &Path) -> SisterResult<Self> {
692        let graph = crate::format::AmemReader::read_from_file(path)
693            .map_err(SisterError::from)?;
694        Ok(Self::from_graph(graph, Some(path.to_path_buf())))
695    }
696
697    fn can_read(path: &Path) -> SisterResult<FileInfo> {
698        // Read just the 64-byte header to check format validity
699        let data = std::fs::read(path)
700            .map_err(|e| SisterError::new(ErrorCode::StorageError, e.to_string()))?;
701        if data.len() < 64 {
702            return Err(SisterError::new(
703                ErrorCode::StorageError,
704                "File too small for .amem format",
705            ));
706        }
707        let header = crate::types::FileHeader::read_from(
708            &mut std::io::Cursor::new(&data[..64]),
709        )
710        .map_err(SisterError::from)?;
711
712        let metadata = std::fs::metadata(path)
713            .map_err(|e| SisterError::new(ErrorCode::StorageError, e.to_string()))?;
714
715        Ok(FileInfo {
716            sister_type: SisterType::Memory,
717            version: Version::new(header.version as u8, 0, 0),
718            created_at: chrono::Utc::now(), // .amem doesn't store creation time in header
719            updated_at: chrono::DateTime::from(
720                metadata.modified().unwrap_or(std::time::SystemTime::UNIX_EPOCH),
721            ),
722            content_length: metadata.len(),
723            needs_migration: header.version < crate::types::FORMAT_VERSION,
724            format_id: "AMEM".to_string(),
725        })
726    }
727
728    fn file_version(path: &Path) -> SisterResult<Version> {
729        let data = std::fs::read(path)
730            .map_err(|e| SisterError::new(ErrorCode::StorageError, e.to_string()))?;
731        if data.len() < 64 {
732            return Err(SisterError::new(
733                ErrorCode::StorageError,
734                "File too small for .amem format",
735            ));
736        }
737        let header = crate::types::FileHeader::read_from(
738            &mut std::io::Cursor::new(&data[..64]),
739        )
740        .map_err(SisterError::from)?;
741        Ok(Version::new(header.version as u8, 0, 0))
742    }
743
744    fn migrate(_data: &[u8], _from_version: Version) -> SisterResult<Vec<u8>> {
745        // Memory format v1 is the only version — no migration needed yet
746        Err(SisterError::new(
747            ErrorCode::NotImplemented,
748            "No migration path available (only v1 exists)",
749        ))
750    }
751}
752
753#[cfg(feature = "format")]
754impl FileFormatWriter for MemorySister {
755    fn write_file(&self, path: &Path) -> SisterResult<()> {
756        let writer = crate::format::AmemWriter::new(self.graph.dimension());
757        writer.write_to_file(&self.graph, path)
758            .map_err(SisterError::from)
759    }
760
761    fn to_bytes(&self) -> SisterResult<Vec<u8>> {
762        let writer = crate::format::AmemWriter::new(self.graph.dimension());
763        let mut buffer = Vec::new();
764        writer.write_to(&self.graph, &mut buffer)
765            .map_err(SisterError::from)?;
766        Ok(buffer)
767    }
768}
769
770// ═══════════════════════════════════════════════════════════════════
771// TESTS
772// ═══════════════════════════════════════════════════════════════════
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use crate::types::{CognitiveEventBuilder, EventType};
778
779    fn make_test_sister() -> MemorySister {
780        let config = SisterConfig::stateless()
781            .option("dimension", DEFAULT_DIMENSION);
782        MemorySister::init(config).unwrap()
783    }
784
785    fn add_test_nodes(sister: &mut MemorySister) {
786        let session_id = sister.current_session_id().unwrap_or(0);
787        let events = vec![
788            CognitiveEventBuilder::new(EventType::Fact, "The sky is blue")
789                .confidence(0.95)
790                .session_id(session_id)
791                .build(),
792            CognitiveEventBuilder::new(EventType::Fact, "Rust is fast and memory safe")
793                .confidence(0.9)
794                .session_id(session_id)
795                .build(),
796            CognitiveEventBuilder::new(EventType::Decision, "Use BM25 for text search")
797                .confidence(0.85)
798                .session_id(session_id)
799                .build(),
800        ];
801        sister.write_engine.ingest(&mut sister.graph, events, vec![]).unwrap();
802    }
803
804    /// Helper to build BM25 term index and doc lengths for text search tests.
805    fn build_indexes(sister: &mut MemorySister) {
806        use crate::engine::Tokenizer;
807        use crate::index::{DocLengths, TermIndex};
808        let tokenizer = Tokenizer::new();
809        let term_index = TermIndex::build(&sister.graph, &tokenizer);
810        sister.graph.set_term_index(term_index);
811        let doc_lengths = DocLengths::build(&sister.graph, &tokenizer);
812        sister.graph.set_doc_lengths(doc_lengths);
813    }
814
815    #[test]
816    fn test_sister_trait() {
817        let sister = make_test_sister();
818        assert_eq!(sister.sister_type(), SisterType::Memory);
819        assert_eq!(sister.file_extension(), "amem");
820        assert_eq!(sister.mcp_prefix(), "memory");
821        assert!(sister.is_healthy());
822        assert_eq!(sister.version(), Version::new(0, 4, 1));
823        assert!(!sister.capabilities().is_empty());
824    }
825
826    #[test]
827    fn test_sister_info() {
828        let sister = make_test_sister();
829        let info = SisterInfo::from_sister(&sister);
830        assert_eq!(info.sister_type, SisterType::Memory);
831        assert_eq!(info.file_extension, "amem");
832        assert_eq!(info.mcp_prefix, "memory");
833    }
834
835    #[test]
836    fn test_session_management() {
837        let mut sister = make_test_sister();
838
839        // No session initially
840        assert!(sister.current_session().is_none());
841        assert!(sister.current_session_info().is_err());
842
843        // Start session
844        let sid = sister.start_session("test_session").unwrap();
845        assert!(sister.current_session().is_some());
846        assert_eq!(sister.current_session().unwrap(), sid);
847
848        // Session info
849        let info = sister.current_session_info().unwrap();
850        assert_eq!(info.name, "test_session");
851
852        // List sessions
853        let sessions = sister.list_sessions().unwrap();
854        assert_eq!(sessions.len(), 1);
855        assert_eq!(sessions[0].name, "test_session");
856
857        // End session
858        sister.end_session().unwrap();
859        assert!(sister.current_session().is_none());
860
861        // Can't end twice
862        assert!(sister.end_session().is_err());
863    }
864
865    #[test]
866    fn test_grounding_with_data() {
867        let mut sister = make_test_sister();
868        sister.start_session("grounding_test").unwrap();
869        add_test_nodes(&mut sister);
870
871        // Ensure term index is built for BM25
872        build_indexes(&mut sister);
873
874        // Ground a claim that should match
875        let result = sister.ground("sky is blue").unwrap();
876        assert!(
877            result.status == GroundingStatus::Verified
878                || result.status == GroundingStatus::Partial,
879            "Expected verified or partial, got {:?}",
880            result.status
881        );
882        assert!(!result.evidence.is_empty());
883
884        // Ground a claim that should NOT match
885        let result = sister.ground("cats can teleport").unwrap();
886        assert_eq!(result.status, GroundingStatus::Ungrounded);
887    }
888
889    #[test]
890    fn test_evidence_query() {
891        let mut sister = make_test_sister();
892        sister.start_session("evidence_test").unwrap();
893        add_test_nodes(&mut sister);
894        build_indexes(&mut sister);
895
896        let evidence = sister.evidence("rust", 10).unwrap();
897        // BM25 should find the "Rust is fast" node
898        assert!(
899            !evidence.is_empty(),
900            "Expected evidence for 'rust' query"
901        );
902        assert_eq!(evidence[0].source_sister, SisterType::Memory);
903    }
904
905    #[test]
906    fn test_suggest_fallback() {
907        let mut sister = make_test_sister();
908        sister.start_session("suggest_test").unwrap();
909        add_test_nodes(&mut sister);
910
911        let suggestions = sister.suggest("blue sky", 5).unwrap();
912        assert!(!suggestions.is_empty());
913        assert!(suggestions[0].relevance_score > 0.0);
914    }
915
916    #[test]
917    fn test_queryable_list() {
918        let mut sister = make_test_sister();
919        sister.start_session("query_test").unwrap();
920        add_test_nodes(&mut sister);
921
922        let result = sister.query(Query::list().limit(2)).unwrap();
923        assert_eq!(result.len(), 2);
924        assert!(result.has_more);
925    }
926
927    #[test]
928    fn test_queryable_recent() {
929        let mut sister = make_test_sister();
930        sister.start_session("recent_test").unwrap();
931        add_test_nodes(&mut sister);
932
933        let result = sister.recent(2).unwrap();
934        assert_eq!(result.len(), 2);
935    }
936
937    #[test]
938    fn test_queryable_search() {
939        let mut sister = make_test_sister();
940        sister.start_session("search_test").unwrap();
941        add_test_nodes(&mut sister);
942        build_indexes(&mut sister);
943
944        let result = sister.search("rust").unwrap();
945        assert!(
946            !result.is_empty(),
947            "Expected search results for 'rust'"
948        );
949    }
950
951    #[test]
952    fn test_queryable_types() {
953        let sister = make_test_sister();
954        assert!(sister.supports_query("list"));
955        assert!(sister.supports_query("search"));
956        assert!(sister.supports_query("recent"));
957        assert!(sister.supports_query("get"));
958        assert!(!sister.supports_query("nonexistent"));
959
960        let types = sister.query_types();
961        assert_eq!(types.len(), 4);
962    }
963
964    #[test]
965    fn test_error_bridge() {
966        let amem_err = AmemError::NodeNotFound(42);
967        let sister_err: SisterError = amem_err.into();
968        assert_eq!(sister_err.code, ErrorCode::NotFound);
969        assert!(sister_err.message.contains("42"));
970
971        let amem_err2 = AmemError::InvalidMagic;
972        let sister_err2: SisterError = amem_err2.into();
973        assert_eq!(sister_err2.code, ErrorCode::VersionMismatch);
974    }
975
976    #[test]
977    fn test_session_export_import() {
978        let mut sister = make_test_sister();
979        let sid = sister.start_session("export_test").unwrap();
980        add_test_nodes(&mut sister);
981
982        // Export
983        let snapshot = sister.export_session(sid).unwrap();
984        assert!(snapshot.verify());
985        assert_eq!(snapshot.sister_type, SisterType::Memory);
986
987        // Import into fresh sister
988        let mut sister2 = make_test_sister();
989        let _imported_sid = sister2.import_session(snapshot).unwrap();
990        assert!(sister2.current_session().is_some());
991        // Imported session should have nodes
992        assert!(sister2.graph().node_count() > 0);
993    }
994
995    #[test]
996    fn test_config_patterns() {
997        // Single path config
998        let config = SisterConfig::new("/tmp/test.amem");
999        let sister = MemorySister::init(config).unwrap();
1000        assert!(sister.is_healthy());
1001
1002        // Stateless config
1003        let config2 = SisterConfig::stateless();
1004        let sister2 = MemorySister::init(config2).unwrap();
1005        assert!(sister2.is_healthy());
1006    }
1007
1008    #[test]
1009    fn test_shutdown() {
1010        let mut sister = make_test_sister();
1011        sister.start_session("shutdown_test").unwrap();
1012        sister.shutdown().unwrap();
1013        // Session should be ended after shutdown
1014        assert!(sister.current_session().is_none());
1015    }
1016}