Skip to main content

agentlib_memory/
buffer.rs

1use agentlib_core::{
2    MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, Role, async_trait,
3};
4use anyhow::Result;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9pub struct BufferMemory {
10    max_messages: usize,
11    store: Arc<Mutex<HashMap<String, Vec<ModelMessage>>>>,
12}
13
14impl BufferMemory {
15    pub fn new(max_messages: usize) -> Self {
16        Self {
17            max_messages,
18            store: Arc::new(Mutex::new(HashMap::new())),
19        }
20    }
21}
22
23#[async_trait]
24impl MemoryProvider for BufferMemory {
25    async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
26        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
27        let store = self.store.lock().await;
28        let messages = store.get(&session_id).cloned().unwrap_or_default();
29        Ok(messages)
30    }
31
32    async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
33        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
34
35        // Filter out system messages (they are usually re-injected by the agent)
36        let mut to_store: Vec<ModelMessage> = messages
37            .into_iter()
38            .filter(|m| m.role != Role::System)
39            .collect();
40
41        // Trim to max_messages (keep newest)
42        if to_store.len() > self.max_messages {
43            to_store = to_store.split_off(to_store.len() - self.max_messages);
44        }
45
46        let mut store = self.store.lock().await;
47        store.insert(session_id, to_store);
48        Ok(())
49    }
50
51    async fn clear(&self, session_id: Option<&str>) -> Result<()> {
52        let mut store = self.store.lock().await;
53        if let Some(sid) = session_id {
54            store.remove(sid);
55        } else {
56            store.clear();
57        }
58        Ok(())
59    }
60}