1use std::sync::Arc;
27
28use async_trait::async_trait;
29
30use crate::{
31 error::AgentError,
32 types::{Message, Role, TokenUsage},
33};
34
35#[derive(Debug, Clone)]
37pub struct TrimConfig {
38 pub max_messages: usize,
40 pub target_tokens: usize,
42 pub preserve_first_user: bool,
44 pub summarization_threshold: f64,
46}
47
48impl Default for TrimConfig {
49 fn default() -> Self {
50 Self {
51 max_messages: 50,
52 target_tokens: 8192,
53 preserve_first_user: true,
54 summarization_threshold: 0.8,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TrimResult {
62 pub messages: Vec<Message>,
64 pub was_summarized: bool,
66 pub estimated_tokens: usize,
68 pub messages_dropped: usize,
70}
71
72#[async_trait]
77pub trait ContextTrimmer: Send + Sync {
78 async fn trim(
89 &self,
90 messages: &[Message],
91 usage: &TokenUsage,
92 ) -> Result<TrimResult, AgentError>;
93
94 fn strategy_name(&self) -> &'static str;
96}
97
98#[derive(Debug, Clone)]
112pub struct SlidingWindowTrimmer {
113 config: TrimConfig,
114}
115
116impl SlidingWindowTrimmer {
117 #[must_use]
119 pub fn new(max_messages: usize, target_tokens: usize) -> Self {
120 Self {
121 config: TrimConfig {
122 max_messages,
123 target_tokens,
124 preserve_first_user: true,
125 summarization_threshold: 1.0, },
127 }
128 }
129
130 #[must_use]
132 pub fn with_config(config: TrimConfig) -> Self {
133 Self { config }
134 }
135}
136
137#[async_trait]
138impl ContextTrimmer for SlidingWindowTrimmer {
139 async fn trim(
140 &self,
141 messages: &[Message],
142 _usage: &TokenUsage,
143 ) -> Result<TrimResult, AgentError> {
144 let original_count = messages.len();
145 let trimmed = sliding_window_trim(
146 messages,
147 self.config.max_messages,
148 self.config.preserve_first_user,
149 );
150 let dropped = original_count.saturating_sub(trimmed.len());
151 let estimated = estimate_tokens(&trimmed);
152
153 Ok(TrimResult {
154 messages: trimmed,
155 was_summarized: false,
156 estimated_tokens: estimated,
157 messages_dropped: dropped,
158 })
159 }
160
161 fn strategy_name(&self) -> &'static str {
162 "sliding_window"
163 }
164}
165
166#[derive(Debug)]
183pub struct SummarizationTrimmer {
184 config: TrimConfig,
185 summarizer: Arc<dyn MessageSummarizer>,
186}
187
188impl SummarizationTrimmer {
189 #[must_use]
191 pub fn new(
192 max_messages: usize,
193 target_tokens: usize,
194 summarizer: Arc<dyn MessageSummarizer>,
195 ) -> Self {
196 Self {
197 config: TrimConfig {
198 max_messages,
199 target_tokens,
200 preserve_first_user: true,
201 summarization_threshold: 0.8,
202 },
203 summarizer,
204 }
205 }
206
207 #[must_use]
209 pub fn with_config(config: TrimConfig, summarizer: Arc<dyn MessageSummarizer>) -> Self {
210 Self { config, summarizer }
211 }
212}
213
214#[async_trait]
215impl ContextTrimmer for SummarizationTrimmer {
216 async fn trim(
217 &self,
218 messages: &[Message],
219 _usage: &TokenUsage,
220 ) -> Result<TrimResult, AgentError> {
221 let estimated = estimate_tokens(messages);
222 let threshold =
223 (self.config.target_tokens as f64 * self.config.summarization_threshold) as usize;
224
225 if estimated < threshold {
227 let trimmed = sliding_window_trim(
228 messages,
229 self.config.max_messages,
230 self.config.preserve_first_user,
231 );
232 let trimmed_tokens = estimate_tokens(&trimmed);
233 let dropped = messages.len().saturating_sub(trimmed.len());
234 return Ok(TrimResult {
235 messages: trimmed,
236 was_summarized: false,
237 estimated_tokens: trimmed_tokens,
238 messages_dropped: dropped,
239 });
240 }
241
242 let (old_messages, recent_messages) = split_at_threshold(
244 messages,
245 self.config.max_messages / 2,
246 self.config.preserve_first_user,
247 );
248
249 if old_messages.is_empty() {
250 let trimmed = sliding_window_trim(
251 messages,
252 self.config.max_messages,
253 self.config.preserve_first_user,
254 );
255 let trimmed_tokens = estimate_tokens(&trimmed);
256 let dropped = messages.len().saturating_sub(trimmed.len());
257 return Ok(TrimResult {
258 messages: trimmed,
259 was_summarized: false,
260 estimated_tokens: trimmed_tokens,
261 messages_dropped: dropped,
262 });
263 }
264
265 let summary = self.summarizer.summarize(&old_messages).await?;
266 let mut result_messages = Vec::with_capacity(1 + recent_messages.len());
267 result_messages.push(Message::text(
268 Role::System,
269 format!("Previous conversation summary:\n{summary}"),
270 ));
271 result_messages.extend(recent_messages);
272
273 Ok(TrimResult {
274 messages: result_messages.clone(),
275 was_summarized: true,
276 estimated_tokens: estimate_tokens(&result_messages),
277 messages_dropped: old_messages.len(),
278 })
279 }
280
281 fn strategy_name(&self) -> &'static str {
282 "summarization"
283 }
284}
285
286#[async_trait]
291pub trait MessageSummarizer: Send + Sync + std::fmt::Debug {
292 async fn summarize(&self, messages: &[Message]) -> Result<String, AgentError>;
294}
295
296#[derive(Debug)]
301pub struct HybridTrimmer {
302 sliding: SlidingWindowTrimmer,
303 summarizer: Arc<dyn MessageSummarizer>,
304 summarization_threshold: f64,
305 target_tokens: usize,
306}
307
308impl HybridTrimmer {
309 #[must_use]
311 pub fn new(
312 max_messages: usize,
313 target_tokens: usize,
314 summarizer: Arc<dyn MessageSummarizer>,
315 ) -> Self {
316 Self {
317 sliding: SlidingWindowTrimmer::new(max_messages, target_tokens),
318 summarizer,
319 summarization_threshold: 0.8,
320 target_tokens,
321 }
322 }
323}
324
325#[async_trait]
326impl ContextTrimmer for HybridTrimmer {
327 async fn trim(
328 &self,
329 messages: &[Message],
330 usage: &TokenUsage,
331 ) -> Result<TrimResult, AgentError> {
332 let estimated = estimate_tokens(messages);
333 let threshold = (self.target_tokens as f64 * self.summarization_threshold) as usize;
334
335 if estimated < threshold {
337 return self.sliding.trim(messages, usage).await;
338 }
339
340 let split_point = messages.len() / 2;
342 let (old_messages, recent_messages) = messages.split_at(split_point);
343
344 let summary = self.summarizer.summarize(old_messages).await?;
345 let mut result_messages = Vec::with_capacity(1 + recent_messages.len());
346 result_messages.push(Message::text(
347 Role::System,
348 format!("Previous conversation summary:\n{summary}"),
349 ));
350 result_messages.extend_from_slice(recent_messages);
351
352 Ok(TrimResult {
353 messages: result_messages.clone(),
354 was_summarized: true,
355 estimated_tokens: estimate_tokens(&result_messages),
356 messages_dropped: old_messages.len(),
357 })
358 }
359
360 fn strategy_name(&self) -> &'static str {
361 "hybrid"
362 }
363}
364
365#[derive(Debug, Clone, Copy, Default)]
370pub struct NoOpTrimmer;
371
372#[async_trait]
373impl ContextTrimmer for NoOpTrimmer {
374 async fn trim(
375 &self,
376 messages: &[Message],
377 _usage: &TokenUsage,
378 ) -> Result<TrimResult, AgentError> {
379 Ok(TrimResult {
380 messages: messages.to_vec(),
381 was_summarized: false,
382 estimated_tokens: estimate_tokens(messages),
383 messages_dropped: 0,
384 })
385 }
386
387 fn strategy_name(&self) -> &'static str {
388 "noop"
389 }
390}
391
392fn sliding_window_trim(
396 messages: &[Message],
397 max: usize,
398 preserve_first_user: bool,
399) -> Vec<Message> {
400 let non_system: Vec<(usize, &Message)> =
401 messages.iter().enumerate().filter(|(_, m)| m.role != Role::System).collect();
402
403 if non_system.len() <= max {
404 return messages.to_vec();
405 }
406
407 let first_user_idx =
408 if preserve_first_user { messages.iter().position(|m| m.role == Role::User) } else { None };
409
410 let recent_start = non_system.len().saturating_sub(max);
411 let mut to_keep: std::collections::HashSet<usize> =
412 non_system[recent_start..].iter().map(|(idx, _)| *idx).collect();
413
414 if let Some(first_idx) = first_user_idx {
416 to_keep.insert(first_idx);
417 }
418
419 messages
420 .iter()
421 .enumerate()
422 .filter(|(idx, msg)| msg.role == Role::System || to_keep.contains(idx))
423 .map(|(_, msg)| msg.clone())
424 .collect()
425}
426
427fn split_at_threshold(
430 messages: &[Message],
431 recent_count: usize,
432 preserve_first_user: bool,
433) -> (Vec<Message>, Vec<Message>) {
434 let non_system: Vec<(usize, &Message)> =
435 messages.iter().enumerate().filter(|(_, m)| m.role != Role::System).collect();
436
437 if non_system.len() <= recent_count {
438 return (Vec::new(), messages.to_vec());
439 }
440
441 let split_idx = non_system.len() - recent_count;
442 let split_at = non_system[split_idx].0;
443
444 let first_user_idx =
445 if preserve_first_user { messages.iter().position(|m| m.role == Role::User) } else { None };
446
447 let mut old = Vec::new();
448 let mut recent = Vec::new();
449
450 for (idx, msg) in messages.iter().enumerate() {
451 if msg.role == Role::System {
452 recent.push(msg.clone());
454 } else if idx < split_at {
455 if Some(idx) == first_user_idx {
456 recent.push(msg.clone()); } else {
458 old.push(msg.clone());
459 }
460 } else {
461 recent.push(msg.clone());
462 }
463 }
464
465 (old, recent)
466}
467
468fn estimate_tokens(messages: &[Message]) -> usize {
470 messages.iter().map(|m| m.content.len() / 4).sum()
471}
472
473#[cfg(test)]
476mod tests {
477 use super::*;
478
479 fn msg(role: Role, content: &str) -> Message {
480 Message::text(role, content.to_string())
481 }
482
483 #[tokio::test]
486 async fn sliding_window_noop_when_under_limit() {
487 let trimmer = SlidingWindowTrimmer::new(50, 8192);
488 let messages = vec![msg(Role::User, "hello"), msg(Role::Assistant, "hi")];
489 let usage = TokenUsage::default();
490
491 let result = trimmer.trim(&messages, &usage).await.unwrap();
492 assert_eq!(result.messages.len(), 2);
493 assert!(!result.was_summarized);
494 assert_eq!(result.messages_dropped, 0);
495 }
496
497 #[tokio::test]
498 async fn sliding_window_drops_oldest() {
499 let trimmer = SlidingWindowTrimmer::new(3, 8192);
500 let messages = vec![
501 msg(Role::User, "msg-0"),
502 msg(Role::Assistant, "msg-1"),
503 msg(Role::User, "msg-2"),
504 msg(Role::Assistant, "msg-3"),
505 msg(Role::User, "msg-4"),
506 ];
507 let usage = TokenUsage::default();
508
509 let result = trimmer.trim(&messages, &usage).await.unwrap();
510 assert_eq!(result.messages.len(), 4); assert_eq!(result.messages_dropped, 1);
512 }
513
514 #[tokio::test]
515 async fn sliding_window_preserves_system() {
516 let trimmer = SlidingWindowTrimmer::new(2, 8192);
517 let messages = vec![
518 msg(Role::System, "system instructions"),
519 msg(Role::User, "old"),
520 msg(Role::Assistant, "mid"),
521 msg(Role::User, "new"),
522 ];
523 let usage = TokenUsage::default();
524
525 let result = trimmer.trim(&messages, &usage).await.unwrap();
526 assert_eq!(result.messages[0].role, Role::System);
527 assert!(result.messages.iter().any(|m| m.content == "new"));
528 }
529
530 #[tokio::test]
531 async fn sliding_window_preserves_first_user() {
532 let trimmer = SlidingWindowTrimmer::new(2, 8192);
533 let messages = vec![
534 msg(Role::User, "original-task"),
535 msg(Role::Assistant, "response-1"),
536 msg(Role::User, "follow-up"),
537 msg(Role::Assistant, "response-2"),
538 ];
539 let usage = TokenUsage::default();
540
541 let result = trimmer.trim(&messages, &usage).await.unwrap();
542 assert!(result.messages.iter().any(|m| m.content == "original-task"));
543 }
544
545 #[tokio::test]
548 async fn noop_trimmer_passes_through() {
549 let trimmer = NoOpTrimmer;
550 let messages = vec![msg(Role::User, "a"), msg(Role::Assistant, "b")];
551 let usage = TokenUsage::default();
552
553 let result = trimmer.trim(&messages, &usage).await.unwrap();
554 assert_eq!(result.messages.len(), 2);
555 assert_eq!(result.messages_dropped, 0);
556 assert!(!result.was_summarized);
557 }
558
559 #[test]
562 fn estimate_tokens_basic() {
563 let messages = vec![msg(Role::User, "hello world")]; assert_eq!(estimate_tokens(&messages), 2);
565 }
566
567 #[test]
568 fn estimate_tokens_empty() {
569 assert_eq!(estimate_tokens(&[]), 0);
570 }
571}