agentlib_memory/
buffer.rs1use agentlib_core::{
2 MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, Role, async_trait,
3};
4use anyhow::Result;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9pub struct BufferMemory {
10 max_messages: usize,
11 store: Arc<Mutex<HashMap<String, Vec<ModelMessage>>>>,
12}
13
14impl BufferMemory {
15 pub fn new(max_messages: usize) -> Self {
16 Self {
17 max_messages,
18 store: Arc::new(Mutex::new(HashMap::new())),
19 }
20 }
21}
22
23#[async_trait]
24impl MemoryProvider for BufferMemory {
25 async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
26 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
27 let store = self.store.lock().await;
28 let messages = store.get(&session_id).cloned().unwrap_or_default();
29 Ok(messages)
30 }
31
32 async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
33 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
34
35 let mut to_store: Vec<ModelMessage> = messages
37 .into_iter()
38 .filter(|m| m.role != Role::System)
39 .collect();
40
41 if to_store.len() > self.max_messages {
43 to_store = to_store.split_off(to_store.len() - self.max_messages);
44 }
45
46 let mut store = self.store.lock().await;
47 store.insert(session_id, to_store);
48 Ok(())
49 }
50
51 async fn clear(&self, session_id: Option<&str>) -> Result<()> {
52 let mut store = self.store.lock().await;
53 if let Some(sid) = session_id {
54 store.remove(sid);
55 } else {
56 store.clear();
57 }
58 Ok(())
59 }
60}