1use anyhow::Result;
2use async_trait::async_trait;
3use uuid::Uuid;
4
5mod domain;
6mod traits;
7
8pub use domain::*;
9pub use traits::*;
10
11pub mod embedder;
12pub mod reranker;
13pub mod scoring;
14pub mod storage;
15
16#[cfg(feature = "real-embeddings")]
17#[allow(unused_imports)]
18pub use embedder::OnnxEmbedder;
19#[allow(unused_imports)]
20pub use embedder::{Embedder, PlaceholderEmbedder};
21#[allow(unused_imports)]
22pub use scoring::{
23 ABSTENTION_MIN_TEXT, GRAPH_MIN_EDGE_WEIGHT, GRAPH_NEIGHBOR_FACTOR, RRF_WEIGHT_FTS,
24 RRF_WEIGHT_VEC, ScoringParams, feedback_factor, jaccard_pre, jaccard_similarity,
25 priority_factor, time_decay_et, type_weight_et, word_overlap_pre,
26};
27#[allow(unused_imports)]
28pub(crate) use scoring::{is_stopword, simple_stem, token_set};
29
30pub struct Pipeline {
32 ingestor: Box<dyn Ingestor>,
33 processor: Box<dyn Processor>,
34 storage: Box<dyn Storage>,
35 retriever: Box<dyn Retriever>,
36 searcher: Box<dyn Searcher>,
37 recents: Box<dyn Recents>,
38 semantic_searcher: Box<dyn SemanticSearcher>,
39}
40
41impl Pipeline {
42 pub fn new(
44 ingestor: Box<dyn Ingestor>,
45 processor: Box<dyn Processor>,
46 storage: Box<dyn Storage>,
47 retriever: Box<dyn Retriever>,
48 searcher: Box<dyn Searcher>,
49 recents: Box<dyn Recents>,
50 semantic_searcher: Box<dyn SemanticSearcher>,
51 ) -> Self {
52 Self {
53 ingestor,
54 processor,
55 storage,
56 retriever,
57 searcher,
58 recents,
59 semantic_searcher,
60 }
61 }
62
63 pub async fn run(&self, content: &str, input: &MemoryInput) -> Result<String> {
65 let id = input
66 .id
67 .clone()
68 .unwrap_or_else(|| Uuid::new_v4().to_string());
69 let mut store_input = input.clone();
70 if store_input.id.is_none() {
71 store_input.id = Some(id.clone());
72 }
73 let content_to_ingest = if content.is_empty() {
74 input.content.as_str()
75 } else {
76 content
77 };
78 let ingested = self.ingestor.ingest(content_to_ingest).await?;
79 let processed = self.processor.process(&ingested).await?;
80 self.storage.store(&id, &processed, &store_input).await?;
81 Ok(id)
82 }
83
84 pub async fn retrieve(&self, id: &str) -> Result<String> {
86 self.retriever.retrieve(id).await
87 }
88
89 pub async fn search(
91 &self,
92 query: &str,
93 limit: usize,
94 opts: &SearchOptions,
95 ) -> Result<Vec<SearchResult>> {
96 self.searcher.search(query, limit, opts).await
97 }
98
99 pub async fn recent(&self, limit: usize, opts: &SearchOptions) -> Result<Vec<SearchResult>> {
100 self.recents.recent(limit, opts).await
101 }
102
103 pub async fn semantic_search(
104 &self,
105 query: &str,
106 limit: usize,
107 opts: &SearchOptions,
108 ) -> Result<Vec<SemanticResult>> {
109 self.semantic_searcher
110 .semantic_search(query, limit, opts)
111 .await
112 }
113}
114
115pub struct PlaceholderPipeline;
117
118#[async_trait]
119impl Ingestor for PlaceholderPipeline {
120 async fn ingest(&self, content: &str) -> Result<String> {
121 Ok(content.to_string())
122 }
123}
124
125#[async_trait]
126impl Processor for PlaceholderPipeline {
127 async fn process(&self, input: &str) -> Result<String> {
128 Ok(format!("processed: {}", input))
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use anyhow::anyhow;
136 use serde_json::json;
137
138 struct MockPipeline;
139
140 #[async_trait]
141 impl Ingestor for MockPipeline {
142 async fn ingest(&self, content: &str) -> Result<String> {
143 Ok(content.to_string())
144 }
145 }
146
147 #[async_trait]
148 impl Processor for MockPipeline {
149 async fn process(&self, input: &str) -> Result<String> {
150 Ok(format!("processed: {}", input))
151 }
152 }
153
154 #[async_trait]
155 impl Storage for MockPipeline {
156 async fn store(&self, _id: &str, _data: &str, _input: &MemoryInput) -> Result<()> {
157 Ok(())
158 }
159 }
160
161 #[async_trait]
162 impl Retriever for MockPipeline {
163 async fn retrieve(&self, id: &str) -> Result<String> {
164 Ok(format!("retrieved: {}", id))
165 }
166 }
167
168 #[async_trait]
169 impl Searcher for MockPipeline {
170 async fn search(
171 &self,
172 query: &str,
173 _limit: usize,
174 _opts: &SearchOptions,
175 ) -> Result<Vec<SearchResult>> {
176 Ok(vec![SearchResult {
177 id: "result-1".to_string(),
178 content: format!("match: {query}"),
179 tags: Vec::new(),
180 importance: 0.5,
181 metadata: json!({}),
182 event_type: None,
183 session_id: None,
184 project: None,
185 entity_id: None,
186 agent_type: None,
187 }])
188 }
189 }
190
191 #[async_trait]
192 impl Recents for MockPipeline {
193 async fn recent(&self, _limit: usize, _opts: &SearchOptions) -> Result<Vec<SearchResult>> {
194 Ok(vec![SearchResult {
195 id: "recent-1".to_string(),
196 content: "recent value".to_string(),
197 tags: Vec::new(),
198 importance: 0.5,
199 metadata: json!({}),
200 event_type: None,
201 session_id: None,
202 project: None,
203 entity_id: None,
204 agent_type: None,
205 }])
206 }
207 }
208
209 #[async_trait]
210 impl SemanticSearcher for MockPipeline {
211 async fn semantic_search(
212 &self,
213 query: &str,
214 _limit: usize,
215 _opts: &SearchOptions,
216 ) -> Result<Vec<SemanticResult>> {
217 Ok(vec![SemanticResult {
218 id: "semantic-1".to_string(),
219 content: format!("semantic match: {query}"),
220 tags: Vec::new(),
221 importance: 0.5,
222 metadata: json!({}),
223 event_type: None,
224 session_id: None,
225 project: None,
226 entity_id: None,
227 agent_type: None,
228 score: 0.99,
229 }])
230 }
231 }
232
233 struct FailingIngestor;
234
235 #[async_trait]
236 impl Ingestor for FailingIngestor {
237 async fn ingest(&self, _content: &str) -> Result<String> {
238 Err(anyhow!("Ingestion failed"))
239 }
240 }
241
242 #[tokio::test]
243 async fn test_ingestor_trait() {
244 let ingestor: Box<dyn Ingestor> = Box::new(MockPipeline);
245 let result = ingestor.ingest("test").await.unwrap();
246 assert_eq!(result, "test");
247 }
248
249 #[tokio::test]
250 async fn test_pipeline_run_success() {
251 let pipeline = Pipeline::new(
252 Box::new(MockPipeline),
253 Box::new(MockPipeline),
254 Box::new(MockPipeline),
255 Box::new(MockPipeline),
256 Box::new(MockPipeline),
257 Box::new(MockPipeline),
258 Box::new(MockPipeline),
259 );
260
261 let input = MemoryInput {
262 id: Some("custom_id".to_string()),
263 content: "hello".to_string(),
264 importance: 0.5,
265 metadata: json!({}),
266 ..Default::default()
267 };
268 let result = pipeline.run("hello", &input).await;
269 assert!(result.is_ok());
270 assert_eq!(result.unwrap(), "custom_id");
271 }
272
273 #[tokio::test]
274 async fn test_pipeline_run_default_id() {
275 let pipeline = Pipeline::new(
276 Box::new(MockPipeline),
277 Box::new(MockPipeline),
278 Box::new(MockPipeline),
279 Box::new(MockPipeline),
280 Box::new(MockPipeline),
281 Box::new(MockPipeline),
282 Box::new(MockPipeline),
283 );
284
285 let input = MemoryInput {
286 content: "hello".to_string(),
287 importance: 0.5,
288 metadata: json!({}),
289 ..Default::default()
290 };
291 let result = pipeline.run("hello", &input).await;
292 assert!(result.is_ok());
293 let id = result.unwrap();
294 assert!(uuid::Uuid::parse_str(&id).is_ok());
295 }
296
297 #[tokio::test]
298 async fn test_pipeline_retrieve_success() {
299 let pipeline = Pipeline::new(
300 Box::new(MockPipeline),
301 Box::new(MockPipeline),
302 Box::new(MockPipeline),
303 Box::new(MockPipeline),
304 Box::new(MockPipeline),
305 Box::new(MockPipeline),
306 Box::new(MockPipeline),
307 );
308
309 let result = pipeline.retrieve("test_id").await;
310 assert!(result.is_ok());
311 assert_eq!(result.unwrap(), "retrieved: test_id");
312 }
313
314 #[tokio::test]
315 async fn test_pipeline_failure() {
316 let pipeline = Pipeline::new(
317 Box::new(FailingIngestor),
318 Box::new(MockPipeline),
319 Box::new(MockPipeline),
320 Box::new(MockPipeline),
321 Box::new(MockPipeline),
322 Box::new(MockPipeline),
323 Box::new(MockPipeline),
324 );
325
326 let input = MemoryInput {
327 content: "hello".to_string(),
328 importance: 0.5,
329 metadata: json!({}),
330 ..Default::default()
331 };
332 let result = pipeline.run("hello", &input).await;
333 assert!(result.is_err());
334 assert_eq!(result.unwrap_err().to_string(), "Ingestion failed");
335 }
336
337 #[tokio::test]
338 async fn test_pipeline_search_success() {
339 let pipeline = Pipeline::new(
340 Box::new(MockPipeline),
341 Box::new(MockPipeline),
342 Box::new(MockPipeline),
343 Box::new(MockPipeline),
344 Box::new(MockPipeline),
345 Box::new(MockPipeline),
346 Box::new(MockPipeline),
347 );
348
349 let results = pipeline
350 .search("needle", 5, &SearchOptions::default())
351 .await
352 .unwrap();
353 assert_eq!(results.len(), 1);
354 assert_eq!(results[0].id, "result-1");
355 assert_eq!(results[0].content, "match: needle");
356 assert!(results[0].tags.is_empty());
357 assert_eq!(results[0].importance, 0.5);
358 assert_eq!(results[0].metadata, json!({}));
359 }
360
361 #[tokio::test]
362 async fn test_pipeline_recent_success() {
363 let pipeline = Pipeline::new(
364 Box::new(MockPipeline),
365 Box::new(MockPipeline),
366 Box::new(MockPipeline),
367 Box::new(MockPipeline),
368 Box::new(MockPipeline),
369 Box::new(MockPipeline),
370 Box::new(MockPipeline),
371 );
372
373 let results = pipeline.recent(3, &SearchOptions::default()).await.unwrap();
374 assert_eq!(results.len(), 1);
375 assert_eq!(results[0].id, "recent-1");
376 assert_eq!(results[0].content, "recent value");
377 assert!(results[0].tags.is_empty());
378 assert_eq!(results[0].importance, 0.5);
379 assert_eq!(results[0].metadata, json!({}));
380 }
381
382 #[tokio::test]
383 async fn test_pipeline_semantic_search_success() {
384 let pipeline = Pipeline::new(
385 Box::new(MockPipeline),
386 Box::new(MockPipeline),
387 Box::new(MockPipeline),
388 Box::new(MockPipeline),
389 Box::new(MockPipeline),
390 Box::new(MockPipeline),
391 Box::new(MockPipeline),
392 );
393
394 let results = pipeline
395 .semantic_search("vector", 4, &SearchOptions::default())
396 .await
397 .unwrap();
398 assert_eq!(results.len(), 1);
399 assert_eq!(results[0].id, "semantic-1");
400 assert_eq!(results[0].content, "semantic match: vector");
401 assert!(results[0].tags.is_empty());
402 assert_eq!(results[0].importance, 0.5);
403 assert_eq!(results[0].metadata, json!({}));
404 assert!(results[0].score > 0.9);
405 }
406
407 #[test]
408 fn test_memory_kind_for_semantic_event_type() {
409 assert_eq!(EventType::Decision.memory_kind(), MemoryKind::Semantic);
410 }
411
412 #[test]
413 fn test_memory_kind_defaults_to_episodic_for_unknown_type() {
414 assert_eq!(
415 EventType::Unknown("totally_unknown".to_string()).memory_kind(),
416 MemoryKind::Episodic
417 );
418 }
419}