agentlib_memory/
sliding_window.rs1use agentlib_core::{
2 MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, Role, async_trait,
3 trim_to_token_budget,
4};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10pub struct SlidingWindowMemory {
11 max_tokens: usize,
12 max_turns: usize,
13 store: Arc<Mutex<HashMap<String, Vec<ConversationTurn>>>>,
14}
15
16#[derive(Debug, Clone)]
17struct ConversationTurn {
18 messages: Vec<ModelMessage>,
19}
20
21impl SlidingWindowMemory {
22 pub fn new(max_tokens: usize, max_turns: usize) -> Self {
23 Self {
24 max_tokens,
25 max_turns,
26 store: Arc::new(Mutex::new(HashMap::new())),
27 }
28 }
29}
30
31#[async_trait]
32impl MemoryProvider for SlidingWindowMemory {
33 async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
34 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
35 let store = self.store.lock().await;
36 let turns = store.get(&session_id).cloned().unwrap_or_default();
37
38 let messages: Vec<ModelMessage> = turns.into_iter().flat_map(|t| t.messages).collect();
40
41 let trimmed = trim_to_token_budget(messages, self.max_tokens);
43 Ok(trimmed)
44 }
45
46 async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
47 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
48
49 let non_system: Vec<ModelMessage> = messages
50 .into_iter()
51 .filter(|m| m.role != Role::System)
52 .collect();
53
54 let new_turns = group_into_turns(non_system);
55
56 let mut store = self.store.lock().await;
57 let existing = store.entry(session_id).or_default();
58
59 existing.extend(new_turns);
60
61 if existing.len() > self.max_turns {
63 *existing = existing.split_off(existing.len() - self.max_turns);
64 }
65
66 Ok(())
67 }
68
69 async fn clear(&self, session_id: Option<&str>) -> Result<()> {
70 let mut store = self.store.lock().await;
71 if let Some(sid) = session_id {
72 store.remove(sid);
73 } else {
74 store.clear();
75 }
76 Ok(())
77 }
78}
79
80fn group_into_turns(messages: Vec<ModelMessage>) -> Vec<ConversationTurn> {
81 let mut turns = Vec::new();
82 let mut current = Vec::new();
83
84 for msg in messages {
85 let role = msg.role;
86 let has_tool_calls = msg
87 .tool_calls
88 .as_ref()
89 .map(|tc| !tc.is_empty())
90 .unwrap_or(false);
91 current.push(msg);
92
93 if role == Role::Assistant && !has_tool_calls {
95 turns.push(ConversationTurn { messages: current });
96 current = Vec::new();
97 }
98 }
99
100 if !current.is_empty() {
101 turns.push(ConversationTurn { messages: current });
102 }
103
104 turns
105}