strands_agents/conversation/
mod.rs

1//! Conversation management for context window optimization.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::types::content::{ContentBlock, Message, Role};
7use crate::types::errors::StrandsError;
8
9/// Default summarization prompt.
10pub const DEFAULT_SUMMARIZATION_PROMPT: &str = r#"You are a conversation summarizer. Provide a concise summary of the conversation history.
11
12Format Requirements:
13- You MUST create a structured and concise summary in bullet-point format.
14- You MUST NOT respond conversationally.
15- You MUST NOT address the user directly.
16- You MUST NOT comment on tool availability.
17
18Assumptions:
19- You MUST NOT assume tool executions failed unless otherwise stated.
20
21Task:
22Your task is to create a structured summary document:
23- It MUST contain bullet points with key topics and questions covered
24- It MUST contain bullet points for all significant tools executed and their results
25- It MUST contain bullet points for any code or technical information shared
26- It MUST contain a section of key insights gained
27- It MUST format the summary in the third person
28
29Example format:
30
31## Conversation Summary
32* Topic 1: Key information
33* Topic 2: Key information
34*
35## Tools Executed
36* Tool X: Result Y"#;
37
38/// Trait for implementing conversation managers.
39pub trait ConversationManager: Send + Sync {
40    /// Applies management strategy to the agent's messages.
41    fn apply_management(&self, messages: &mut Vec<Message>);
42
43    /// Reduces context when an error occurs (e.g., context overflow).
44    fn reduce_context(&self, messages: &mut Vec<Message>, error: &StrandsError);
45
46    /// Returns the current state for session persistence.
47    fn get_state(&self) -> HashMap<String, serde_json::Value> {
48        HashMap::new()
49    }
50
51    /// Restores state from a session. Returns optional prepend messages.
52    fn restore_from_session(&mut self, _state: HashMap<String, serde_json::Value>) -> Option<Vec<Message>> {
53        None
54    }
55
56    /// Returns the count of messages removed by this manager.
57    fn removed_message_count(&self) -> usize {
58        0
59    }
60}
61
62/// A no-op conversation manager.
63#[derive(Debug, Clone, Default)]
64pub struct NullConversationManager;
65
66impl ConversationManager for NullConversationManager {
67    fn apply_management(&self, _messages: &mut Vec<Message>) {}
68    fn reduce_context(&self, _messages: &mut Vec<Message>, _error: &StrandsError) {}
69}
70
71/// Sliding window conversation manager that keeps recent messages.
72#[derive(Debug, Clone)]
73pub struct SlidingWindowConversationManager {
74    pub window_size: usize,
75    removed_message_count: usize,
76}
77
78impl Default for SlidingWindowConversationManager {
79    fn default() -> Self {
80        Self {
81            window_size: 40,
82            removed_message_count: 0,
83        }
84    }
85}
86
87impl SlidingWindowConversationManager {
88    pub fn new(window_size: usize) -> Self {
89        Self {
90            window_size,
91            removed_message_count: 0,
92        }
93    }
94
95    fn adjust_split_point_for_tool_pairs(
96        &self,
97        messages: &[Message],
98        split_point: usize,
99    ) -> Result<usize, StrandsError> {
100        if split_point > messages.len() {
101            return Err(StrandsError::ContextWindowOverflow {
102                message: "Split point exceeds message array length".to_string(),
103            });
104        }
105
106        if split_point == messages.len() {
107            return Ok(split_point);
108        }
109
110        let mut adjusted = split_point;
111
112        while adjusted < messages.len() {
113            let msg = &messages[adjusted];
114            let has_tool_result = msg.content.iter().any(|c| c.tool_result.is_some());
115            let has_tool_use = msg.content.iter().any(|c| c.tool_use.is_some());
116
117            let next_has_tool_result = if adjusted + 1 < messages.len() {
118                messages[adjusted + 1]
119                    .content
120                    .iter()
121                    .any(|c| c.tool_result.is_some())
122            } else {
123                false
124            };
125
126            if has_tool_result || (has_tool_use && adjusted + 1 < messages.len() && !next_has_tool_result)
127            {
128                adjusted += 1;
129            } else {
130                break;
131            }
132        }
133
134        if adjusted >= messages.len() {
135            return Err(StrandsError::ContextWindowOverflow {
136                message: "Unable to trim conversation context!".to_string(),
137            });
138        }
139
140        Ok(adjusted)
141    }
142}
143
144impl ConversationManager for SlidingWindowConversationManager {
145    fn apply_management(&self, messages: &mut Vec<Message>) {
146        if messages.len() > self.window_size {
147            let to_remove = messages.len() - self.window_size;
148            if let Ok(adjusted) = self.adjust_split_point_for_tool_pairs(messages, to_remove) {
149                messages.drain(..adjusted);
150            }
151        }
152    }
153
154    fn reduce_context(&self, messages: &mut Vec<Message>, _error: &StrandsError) {
155        let keep = messages.len() / 2;
156        if keep > 0 {
157            let to_remove = messages.len() - keep;
158            if let Ok(adjusted) = self.adjust_split_point_for_tool_pairs(messages, to_remove) {
159                messages.drain(..adjusted);
160            }
161        }
162    }
163
164    fn get_state(&self) -> HashMap<String, serde_json::Value> {
165        let mut state = HashMap::new();
166        state.insert(
167            "removed_message_count".to_string(),
168            serde_json::json!(self.removed_message_count),
169        );
170        state.insert(
171            "window_size".to_string(),
172            serde_json::json!(self.window_size),
173        );
174        state
175    }
176
177    fn removed_message_count(&self) -> usize {
178        self.removed_message_count
179    }
180}
181
182/// Summarization function type for SummarizingConversationManager.
183pub type SummarizeFn = Arc<dyn Fn(&[Message]) -> Message + Send + Sync>;
184
185/// Summarizing conversation manager that summarizes older context.
186pub struct SummarizingConversationManager {
187    pub summary_ratio: f64,
188    pub preserve_recent_messages: usize,
189    pub summarization_prompt: String,
190    summarize_fn: Option<SummarizeFn>,
191    summary_message: Option<Message>,
192    removed_message_count: usize,
193}
194
195impl Default for SummarizingConversationManager {
196    fn default() -> Self {
197        Self {
198            summary_ratio: 0.3,
199            preserve_recent_messages: 10,
200            summarization_prompt: DEFAULT_SUMMARIZATION_PROMPT.to_string(),
201            summarize_fn: None,
202            summary_message: None,
203            removed_message_count: 0,
204        }
205    }
206}
207
208impl SummarizingConversationManager {
209    pub fn new(summary_ratio: f64, preserve_recent_messages: usize) -> Self {
210        Self {
211            summary_ratio: summary_ratio.clamp(0.1, 0.8),
212            preserve_recent_messages,
213            ..Default::default()
214        }
215    }
216
217    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
218        self.summarization_prompt = prompt.into();
219        self
220    }
221
222    pub fn with_summarize_fn(mut self, f: SummarizeFn) -> Self {
223        self.summarize_fn = Some(f);
224        self
225    }
226
227    fn adjust_split_point_for_tool_pairs(
228        &self,
229        messages: &[Message],
230        split_point: usize,
231    ) -> Result<usize, StrandsError> {
232        if split_point > messages.len() {
233            return Err(StrandsError::ContextWindowOverflow {
234                message: "Split point exceeds message array length".to_string(),
235            });
236        }
237
238        if split_point == messages.len() {
239            return Ok(split_point);
240        }
241
242        let mut adjusted = split_point;
243
244        while adjusted < messages.len() {
245            let msg = &messages[adjusted];
246            let has_tool_result = msg.content.iter().any(|c| c.tool_result.is_some());
247            let has_tool_use = msg.content.iter().any(|c| c.tool_use.is_some());
248
249            let next_has_tool_result = if adjusted + 1 < messages.len() {
250                messages[adjusted + 1]
251                    .content
252                    .iter()
253                    .any(|c| c.tool_result.is_some())
254            } else {
255                false
256            };
257
258            if has_tool_result || (has_tool_use && adjusted + 1 < messages.len() && !next_has_tool_result)
259            {
260                adjusted += 1;
261            } else {
262                break;
263            }
264        }
265
266        if adjusted >= messages.len() {
267            return Err(StrandsError::ContextWindowOverflow {
268                message: "Unable to trim conversation context!".to_string(),
269            });
270        }
271
272        Ok(adjusted)
273    }
274
275    fn generate_summary(&self, messages: &[Message]) -> Message {
276        if let Some(ref f) = self.summarize_fn {
277            f(messages)
278        } else {
279
280            let summary_text = messages
281                .iter()
282                .filter_map(|m| {
283                    m.content.iter().find_map(|c| c.text.clone())
284                })
285                .collect::<Vec<_>>()
286                .join("\n");
287
288            Message::new(
289                Role::User,
290                vec![ContentBlock::text(format!(
291                    "## Conversation Summary\n{}",
292                    summary_text
293                ))],
294            )
295        }
296    }
297}
298
299impl ConversationManager for SummarizingConversationManager {
300    fn apply_management(&self, _messages: &mut Vec<Message>) {
301
302    }
303
304    fn reduce_context(&self, messages: &mut Vec<Message>, _error: &StrandsError) {
305        let messages_to_summarize_count =
306            (messages.len() as f64 * self.summary_ratio).max(1.0) as usize;
307
308        let messages_to_summarize_count = messages_to_summarize_count
309            .min(messages.len().saturating_sub(self.preserve_recent_messages));
310
311        if messages_to_summarize_count == 0 {
312            return;
313        }
314
315        let adjusted = match self.adjust_split_point_for_tool_pairs(messages, messages_to_summarize_count) {
316            Ok(a) => a,
317            Err(_) => return,
318        };
319
320        if adjusted == 0 {
321            return;
322        }
323
324        let messages_to_summarize: Vec<_> = messages.drain(..adjusted).collect();
325        let summary = self.generate_summary(&messages_to_summarize);
326
327        messages.insert(0, summary);
328    }
329
330    fn get_state(&self) -> HashMap<String, serde_json::Value> {
331        let mut state = HashMap::new();
332        state.insert(
333            "removed_message_count".to_string(),
334            serde_json::json!(self.removed_message_count),
335        );
336        if let Some(ref summary) = self.summary_message {
337            if let Ok(v) = serde_json::to_value(summary) {
338                state.insert("summary_message".to_string(), v);
339            }
340        }
341        state
342    }
343
344    fn restore_from_session(&mut self, state: HashMap<String, serde_json::Value>) -> Option<Vec<Message>> {
345        if let Some(v) = state.get("removed_message_count") {
346            if let Some(count) = v.as_u64() {
347                self.removed_message_count = count as usize;
348            }
349        }
350
351        if let Some(v) = state.get("summary_message") {
352            if let Ok(msg) = serde_json::from_value(v.clone()) {
353                self.summary_message = Some(msg);
354                return self.summary_message.clone().map(|m| vec![m]);
355            }
356        }
357
358        None
359    }
360
361    fn removed_message_count(&self) -> usize {
362        self.removed_message_count
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::types::content::Role;
370
371    #[test]
372    fn test_sliding_window_applies_management() {
373        let manager = SlidingWindowConversationManager::new(3);
374        let mut messages = vec![
375            Message::new(Role::User, vec![ContentBlock::text("1")]),
376            Message::new(Role::Assistant, vec![ContentBlock::text("2")]),
377            Message::new(Role::User, vec![ContentBlock::text("3")]),
378            Message::new(Role::Assistant, vec![ContentBlock::text("4")]),
379            Message::new(Role::User, vec![ContentBlock::text("5")]),
380        ];
381
382        manager.apply_management(&mut messages);
383        assert_eq!(messages.len(), 3);
384    }
385
386    #[test]
387    fn test_null_conversation_manager() {
388        let manager = NullConversationManager;
389        let mut messages = vec![
390            Message::new(Role::User, vec![ContentBlock::text("test")]),
391        ];
392
393        manager.apply_management(&mut messages);
394        assert_eq!(messages.len(), 1);
395    }
396}