Skip to main content

mag/memory_core/
mod.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use uuid::Uuid;
4
5mod domain;
6mod traits;
7
8pub use domain::*;
9pub use traits::*;
10
11pub mod embedder;
12pub mod reranker;
13pub mod scoring;
14pub mod storage;
15
16#[cfg(feature = "real-embeddings")]
17#[allow(unused_imports)]
18pub use embedder::OnnxEmbedder;
19#[allow(unused_imports)]
20pub use embedder::{Embedder, PlaceholderEmbedder};
21#[allow(unused_imports)]
22pub use scoring::{
23    ABSTENTION_MIN_TEXT, GRAPH_MIN_EDGE_WEIGHT, GRAPH_NEIGHBOR_FACTOR, RRF_WEIGHT_FTS,
24    RRF_WEIGHT_VEC, ScoringParams, feedback_factor, jaccard_pre, jaccard_similarity,
25    priority_factor, time_decay_et, type_weight_et, word_overlap_pre,
26};
27#[allow(unused_imports)]
28pub(crate) use scoring::{is_stopword, simple_stem, token_set};
29
30/// Orchestrates the memory pipeline by coordinating ingestors, processors, and storage.
31pub struct Pipeline {
32    ingestor: Box<dyn Ingestor>,
33    processor: Box<dyn Processor>,
34    storage: Box<dyn Storage>,
35    retriever: Box<dyn Retriever>,
36    searcher: Box<dyn Searcher>,
37    recents: Box<dyn Recents>,
38    semantic_searcher: Box<dyn SemanticSearcher>,
39}
40
41impl Pipeline {
42    /// Creates a new Pipeline with the provided components.
43    pub fn new(
44        ingestor: Box<dyn Ingestor>,
45        processor: Box<dyn Processor>,
46        storage: Box<dyn Storage>,
47        retriever: Box<dyn Retriever>,
48        searcher: Box<dyn Searcher>,
49        recents: Box<dyn Recents>,
50        semantic_searcher: Box<dyn SemanticSearcher>,
51    ) -> Self {
52        Self {
53            ingestor,
54            processor,
55            storage,
56            retriever,
57            searcher,
58            recents,
59            semantic_searcher,
60        }
61    }
62
63    /// Runs the full pipeline: ingest -> process -> store.
64    pub async fn run(&self, content: &str, input: &MemoryInput) -> Result<String> {
65        let id = input
66            .id
67            .clone()
68            .unwrap_or_else(|| Uuid::new_v4().to_string());
69        let mut store_input = input.clone();
70        if store_input.id.is_none() {
71            store_input.id = Some(id.clone());
72        }
73        let content_to_ingest = if content.is_empty() {
74            input.content.as_str()
75        } else {
76            content
77        };
78        let ingested = self.ingestor.ingest(content_to_ingest).await?;
79        let processed = self.processor.process(&ingested).await?;
80        self.storage.store(&id, &processed, &store_input).await?;
81        Ok(id)
82    }
83
84    /// Retrieves data from storage via the retriever.
85    pub async fn retrieve(&self, id: &str) -> Result<String> {
86        self.retriever.retrieve(id).await
87    }
88
89    /// Searches for stored memories matching the provided query.
90    pub async fn search(
91        &self,
92        query: &str,
93        limit: usize,
94        opts: &SearchOptions,
95    ) -> Result<Vec<SearchResult>> {
96        self.searcher.search(query, limit, opts).await
97    }
98
99    pub async fn recent(&self, limit: usize, opts: &SearchOptions) -> Result<Vec<SearchResult>> {
100        self.recents.recent(limit, opts).await
101    }
102
103    pub async fn semantic_search(
104        &self,
105        query: &str,
106        limit: usize,
107        opts: &SearchOptions,
108    ) -> Result<Vec<SemanticResult>> {
109        self.semantic_searcher
110            .semantic_search(query, limit, opts)
111            .await
112    }
113}
114
115/// A placeholder implementation of the memory pipeline for development and testing.
116pub struct PlaceholderPipeline;
117
118#[async_trait]
119impl Ingestor for PlaceholderPipeline {
120    async fn ingest(&self, content: &str) -> Result<String> {
121        Ok(content.to_string())
122    }
123}
124
125#[async_trait]
126impl Processor for PlaceholderPipeline {
127    async fn process(&self, input: &str) -> Result<String> {
128        Ok(format!("processed: {}", input))
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use anyhow::anyhow;
136    use serde_json::json;
137
138    struct MockPipeline;
139
140    #[async_trait]
141    impl Ingestor for MockPipeline {
142        async fn ingest(&self, content: &str) -> Result<String> {
143            Ok(content.to_string())
144        }
145    }
146
147    #[async_trait]
148    impl Processor for MockPipeline {
149        async fn process(&self, input: &str) -> Result<String> {
150            Ok(format!("processed: {}", input))
151        }
152    }
153
154    #[async_trait]
155    impl Storage for MockPipeline {
156        async fn store(&self, _id: &str, _data: &str, _input: &MemoryInput) -> Result<()> {
157            Ok(())
158        }
159    }
160
161    #[async_trait]
162    impl Retriever for MockPipeline {
163        async fn retrieve(&self, id: &str) -> Result<String> {
164            Ok(format!("retrieved: {}", id))
165        }
166    }
167
168    #[async_trait]
169    impl Searcher for MockPipeline {
170        async fn search(
171            &self,
172            query: &str,
173            _limit: usize,
174            _opts: &SearchOptions,
175        ) -> Result<Vec<SearchResult>> {
176            Ok(vec![SearchResult {
177                id: "result-1".to_string(),
178                content: format!("match: {query}"),
179                tags: Vec::new(),
180                importance: 0.5,
181                metadata: json!({}),
182                event_type: None,
183                session_id: None,
184                project: None,
185                entity_id: None,
186                agent_type: None,
187            }])
188        }
189    }
190
191    #[async_trait]
192    impl Recents for MockPipeline {
193        async fn recent(&self, _limit: usize, _opts: &SearchOptions) -> Result<Vec<SearchResult>> {
194            Ok(vec![SearchResult {
195                id: "recent-1".to_string(),
196                content: "recent value".to_string(),
197                tags: Vec::new(),
198                importance: 0.5,
199                metadata: json!({}),
200                event_type: None,
201                session_id: None,
202                project: None,
203                entity_id: None,
204                agent_type: None,
205            }])
206        }
207    }
208
209    #[async_trait]
210    impl SemanticSearcher for MockPipeline {
211        async fn semantic_search(
212            &self,
213            query: &str,
214            _limit: usize,
215            _opts: &SearchOptions,
216        ) -> Result<Vec<SemanticResult>> {
217            Ok(vec![SemanticResult {
218                id: "semantic-1".to_string(),
219                content: format!("semantic match: {query}"),
220                tags: Vec::new(),
221                importance: 0.5,
222                metadata: json!({}),
223                event_type: None,
224                session_id: None,
225                project: None,
226                entity_id: None,
227                agent_type: None,
228                score: 0.99,
229            }])
230        }
231    }
232
233    struct FailingIngestor;
234
235    #[async_trait]
236    impl Ingestor for FailingIngestor {
237        async fn ingest(&self, _content: &str) -> Result<String> {
238            Err(anyhow!("Ingestion failed"))
239        }
240    }
241
242    #[tokio::test]
243    async fn test_ingestor_trait() {
244        let ingestor: Box<dyn Ingestor> = Box::new(MockPipeline);
245        let result = ingestor.ingest("test").await.unwrap();
246        assert_eq!(result, "test");
247    }
248
249    #[tokio::test]
250    async fn test_pipeline_run_success() {
251        let pipeline = Pipeline::new(
252            Box::new(MockPipeline),
253            Box::new(MockPipeline),
254            Box::new(MockPipeline),
255            Box::new(MockPipeline),
256            Box::new(MockPipeline),
257            Box::new(MockPipeline),
258            Box::new(MockPipeline),
259        );
260
261        let input = MemoryInput {
262            id: Some("custom_id".to_string()),
263            content: "hello".to_string(),
264            importance: 0.5,
265            metadata: json!({}),
266            ..Default::default()
267        };
268        let result = pipeline.run("hello", &input).await;
269        assert!(result.is_ok());
270        assert_eq!(result.unwrap(), "custom_id");
271    }
272
273    #[tokio::test]
274    async fn test_pipeline_run_default_id() {
275        let pipeline = Pipeline::new(
276            Box::new(MockPipeline),
277            Box::new(MockPipeline),
278            Box::new(MockPipeline),
279            Box::new(MockPipeline),
280            Box::new(MockPipeline),
281            Box::new(MockPipeline),
282            Box::new(MockPipeline),
283        );
284
285        let input = MemoryInput {
286            content: "hello".to_string(),
287            importance: 0.5,
288            metadata: json!({}),
289            ..Default::default()
290        };
291        let result = pipeline.run("hello", &input).await;
292        assert!(result.is_ok());
293        let id = result.unwrap();
294        assert!(uuid::Uuid::parse_str(&id).is_ok());
295    }
296
297    #[tokio::test]
298    async fn test_pipeline_retrieve_success() {
299        let pipeline = Pipeline::new(
300            Box::new(MockPipeline),
301            Box::new(MockPipeline),
302            Box::new(MockPipeline),
303            Box::new(MockPipeline),
304            Box::new(MockPipeline),
305            Box::new(MockPipeline),
306            Box::new(MockPipeline),
307        );
308
309        let result = pipeline.retrieve("test_id").await;
310        assert!(result.is_ok());
311        assert_eq!(result.unwrap(), "retrieved: test_id");
312    }
313
314    #[tokio::test]
315    async fn test_pipeline_failure() {
316        let pipeline = Pipeline::new(
317            Box::new(FailingIngestor),
318            Box::new(MockPipeline),
319            Box::new(MockPipeline),
320            Box::new(MockPipeline),
321            Box::new(MockPipeline),
322            Box::new(MockPipeline),
323            Box::new(MockPipeline),
324        );
325
326        let input = MemoryInput {
327            content: "hello".to_string(),
328            importance: 0.5,
329            metadata: json!({}),
330            ..Default::default()
331        };
332        let result = pipeline.run("hello", &input).await;
333        assert!(result.is_err());
334        assert_eq!(result.unwrap_err().to_string(), "Ingestion failed");
335    }
336
337    #[tokio::test]
338    async fn test_pipeline_search_success() {
339        let pipeline = Pipeline::new(
340            Box::new(MockPipeline),
341            Box::new(MockPipeline),
342            Box::new(MockPipeline),
343            Box::new(MockPipeline),
344            Box::new(MockPipeline),
345            Box::new(MockPipeline),
346            Box::new(MockPipeline),
347        );
348
349        let results = pipeline
350            .search("needle", 5, &SearchOptions::default())
351            .await
352            .unwrap();
353        assert_eq!(results.len(), 1);
354        assert_eq!(results[0].id, "result-1");
355        assert_eq!(results[0].content, "match: needle");
356        assert!(results[0].tags.is_empty());
357        assert_eq!(results[0].importance, 0.5);
358        assert_eq!(results[0].metadata, json!({}));
359    }
360
361    #[tokio::test]
362    async fn test_pipeline_recent_success() {
363        let pipeline = Pipeline::new(
364            Box::new(MockPipeline),
365            Box::new(MockPipeline),
366            Box::new(MockPipeline),
367            Box::new(MockPipeline),
368            Box::new(MockPipeline),
369            Box::new(MockPipeline),
370            Box::new(MockPipeline),
371        );
372
373        let results = pipeline.recent(3, &SearchOptions::default()).await.unwrap();
374        assert_eq!(results.len(), 1);
375        assert_eq!(results[0].id, "recent-1");
376        assert_eq!(results[0].content, "recent value");
377        assert!(results[0].tags.is_empty());
378        assert_eq!(results[0].importance, 0.5);
379        assert_eq!(results[0].metadata, json!({}));
380    }
381
382    #[tokio::test]
383    async fn test_pipeline_semantic_search_success() {
384        let pipeline = Pipeline::new(
385            Box::new(MockPipeline),
386            Box::new(MockPipeline),
387            Box::new(MockPipeline),
388            Box::new(MockPipeline),
389            Box::new(MockPipeline),
390            Box::new(MockPipeline),
391            Box::new(MockPipeline),
392        );
393
394        let results = pipeline
395            .semantic_search("vector", 4, &SearchOptions::default())
396            .await
397            .unwrap();
398        assert_eq!(results.len(), 1);
399        assert_eq!(results[0].id, "semantic-1");
400        assert_eq!(results[0].content, "semantic match: vector");
401        assert!(results[0].tags.is_empty());
402        assert_eq!(results[0].importance, 0.5);
403        assert_eq!(results[0].metadata, json!({}));
404        assert!(results[0].score > 0.9);
405    }
406
407    #[test]
408    fn test_memory_kind_for_semantic_event_type() {
409        assert_eq!(EventType::Decision.memory_kind(), MemoryKind::Semantic);
410    }
411
412    #[test]
413    fn test_memory_kind_defaults_to_episodic_for_unknown_type() {
414        assert_eq!(
415            EventType::Unknown("totally_unknown".to_string()).memory_kind(),
416            MemoryKind::Episodic
417        );
418    }
419}