strands_agents/conversation/
mod.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::types::content::{ContentBlock, Message, Role};
7use crate::types::errors::StrandsError;
8
9pub const DEFAULT_SUMMARIZATION_PROMPT: &str = r#"You are a conversation summarizer. Provide a concise summary of the conversation history.
11
12Format Requirements:
13- You MUST create a structured and concise summary in bullet-point format.
14- You MUST NOT respond conversationally.
15- You MUST NOT address the user directly.
16- You MUST NOT comment on tool availability.
17
18Assumptions:
19- You MUST NOT assume tool executions failed unless otherwise stated.
20
21Task:
22Your task is to create a structured summary document:
23- It MUST contain bullet points with key topics and questions covered
24- It MUST contain bullet points for all significant tools executed and their results
25- It MUST contain bullet points for any code or technical information shared
26- It MUST contain a section of key insights gained
27- It MUST format the summary in the third person
28
29Example format:
30
31## Conversation Summary
32* Topic 1: Key information
33* Topic 2: Key information
34*
35## Tools Executed
36* Tool X: Result Y"#;
37
38pub trait ConversationManager: Send + Sync {
40 fn apply_management(&self, messages: &mut Vec<Message>);
42
43 fn reduce_context(&self, messages: &mut Vec<Message>, error: &StrandsError);
45
46 fn get_state(&self) -> HashMap<String, serde_json::Value> {
48 HashMap::new()
49 }
50
51 fn restore_from_session(&mut self, _state: HashMap<String, serde_json::Value>) -> Option<Vec<Message>> {
53 None
54 }
55
56 fn removed_message_count(&self) -> usize {
58 0
59 }
60}
61
62#[derive(Debug, Clone, Default)]
64pub struct NullConversationManager;
65
66impl ConversationManager for NullConversationManager {
67 fn apply_management(&self, _messages: &mut Vec<Message>) {}
68 fn reduce_context(&self, _messages: &mut Vec<Message>, _error: &StrandsError) {}
69}
70
71#[derive(Debug, Clone)]
73pub struct SlidingWindowConversationManager {
74 pub window_size: usize,
75 removed_message_count: usize,
76}
77
78impl Default for SlidingWindowConversationManager {
79 fn default() -> Self {
80 Self {
81 window_size: 40,
82 removed_message_count: 0,
83 }
84 }
85}
86
87impl SlidingWindowConversationManager {
88 pub fn new(window_size: usize) -> Self {
89 Self {
90 window_size,
91 removed_message_count: 0,
92 }
93 }
94
95 fn adjust_split_point_for_tool_pairs(
96 &self,
97 messages: &[Message],
98 split_point: usize,
99 ) -> Result<usize, StrandsError> {
100 if split_point > messages.len() {
101 return Err(StrandsError::ContextWindowOverflow {
102 message: "Split point exceeds message array length".to_string(),
103 });
104 }
105
106 if split_point == messages.len() {
107 return Ok(split_point);
108 }
109
110 let mut adjusted = split_point;
111
112 while adjusted < messages.len() {
113 let msg = &messages[adjusted];
114 let has_tool_result = msg.content.iter().any(|c| c.tool_result.is_some());
115 let has_tool_use = msg.content.iter().any(|c| c.tool_use.is_some());
116
117 let next_has_tool_result = if adjusted + 1 < messages.len() {
118 messages[adjusted + 1]
119 .content
120 .iter()
121 .any(|c| c.tool_result.is_some())
122 } else {
123 false
124 };
125
126 if has_tool_result || (has_tool_use && adjusted + 1 < messages.len() && !next_has_tool_result)
127 {
128 adjusted += 1;
129 } else {
130 break;
131 }
132 }
133
134 if adjusted >= messages.len() {
135 return Err(StrandsError::ContextWindowOverflow {
136 message: "Unable to trim conversation context!".to_string(),
137 });
138 }
139
140 Ok(adjusted)
141 }
142}
143
144impl ConversationManager for SlidingWindowConversationManager {
145 fn apply_management(&self, messages: &mut Vec<Message>) {
146 if messages.len() > self.window_size {
147 let to_remove = messages.len() - self.window_size;
148 if let Ok(adjusted) = self.adjust_split_point_for_tool_pairs(messages, to_remove) {
149 messages.drain(..adjusted);
150 }
151 }
152 }
153
154 fn reduce_context(&self, messages: &mut Vec<Message>, _error: &StrandsError) {
155 let keep = messages.len() / 2;
156 if keep > 0 {
157 let to_remove = messages.len() - keep;
158 if let Ok(adjusted) = self.adjust_split_point_for_tool_pairs(messages, to_remove) {
159 messages.drain(..adjusted);
160 }
161 }
162 }
163
164 fn get_state(&self) -> HashMap<String, serde_json::Value> {
165 let mut state = HashMap::new();
166 state.insert(
167 "removed_message_count".to_string(),
168 serde_json::json!(self.removed_message_count),
169 );
170 state.insert(
171 "window_size".to_string(),
172 serde_json::json!(self.window_size),
173 );
174 state
175 }
176
177 fn removed_message_count(&self) -> usize {
178 self.removed_message_count
179 }
180}
181
182pub type SummarizeFn = Arc<dyn Fn(&[Message]) -> Message + Send + Sync>;
184
185pub struct SummarizingConversationManager {
187 pub summary_ratio: f64,
188 pub preserve_recent_messages: usize,
189 pub summarization_prompt: String,
190 summarize_fn: Option<SummarizeFn>,
191 summary_message: Option<Message>,
192 removed_message_count: usize,
193}
194
195impl Default for SummarizingConversationManager {
196 fn default() -> Self {
197 Self {
198 summary_ratio: 0.3,
199 preserve_recent_messages: 10,
200 summarization_prompt: DEFAULT_SUMMARIZATION_PROMPT.to_string(),
201 summarize_fn: None,
202 summary_message: None,
203 removed_message_count: 0,
204 }
205 }
206}
207
208impl SummarizingConversationManager {
209 pub fn new(summary_ratio: f64, preserve_recent_messages: usize) -> Self {
210 Self {
211 summary_ratio: summary_ratio.clamp(0.1, 0.8),
212 preserve_recent_messages,
213 ..Default::default()
214 }
215 }
216
217 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
218 self.summarization_prompt = prompt.into();
219 self
220 }
221
222 pub fn with_summarize_fn(mut self, f: SummarizeFn) -> Self {
223 self.summarize_fn = Some(f);
224 self
225 }
226
227 fn adjust_split_point_for_tool_pairs(
228 &self,
229 messages: &[Message],
230 split_point: usize,
231 ) -> Result<usize, StrandsError> {
232 if split_point > messages.len() {
233 return Err(StrandsError::ContextWindowOverflow {
234 message: "Split point exceeds message array length".to_string(),
235 });
236 }
237
238 if split_point == messages.len() {
239 return Ok(split_point);
240 }
241
242 let mut adjusted = split_point;
243
244 while adjusted < messages.len() {
245 let msg = &messages[adjusted];
246 let has_tool_result = msg.content.iter().any(|c| c.tool_result.is_some());
247 let has_tool_use = msg.content.iter().any(|c| c.tool_use.is_some());
248
249 let next_has_tool_result = if adjusted + 1 < messages.len() {
250 messages[adjusted + 1]
251 .content
252 .iter()
253 .any(|c| c.tool_result.is_some())
254 } else {
255 false
256 };
257
258 if has_tool_result || (has_tool_use && adjusted + 1 < messages.len() && !next_has_tool_result)
259 {
260 adjusted += 1;
261 } else {
262 break;
263 }
264 }
265
266 if adjusted >= messages.len() {
267 return Err(StrandsError::ContextWindowOverflow {
268 message: "Unable to trim conversation context!".to_string(),
269 });
270 }
271
272 Ok(adjusted)
273 }
274
275 fn generate_summary(&self, messages: &[Message]) -> Message {
276 if let Some(ref f) = self.summarize_fn {
277 f(messages)
278 } else {
279
280 let summary_text = messages
281 .iter()
282 .filter_map(|m| {
283 m.content.iter().find_map(|c| c.text.clone())
284 })
285 .collect::<Vec<_>>()
286 .join("\n");
287
288 Message::new(
289 Role::User,
290 vec![ContentBlock::text(format!(
291 "## Conversation Summary\n{}",
292 summary_text
293 ))],
294 )
295 }
296 }
297}
298
299impl ConversationManager for SummarizingConversationManager {
300 fn apply_management(&self, _messages: &mut Vec<Message>) {
301
302 }
303
304 fn reduce_context(&self, messages: &mut Vec<Message>, _error: &StrandsError) {
305 let messages_to_summarize_count =
306 (messages.len() as f64 * self.summary_ratio).max(1.0) as usize;
307
308 let messages_to_summarize_count = messages_to_summarize_count
309 .min(messages.len().saturating_sub(self.preserve_recent_messages));
310
311 if messages_to_summarize_count == 0 {
312 return;
313 }
314
315 let adjusted = match self.adjust_split_point_for_tool_pairs(messages, messages_to_summarize_count) {
316 Ok(a) => a,
317 Err(_) => return,
318 };
319
320 if adjusted == 0 {
321 return;
322 }
323
324 let messages_to_summarize: Vec<_> = messages.drain(..adjusted).collect();
325 let summary = self.generate_summary(&messages_to_summarize);
326
327 messages.insert(0, summary);
328 }
329
330 fn get_state(&self) -> HashMap<String, serde_json::Value> {
331 let mut state = HashMap::new();
332 state.insert(
333 "removed_message_count".to_string(),
334 serde_json::json!(self.removed_message_count),
335 );
336 if let Some(ref summary) = self.summary_message {
337 if let Ok(v) = serde_json::to_value(summary) {
338 state.insert("summary_message".to_string(), v);
339 }
340 }
341 state
342 }
343
344 fn restore_from_session(&mut self, state: HashMap<String, serde_json::Value>) -> Option<Vec<Message>> {
345 if let Some(v) = state.get("removed_message_count") {
346 if let Some(count) = v.as_u64() {
347 self.removed_message_count = count as usize;
348 }
349 }
350
351 if let Some(v) = state.get("summary_message") {
352 if let Ok(msg) = serde_json::from_value(v.clone()) {
353 self.summary_message = Some(msg);
354 return self.summary_message.clone().map(|m| vec![m]);
355 }
356 }
357
358 None
359 }
360
361 fn removed_message_count(&self) -> usize {
362 self.removed_message_count
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::types::content::Role;
370
371 #[test]
372 fn test_sliding_window_applies_management() {
373 let manager = SlidingWindowConversationManager::new(3);
374 let mut messages = vec![
375 Message::new(Role::User, vec![ContentBlock::text("1")]),
376 Message::new(Role::Assistant, vec![ContentBlock::text("2")]),
377 Message::new(Role::User, vec![ContentBlock::text("3")]),
378 Message::new(Role::Assistant, vec![ContentBlock::text("4")]),
379 Message::new(Role::User, vec![ContentBlock::text("5")]),
380 ];
381
382 manager.apply_management(&mut messages);
383 assert_eq!(messages.len(), 3);
384 }
385
386 #[test]
387 fn test_null_conversation_manager() {
388 let manager = NullConversationManager;
389 let mut messages = vec![
390 Message::new(Role::User, vec![ContentBlock::text("test")]),
391 ];
392
393 manager.apply_management(&mut messages);
394 assert_eq!(messages.len(), 1);
395 }
396}