Skip to main content

agentlib_memory/
composite.rs

1use agentlib_core::{
2    MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, async_trait,
3};
4use anyhow::Result;
5use std::collections::HashSet;
6
7pub enum ReadStrategy {
8    FirstHit,
9    Merge,
10}
11
12pub struct CompositeMemory {
13    providers: Vec<Box<dyn MemoryProvider>>,
14    read_strategy: ReadStrategy,
15}
16
17impl CompositeMemory {
18    pub fn new(providers: Vec<Box<dyn MemoryProvider>>, read_strategy: ReadStrategy) -> Self {
19        Self {
20            providers,
21            read_strategy,
22        }
23    }
24}
25
26#[async_trait]
27impl MemoryProvider for CompositeMemory {
28    async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
29        match self.read_strategy {
30            ReadStrategy::FirstHit => {
31                for provider in &self.providers {
32                    let messages = provider.read(options.clone()).await?;
33                    if !messages.is_empty() {
34                        return Ok(messages);
35                    }
36                }
37                Ok(Vec::new())
38            }
39            ReadStrategy::Merge => {
40                let mut merged = Vec::new();
41                let mut seen = HashSet::new();
42
43                for provider in &self.providers {
44                    let messages = provider.read(options.clone()).await?;
45                    for msg in messages {
46                        let key = format!(
47                            "{:?}:{}",
48                            msg.role,
49                            msg.content.chars().take(100).collect::<String>()
50                        );
51                        if !seen.contains(&key) {
52                            seen.insert(key);
53                            merged.push(msg);
54                        }
55                    }
56                }
57                Ok(merged)
58            }
59        }
60    }
61
62    async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
63        for provider in &self.providers {
64            provider.write(messages.clone(), options.clone()).await?;
65        }
66        Ok(())
67    }
68
69    async fn clear(&self, session_id: Option<&str>) -> Result<()> {
70        for provider in &self.providers {
71            provider.clear(session_id).await?;
72        }
73        Ok(())
74    }
75}