Skip to main content

agentlib_memory/
sliding_window.rs

1use agentlib_core::{
2    MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, Role, async_trait,
3    trim_to_token_budget,
4};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10pub struct SlidingWindowMemory {
11    max_tokens: usize,
12    max_turns: usize,
13    store: Arc<Mutex<HashMap<String, Vec<ConversationTurn>>>>,
14}
15
16#[derive(Debug, Clone)]
17struct ConversationTurn {
18    messages: Vec<ModelMessage>,
19}
20
21impl SlidingWindowMemory {
22    pub fn new(max_tokens: usize, max_turns: usize) -> Self {
23        Self {
24            max_tokens,
25            max_turns,
26            store: Arc::new(Mutex::new(HashMap::new())),
27        }
28    }
29}
30
31#[async_trait]
32impl MemoryProvider for SlidingWindowMemory {
33    async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
34        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
35        let store = self.store.lock().await;
36        let turns = store.get(&session_id).cloned().unwrap_or_default();
37
38        // Flatten turns to messages
39        let messages: Vec<ModelMessage> = turns.into_iter().flat_map(|t| t.messages).collect();
40
41        // Apply token budget
42        let trimmed = trim_to_token_budget(messages, self.max_tokens);
43        Ok(trimmed)
44    }
45
46    async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
47        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
48
49        let non_system: Vec<ModelMessage> = messages
50            .into_iter()
51            .filter(|m| m.role != Role::System)
52            .collect();
53
54        let new_turns = group_into_turns(non_system);
55
56        let mut store = self.store.lock().await;
57        let existing = store.entry(session_id).or_default();
58
59        existing.extend(new_turns);
60
61        // Keep last max_turns
62        if existing.len() > self.max_turns {
63            *existing = existing.split_off(existing.len() - self.max_turns);
64        }
65
66        Ok(())
67    }
68
69    async fn clear(&self, session_id: Option<&str>) -> Result<()> {
70        let mut store = self.store.lock().await;
71        if let Some(sid) = session_id {
72            store.remove(sid);
73        } else {
74            store.clear();
75        }
76        Ok(())
77    }
78}
79
80fn group_into_turns(messages: Vec<ModelMessage>) -> Vec<ConversationTurn> {
81    let mut turns = Vec::new();
82    let mut current = Vec::new();
83
84    for msg in messages {
85        let role = msg.role;
86        let has_tool_calls = msg
87            .tool_calls
88            .as_ref()
89            .map(|tc| !tc.is_empty())
90            .unwrap_or(false);
91        current.push(msg);
92
93        // A turn ends when the assistant finishes (no pending tool calls)
94        if role == Role::Assistant && !has_tool_calls {
95            turns.push(ConversationTurn { messages: current });
96            current = Vec::new();
97        }
98    }
99
100    if !current.is_empty() {
101        turns.push(ConversationTurn { messages: current });
102    }
103
104    turns
105}