Skip to main content

bob_core/
context_trimmer.rs

1//! # Context Trimmer
2//!
3//! Pluggable context window management strategies for the Bob Agent Framework.
4//!
5//! ## Overview
6//!
7//! As conversations grow, the context window can exceed model token limits.
8//! The [`ContextTrimmer`] trait provides pluggable strategies for managing
9//! context size before sending to the LLM.
10//!
11//! ## Strategies
12//!
13//! - [`SlidingWindowTrimmer`] — Keeps the N most recent messages (simple, fast)
14//! - [`SummarizationTrimmer`] — Uses a small model to compress older messages
15//! - [`HybridTrimmer`] — Sliding window with periodic summarization
16//!
17//! ## Example
18//!
19//! ```rust,ignore
20//! use bob_core::context_trimmer::{ContextTrimmer, SlidingWindowTrimmer};
21//!
22//! let trimmer = SlidingWindowTrimmer::new(50, 4096);
23//! let trimmed = trimmer.trim(&session.messages, &session.total_usage).await?;
24//! ```
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29
30use crate::{
31    error::AgentError,
32    types::{Message, Role, TokenUsage},
33};
34
35/// Configuration for context trimming behavior.
36#[derive(Debug, Clone)]
37pub struct TrimConfig {
38    /// Maximum number of non-system messages to retain.
39    pub max_messages: usize,
40    /// Target token budget (approximate).
41    pub target_tokens: usize,
42    /// Whether to preserve the first user message (often contains the original task).
43    pub preserve_first_user: bool,
44    /// Ratio of context at which summarization triggers (0.0-1.0).
45    pub summarization_threshold: f64,
46}
47
48impl Default for TrimConfig {
49    fn default() -> Self {
50        Self {
51            max_messages: 50,
52            target_tokens: 8192,
53            preserve_first_user: true,
54            summarization_threshold: 0.8,
55        }
56    }
57}
58
59/// Result of a context trimming operation.
60#[derive(Debug, Clone)]
61pub struct TrimResult {
62    /// The trimmed messages ready for LLM consumption.
63    pub messages: Vec<Message>,
64    /// Whether summarization was performed.
65    pub was_summarized: bool,
66    /// Estimated token count after trimming.
67    pub estimated_tokens: usize,
68    /// Number of messages dropped.
69    pub messages_dropped: usize,
70}
71
72/// Trait for pluggable context trimming strategies.
73///
74/// Implementors define how to reduce message history to fit within
75/// token or message count limits.
76#[async_trait]
77pub trait ContextTrimmer: Send + Sync {
78    /// Trim the message history according to the strategy.
79    ///
80    /// # Arguments
81    ///
82    /// * `messages` — Full message history (including system messages)
83    /// * `usage` — Current token usage statistics
84    ///
85    /// # Returns
86    ///
87    /// A [`TrimResult`] containing the trimmed messages and metadata.
88    async fn trim(
89        &self,
90        messages: &[Message],
91        usage: &TokenUsage,
92    ) -> Result<TrimResult, AgentError>;
93
94    /// Human-readable name of this trimming strategy.
95    fn strategy_name(&self) -> &'static str;
96}
97
98/// Simple sliding window trimmer.
99///
100/// Keeps the most recent N messages, always preserving system messages
101/// and optionally the first user message.
102///
103/// ## Example
104///
105/// ```rust,ignore
106/// use bob_core::context_trimmer::{ContextTrimmer, SlidingWindowTrimmer};
107///
108/// // Keep 50 most recent messages, target ~8K tokens
109/// let trimmer = SlidingWindowTrimmer::new(50, 8192);
110/// ```
111#[derive(Debug, Clone)]
112pub struct SlidingWindowTrimmer {
113    config: TrimConfig,
114}
115
116impl SlidingWindowTrimmer {
117    /// Create a new sliding window trimmer.
118    #[must_use]
119    pub fn new(max_messages: usize, target_tokens: usize) -> Self {
120        Self {
121            config: TrimConfig {
122                max_messages,
123                target_tokens,
124                preserve_first_user: true,
125                summarization_threshold: 1.0, // Never summarize
126            },
127        }
128    }
129
130    /// Create with full configuration.
131    #[must_use]
132    pub fn with_config(config: TrimConfig) -> Self {
133        Self { config }
134    }
135}
136
137#[async_trait]
138impl ContextTrimmer for SlidingWindowTrimmer {
139    async fn trim(
140        &self,
141        messages: &[Message],
142        _usage: &TokenUsage,
143    ) -> Result<TrimResult, AgentError> {
144        let original_count = messages.len();
145        let trimmed = sliding_window_trim(
146            messages,
147            self.config.max_messages,
148            self.config.preserve_first_user,
149        );
150        let dropped = original_count.saturating_sub(trimmed.len());
151        let estimated = estimate_tokens(&trimmed);
152
153        Ok(TrimResult {
154            messages: trimmed,
155            was_summarized: false,
156            estimated_tokens: estimated,
157            messages_dropped: dropped,
158        })
159    }
160
161    fn strategy_name(&self) -> &'static str {
162        "sliding_window"
163    }
164}
165
166/// Summarization-based trimmer.
167///
168/// When context approaches the limit, uses a callback to summarize
169/// older messages into a compact form.
170///
171/// ## Example
172///
173/// ```rust,ignore
174/// use bob_core::context_trimmer::{ContextTrimmer, SummarizationTrimmer};
175///
176/// let trimmer = SummarizationTrimmer::new(
177///     100,
178///     4096,
179///     Arc::new(my_summarizer), // impl MessageSummarizer
180/// );
181/// ```
182#[derive(Debug)]
183pub struct SummarizationTrimmer {
184    config: TrimConfig,
185    summarizer: Arc<dyn MessageSummarizer>,
186}
187
188impl SummarizationTrimmer {
189    /// Create a new summarization trimmer.
190    #[must_use]
191    pub fn new(
192        max_messages: usize,
193        target_tokens: usize,
194        summarizer: Arc<dyn MessageSummarizer>,
195    ) -> Self {
196        Self {
197            config: TrimConfig {
198                max_messages,
199                target_tokens,
200                preserve_first_user: true,
201                summarization_threshold: 0.8,
202            },
203            summarizer,
204        }
205    }
206
207    /// Create with full configuration.
208    #[must_use]
209    pub fn with_config(config: TrimConfig, summarizer: Arc<dyn MessageSummarizer>) -> Self {
210        Self { config, summarizer }
211    }
212}
213
214#[async_trait]
215impl ContextTrimmer for SummarizationTrimmer {
216    async fn trim(
217        &self,
218        messages: &[Message],
219        _usage: &TokenUsage,
220    ) -> Result<TrimResult, AgentError> {
221        let estimated = estimate_tokens(messages);
222        let threshold =
223            (self.config.target_tokens as f64 * self.config.summarization_threshold) as usize;
224
225        // If under threshold, use sliding window
226        if estimated < threshold {
227            let trimmed = sliding_window_trim(
228                messages,
229                self.config.max_messages,
230                self.config.preserve_first_user,
231            );
232            let trimmed_tokens = estimate_tokens(&trimmed);
233            let dropped = messages.len().saturating_sub(trimmed.len());
234            return Ok(TrimResult {
235                messages: trimmed,
236                was_summarized: false,
237                estimated_tokens: trimmed_tokens,
238                messages_dropped: dropped,
239            });
240        }
241
242        // Summarize older messages
243        let (old_messages, recent_messages) = split_at_threshold(
244            messages,
245            self.config.max_messages / 2,
246            self.config.preserve_first_user,
247        );
248
249        if old_messages.is_empty() {
250            let trimmed = sliding_window_trim(
251                messages,
252                self.config.max_messages,
253                self.config.preserve_first_user,
254            );
255            let trimmed_tokens = estimate_tokens(&trimmed);
256            let dropped = messages.len().saturating_sub(trimmed.len());
257            return Ok(TrimResult {
258                messages: trimmed,
259                was_summarized: false,
260                estimated_tokens: trimmed_tokens,
261                messages_dropped: dropped,
262            });
263        }
264
265        let summary = self.summarizer.summarize(&old_messages).await?;
266        let mut result_messages = Vec::with_capacity(1 + recent_messages.len());
267        result_messages.push(Message::text(
268            Role::System,
269            format!("Previous conversation summary:\n{summary}"),
270        ));
271        result_messages.extend(recent_messages);
272
273        Ok(TrimResult {
274            messages: result_messages.clone(),
275            was_summarized: true,
276            estimated_tokens: estimate_tokens(&result_messages),
277            messages_dropped: old_messages.len(),
278        })
279    }
280
281    fn strategy_name(&self) -> &'static str {
282        "summarization"
283    }
284}
285
286/// Trait for message summarization backends.
287///
288/// Implementations can use a small LLM, extractive summarization,
289/// or any other method to compress messages.
290#[async_trait]
291pub trait MessageSummarizer: Send + Sync + std::fmt::Debug {
292    /// Summarize a batch of messages into a concise text.
293    async fn summarize(&self, messages: &[Message]) -> Result<String, AgentError>;
294}
295
296/// Hybrid trimmer that combines sliding window with periodic summarization.
297///
298/// Maintains a sliding window but periodically summarizes old context
299/// when approaching token limits.
300#[derive(Debug)]
301pub struct HybridTrimmer {
302    sliding: SlidingWindowTrimmer,
303    summarizer: Arc<dyn MessageSummarizer>,
304    summarization_threshold: f64,
305    target_tokens: usize,
306}
307
308impl HybridTrimmer {
309    /// Create a new hybrid trimmer.
310    #[must_use]
311    pub fn new(
312        max_messages: usize,
313        target_tokens: usize,
314        summarizer: Arc<dyn MessageSummarizer>,
315    ) -> Self {
316        Self {
317            sliding: SlidingWindowTrimmer::new(max_messages, target_tokens),
318            summarizer,
319            summarization_threshold: 0.8,
320            target_tokens,
321        }
322    }
323}
324
325#[async_trait]
326impl ContextTrimmer for HybridTrimmer {
327    async fn trim(
328        &self,
329        messages: &[Message],
330        usage: &TokenUsage,
331    ) -> Result<TrimResult, AgentError> {
332        let estimated = estimate_tokens(messages);
333        let threshold = (self.target_tokens as f64 * self.summarization_threshold) as usize;
334
335        // Under threshold: use sliding window
336        if estimated < threshold {
337            return self.sliding.trim(messages, usage).await;
338        }
339
340        // Over threshold: summarize
341        let split_point = messages.len() / 2;
342        let (old_messages, recent_messages) = messages.split_at(split_point);
343
344        let summary = self.summarizer.summarize(old_messages).await?;
345        let mut result_messages = Vec::with_capacity(1 + recent_messages.len());
346        result_messages.push(Message::text(
347            Role::System,
348            format!("Previous conversation summary:\n{summary}"),
349        ));
350        result_messages.extend_from_slice(recent_messages);
351
352        Ok(TrimResult {
353            messages: result_messages.clone(),
354            was_summarized: true,
355            estimated_tokens: estimate_tokens(&result_messages),
356            messages_dropped: old_messages.len(),
357        })
358    }
359
360    fn strategy_name(&self) -> &'static str {
361        "hybrid"
362    }
363}
364
365/// No-op trimmer that passes messages through unchanged.
366///
367/// Useful when no trimming is desired but the `ContextTrimmer`
368/// trait is required by an interface.
369#[derive(Debug, Clone, Copy, Default)]
370pub struct NoOpTrimmer;
371
372#[async_trait]
373impl ContextTrimmer for NoOpTrimmer {
374    async fn trim(
375        &self,
376        messages: &[Message],
377        _usage: &TokenUsage,
378    ) -> Result<TrimResult, AgentError> {
379        Ok(TrimResult {
380            messages: messages.to_vec(),
381            was_summarized: false,
382            estimated_tokens: estimate_tokens(messages),
383            messages_dropped: 0,
384        })
385    }
386
387    fn strategy_name(&self) -> &'static str {
388        "noop"
389    }
390}
391
392// ── Helpers ──────────────────────────────────────────────────────────
393
394/// Sliding window: keep most recent non-system messages, preserve system.
395fn sliding_window_trim(
396    messages: &[Message],
397    max: usize,
398    preserve_first_user: bool,
399) -> Vec<Message> {
400    let non_system: Vec<(usize, &Message)> =
401        messages.iter().enumerate().filter(|(_, m)| m.role != Role::System).collect();
402
403    if non_system.len() <= max {
404        return messages.to_vec();
405    }
406
407    let first_user_idx =
408        if preserve_first_user { messages.iter().position(|m| m.role == Role::User) } else { None };
409
410    let recent_start = non_system.len().saturating_sub(max);
411    let mut to_keep: std::collections::HashSet<usize> =
412        non_system[recent_start..].iter().map(|(idx, _)| *idx).collect();
413
414    // Add first user message if it's not already in the kept set
415    if let Some(first_idx) = first_user_idx {
416        to_keep.insert(first_idx);
417    }
418
419    messages
420        .iter()
421        .enumerate()
422        .filter(|(idx, msg)| msg.role == Role::System || to_keep.contains(idx))
423        .map(|(_, msg)| msg.clone())
424        .collect()
425}
426
427/// Split messages into old and recent halves, preserving system messages and optionally the first
428/// user message.
429fn split_at_threshold(
430    messages: &[Message],
431    recent_count: usize,
432    preserve_first_user: bool,
433) -> (Vec<Message>, Vec<Message>) {
434    let non_system: Vec<(usize, &Message)> =
435        messages.iter().enumerate().filter(|(_, m)| m.role != Role::System).collect();
436
437    if non_system.len() <= recent_count {
438        return (Vec::new(), messages.to_vec());
439    }
440
441    let split_idx = non_system.len() - recent_count;
442    let split_at = non_system[split_idx].0;
443
444    let first_user_idx =
445        if preserve_first_user { messages.iter().position(|m| m.role == Role::User) } else { None };
446
447    let mut old = Vec::new();
448    let mut recent = Vec::new();
449
450    for (idx, msg) in messages.iter().enumerate() {
451        if msg.role == Role::System {
452            // System messages go to both (or just recent)
453            recent.push(msg.clone());
454        } else if idx < split_at {
455            if Some(idx) == first_user_idx {
456                recent.push(msg.clone()); // Keep first user in recent
457            } else {
458                old.push(msg.clone());
459            }
460        } else {
461            recent.push(msg.clone());
462        }
463    }
464
465    (old, recent)
466}
467
468/// Rough token estimation (4 chars ≈ 1 token).
469fn estimate_tokens(messages: &[Message]) -> usize {
470    messages.iter().map(|m| m.content.len() / 4).sum()
471}
472
473// ── Tests ────────────────────────────────────────────────────────────
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    fn msg(role: Role, content: &str) -> Message {
480        Message::text(role, content.to_string())
481    }
482
483    // ── Sliding Window ──────────────────────────────────────────────
484
485    #[tokio::test]
486    async fn sliding_window_noop_when_under_limit() {
487        let trimmer = SlidingWindowTrimmer::new(50, 8192);
488        let messages = vec![msg(Role::User, "hello"), msg(Role::Assistant, "hi")];
489        let usage = TokenUsage::default();
490
491        let result = trimmer.trim(&messages, &usage).await.unwrap();
492        assert_eq!(result.messages.len(), 2);
493        assert!(!result.was_summarized);
494        assert_eq!(result.messages_dropped, 0);
495    }
496
497    #[tokio::test]
498    async fn sliding_window_drops_oldest() {
499        let trimmer = SlidingWindowTrimmer::new(3, 8192);
500        let messages = vec![
501            msg(Role::User, "msg-0"),
502            msg(Role::Assistant, "msg-1"),
503            msg(Role::User, "msg-2"),
504            msg(Role::Assistant, "msg-3"),
505            msg(Role::User, "msg-4"),
506        ];
507        let usage = TokenUsage::default();
508
509        let result = trimmer.trim(&messages, &usage).await.unwrap();
510        assert_eq!(result.messages.len(), 4); // 3 recent + 1 system (first user preserved)
511        assert_eq!(result.messages_dropped, 1);
512    }
513
514    #[tokio::test]
515    async fn sliding_window_preserves_system() {
516        let trimmer = SlidingWindowTrimmer::new(2, 8192);
517        let messages = vec![
518            msg(Role::System, "system instructions"),
519            msg(Role::User, "old"),
520            msg(Role::Assistant, "mid"),
521            msg(Role::User, "new"),
522        ];
523        let usage = TokenUsage::default();
524
525        let result = trimmer.trim(&messages, &usage).await.unwrap();
526        assert_eq!(result.messages[0].role, Role::System);
527        assert!(result.messages.iter().any(|m| m.content == "new"));
528    }
529
530    #[tokio::test]
531    async fn sliding_window_preserves_first_user() {
532        let trimmer = SlidingWindowTrimmer::new(2, 8192);
533        let messages = vec![
534            msg(Role::User, "original-task"),
535            msg(Role::Assistant, "response-1"),
536            msg(Role::User, "follow-up"),
537            msg(Role::Assistant, "response-2"),
538        ];
539        let usage = TokenUsage::default();
540
541        let result = trimmer.trim(&messages, &usage).await.unwrap();
542        assert!(result.messages.iter().any(|m| m.content == "original-task"));
543    }
544
545    // ── NoOp Trimmer ────────────────────────────────────────────────
546
547    #[tokio::test]
548    async fn noop_trimmer_passes_through() {
549        let trimmer = NoOpTrimmer;
550        let messages = vec![msg(Role::User, "a"), msg(Role::Assistant, "b")];
551        let usage = TokenUsage::default();
552
553        let result = trimmer.trim(&messages, &usage).await.unwrap();
554        assert_eq!(result.messages.len(), 2);
555        assert_eq!(result.messages_dropped, 0);
556        assert!(!result.was_summarized);
557    }
558
559    // ── Token Estimation ────────────────────────────────────────────
560
561    #[test]
562    fn estimate_tokens_basic() {
563        let messages = vec![msg(Role::User, "hello world")]; // 11 chars ≈ 2 tokens
564        assert_eq!(estimate_tokens(&messages), 2);
565    }
566
567    #[test]
568    fn estimate_tokens_empty() {
569        assert_eq!(estimate_tokens(&[]), 0);
570    }
571}