Skip to main content

neuron_context/
strategies.rs

1//! Context compaction strategies implementing [`ContextStrategy`].
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use neuron_types::{ContextError, ContextStrategy, Message, Provider, Role};
8
9use crate::counter::TokenCounter;
10
11// ---- Dyn-compatible wrapper for CompositeStrategy --------------------------
12
13/// Type alias for a pinned, boxed, `Send` future returning compacted messages.
14type CompactFuture<'a> =
15    Pin<Box<dyn Future<Output = Result<Vec<Message>, ContextError>> + Send + 'a>>;
16
17/// A dyn-compatible strategy object. Used internally by [`CompositeStrategy`].
18///
19/// Because `ContextStrategy::compact` returns `impl Future` (RPITIT), the trait
20/// is not dyn-compatible. `ErasedStrategy` provides a vtable-friendly equivalent
21/// that boxes the future.
22trait ErasedStrategy: Send + Sync {
23    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a>;
24    fn erased_token_estimate(&self, messages: &[Message]) -> usize;
25    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool;
26}
27
28impl<S: ContextStrategy> ErasedStrategy for S {
29    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a> {
30        Box::pin(self.compact(messages))
31    }
32
33    fn erased_token_estimate(&self, messages: &[Message]) -> usize {
34        self.token_estimate(messages)
35    }
36
37    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool {
38        self.should_compact(messages, token_count)
39    }
40}
41
42/// A type-erased wrapper around a [`ContextStrategy`] for use in [`CompositeStrategy`].
43///
44/// Use `BoxedStrategy::new(strategy)` to wrap any strategy.
45///
46/// # Example
47///
48/// ```
49/// use neuron_context::{SlidingWindowStrategy, strategies::BoxedStrategy};
50///
51/// let boxed = BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000));
52/// ```
53pub struct BoxedStrategy(Arc<dyn ErasedStrategy>);
54
55impl BoxedStrategy {
56    /// Wrap any [`ContextStrategy`] into a type-erased `BoxedStrategy`.
57    #[must_use]
58    pub fn new<S: ContextStrategy + 'static>(strategy: S) -> Self {
59        BoxedStrategy(Arc::new(strategy))
60    }
61}
62
63impl ContextStrategy for BoxedStrategy {
64    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
65        self.0.erased_should_compact(messages, token_count)
66    }
67
68    fn compact(
69        &self,
70        messages: Vec<Message>,
71    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
72    {
73        let inner = Arc::clone(&self.0);
74        async move { inner.erased_compact(messages).await }
75    }
76
77    fn token_estimate(&self, messages: &[Message]) -> usize {
78        self.0.erased_token_estimate(messages)
79    }
80}
81
82// ---- SlidingWindowStrategy --------------------------------------------------
83
84/// Keeps system messages plus the last `window_size` non-system messages.
85///
86/// Triggers compaction when the estimated token count exceeds `max_tokens`.
87///
88/// # Example
89///
90/// ```
91/// use neuron_context::SlidingWindowStrategy;
92///
93/// let strategy = SlidingWindowStrategy::new(10, 100_000);
94/// ```
95pub struct SlidingWindowStrategy {
96    window_size: usize,
97    counter: TokenCounter,
98    max_tokens: usize,
99}
100
101impl SlidingWindowStrategy {
102    /// Creates a new `SlidingWindowStrategy`.
103    ///
104    /// # Arguments
105    /// * `window_size` — maximum number of non-system messages to retain
106    /// * `max_tokens` — token threshold above which compaction is triggered
107    #[must_use]
108    pub fn new(window_size: usize, max_tokens: usize) -> Self {
109        Self {
110            window_size,
111            counter: TokenCounter::new(),
112            max_tokens,
113        }
114    }
115
116    /// Creates a new `SlidingWindowStrategy` with a custom [`TokenCounter`].
117    #[must_use]
118    pub fn with_counter(window_size: usize, max_tokens: usize, counter: TokenCounter) -> Self {
119        Self {
120            window_size,
121            counter,
122            max_tokens,
123        }
124    }
125}
126
127impl ContextStrategy for SlidingWindowStrategy {
128    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
129        let _ = messages;
130        token_count > self.max_tokens
131    }
132
133    fn compact(
134        &self,
135        messages: Vec<Message>,
136    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
137    {
138        let window_size = self.window_size;
139        async move {
140            let (system_msgs, non_system): (Vec<_>, Vec<_>) =
141                messages.into_iter().partition(|m| m.role == Role::System);
142
143            let recent: Vec<Message> = non_system
144                .into_iter()
145                .rev()
146                .take(window_size)
147                .collect::<Vec<_>>()
148                .into_iter()
149                .rev()
150                .collect();
151
152            let mut result = system_msgs;
153            result.extend(recent);
154            Ok(result)
155        }
156    }
157
158    fn token_estimate(&self, messages: &[Message]) -> usize {
159        self.counter.estimate_messages(messages)
160    }
161}
162
163// ---- ToolResultClearingStrategy ---------------------------------------------
164
165/// Replaces old tool result content with a placeholder to reduce token usage.
166///
167/// Keeps the most recent `keep_recent_n` tool results intact and replaces
168/// older ones with `[tool result cleared]` while preserving the `tool_use_id`
169/// so the conversation still makes semantic sense.
170///
171/// # Example
172///
173/// ```
174/// use neuron_context::ToolResultClearingStrategy;
175///
176/// let strategy = ToolResultClearingStrategy::new(2, 100_000);
177/// ```
178pub struct ToolResultClearingStrategy {
179    keep_recent_n: usize,
180    counter: TokenCounter,
181    max_tokens: usize,
182}
183
184impl ToolResultClearingStrategy {
185    /// Creates a new `ToolResultClearingStrategy`.
186    ///
187    /// # Arguments
188    /// * `keep_recent_n` — number of most-recent tool results to leave untouched
189    /// * `max_tokens` — token threshold above which compaction is triggered
190    #[must_use]
191    pub fn new(keep_recent_n: usize, max_tokens: usize) -> Self {
192        Self {
193            keep_recent_n,
194            counter: TokenCounter::new(),
195            max_tokens,
196        }
197    }
198
199    /// Creates a new `ToolResultClearingStrategy` with a custom [`TokenCounter`].
200    #[must_use]
201    pub fn with_counter(keep_recent_n: usize, max_tokens: usize, counter: TokenCounter) -> Self {
202        Self {
203            keep_recent_n,
204            counter,
205            max_tokens,
206        }
207    }
208}
209
210impl ContextStrategy for ToolResultClearingStrategy {
211    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
212        let _ = messages;
213        token_count > self.max_tokens
214    }
215
216    fn compact(
217        &self,
218        messages: Vec<Message>,
219    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
220    {
221        use neuron_types::{ContentBlock, ContentItem};
222
223        let keep_recent_n = self.keep_recent_n;
224        async move {
225            // Collect positions of all ToolResult blocks across all messages.
226            let mut tool_result_positions: Vec<(usize, usize)> = Vec::new();
227            for (msg_idx, msg) in messages.iter().enumerate() {
228                for (block_idx, block) in msg.content.iter().enumerate() {
229                    if matches!(block, ContentBlock::ToolResult { .. }) {
230                        tool_result_positions.push((msg_idx, block_idx));
231                    }
232                }
233            }
234
235            let total = tool_result_positions.len();
236            let to_clear_count = total.saturating_sub(keep_recent_n);
237
238            if to_clear_count == 0 {
239                return Ok(messages);
240            }
241
242            let to_clear = tool_result_positions[..to_clear_count].to_vec();
243            let mut messages = messages;
244            for (msg_idx, block_idx) in to_clear {
245                let block = &mut messages[msg_idx].content[block_idx];
246                if let ContentBlock::ToolResult {
247                    content, is_error, ..
248                } = block
249                {
250                    *content = vec![ContentItem::Text("[tool result cleared]".to_string())];
251                    *is_error = false;
252                }
253            }
254
255            Ok(messages)
256        }
257    }
258
259    fn token_estimate(&self, messages: &[Message]) -> usize {
260        self.counter.estimate_messages(messages)
261    }
262}
263
264// ---- SummarizationStrategy --------------------------------------------------
265
266/// Summarizes old messages using an LLM provider, preserving recent messages verbatim.
267///
268/// When compaction is triggered, messages older than `preserve_recent` are sent
269/// to the provider with a summarization prompt. The response replaces the old
270/// messages with a single `User` message containing the summary, followed by
271/// the preserved recent messages.
272///
273/// # Example
274///
275/// ```ignore
276/// use neuron_context::SummarizationStrategy;
277///
278/// let strategy = SummarizationStrategy::new(provider, 5, 100_000);
279/// ```
280pub struct SummarizationStrategy<P: Provider> {
281    provider: P,
282    preserve_recent: usize,
283    counter: TokenCounter,
284    max_tokens: usize,
285}
286
287impl<P: Provider> SummarizationStrategy<P> {
288    /// Creates a new `SummarizationStrategy`.
289    ///
290    /// # Arguments
291    /// * `provider` — the LLM provider used for summarization
292    /// * `preserve_recent` — number of most-recent messages to keep verbatim
293    /// * `max_tokens` — token threshold above which compaction is triggered
294    #[must_use]
295    pub fn new(provider: P, preserve_recent: usize, max_tokens: usize) -> Self {
296        Self {
297            provider,
298            preserve_recent,
299            counter: TokenCounter::new(),
300            max_tokens,
301        }
302    }
303
304    /// Creates a new `SummarizationStrategy` with a custom [`TokenCounter`].
305    #[must_use]
306    pub fn with_counter(
307        provider: P,
308        preserve_recent: usize,
309        max_tokens: usize,
310        counter: TokenCounter,
311    ) -> Self {
312        Self {
313            provider,
314            preserve_recent,
315            counter,
316            max_tokens,
317        }
318    }
319}
320
321impl<P: Provider> ContextStrategy for SummarizationStrategy<P> {
322    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
323        let _ = messages;
324        token_count > self.max_tokens
325    }
326
327    fn compact(
328        &self,
329        messages: Vec<Message>,
330    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
331    {
332        use neuron_types::{CompletionRequest, ContentBlock, Role, SystemPrompt};
333
334        let preserve_recent = self.preserve_recent;
335
336        // Partition before entering the async block so we don't borrow `messages`.
337        let (system_msgs, non_system): (Vec<Message>, Vec<Message>) =
338            messages.into_iter().partition(|m| m.role == Role::System);
339
340        let split_at = non_system.len().saturating_sub(preserve_recent);
341        let old_messages = non_system[..split_at].to_vec();
342        let recent_messages = non_system[split_at..].to_vec();
343
344        let summarize_request = CompletionRequest {
345            model: String::new(),
346            messages: old_messages,
347            system: Some(SystemPrompt::Text(
348                "Summarize the conversation above concisely. Focus on key information, \
349                 decisions made, and results from tool calls. Write in third person."
350                    .to_string(),
351            )),
352            tools: vec![],
353            max_tokens: Some(1024),
354            temperature: Some(0.0),
355            top_p: None,
356            stop_sequences: vec![],
357            tool_choice: None,
358            response_format: None,
359            thinking: None,
360            reasoning_effort: None,
361            extra: None,
362            context_management: None,
363        };
364
365        async move {
366            let response = self.provider.complete(summarize_request).await?;
367
368            let summary_text = response
369                .message
370                .content
371                .into_iter()
372                .filter_map(|block| {
373                    if let ContentBlock::Text(text) = block {
374                        Some(text)
375                    } else {
376                        None
377                    }
378                })
379                .collect::<Vec<_>>()
380                .join("\n");
381
382            let summary_message = Message {
383                role: Role::User,
384                content: vec![ContentBlock::Text(format!(
385                    "[Summary of earlier conversation]\n{summary_text}"
386                ))],
387            };
388
389            let mut result = system_msgs;
390            result.push(summary_message);
391            result.extend(recent_messages);
392            Ok(result)
393        }
394    }
395
396    fn token_estimate(&self, messages: &[Message]) -> usize {
397        self.counter.estimate_messages(messages)
398    }
399}
400
401// ---- CompositeStrategy ------------------------------------------------------
402
403/// Chains multiple strategies, applying each in order until token budget is met.
404///
405/// Each strategy is tried in sequence. After each strategy's `compact` runs,
406/// the resulting token count is re-estimated. If it falls below `max_tokens`,
407/// iteration stops early.
408///
409/// Use [`BoxedStrategy::new`] to wrap concrete strategies before collecting them.
410///
411/// # Example
412///
413/// ```
414/// use neuron_context::{CompositeStrategy, SlidingWindowStrategy, ToolResultClearingStrategy};
415/// use neuron_context::strategies::BoxedStrategy;
416///
417/// let strategy = CompositeStrategy::new(vec![
418///     BoxedStrategy::new(ToolResultClearingStrategy::new(2, 100_000)),
419///     BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000)),
420/// ], 100_000);
421/// ```
422pub struct CompositeStrategy {
423    strategies: Vec<BoxedStrategy>,
424    counter: TokenCounter,
425    max_tokens: usize,
426}
427
428impl CompositeStrategy {
429    /// Creates a new `CompositeStrategy`.
430    ///
431    /// # Arguments
432    /// * `strategies` — ordered list of type-erased strategies to apply
433    /// * `max_tokens` — token threshold above which compaction is triggered
434    #[must_use]
435    pub fn new(strategies: Vec<BoxedStrategy>, max_tokens: usize) -> Self {
436        Self {
437            strategies,
438            counter: TokenCounter::new(),
439            max_tokens,
440        }
441    }
442}
443
444impl ContextStrategy for CompositeStrategy {
445    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
446        let _ = messages;
447        token_count > self.max_tokens
448    }
449
450    fn compact(
451        &self,
452        messages: Vec<Message>,
453    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
454    {
455        // Snapshot what we need before entering the async block.
456        let inner_refs: Vec<Arc<dyn ErasedStrategy>> =
457            self.strategies.iter().map(|b| Arc::clone(&b.0)).collect();
458        let max_tokens = self.max_tokens;
459        let counter = TokenCounter::new();
460
461        async move {
462            let mut current = messages;
463            for strategy in &inner_refs {
464                let token_count = counter.estimate_messages(&current);
465                if token_count <= max_tokens {
466                    break;
467                }
468                current = strategy.erased_compact(current).await?;
469            }
470            Ok(current)
471        }
472    }
473
474    fn token_estimate(&self, messages: &[Message]) -> usize {
475        self.counter.estimate_messages(messages)
476    }
477}