Skip to main content

neuron_context/
strategies.rs

1//! Context compaction strategies implementing [`ContextStrategy`].
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use neuron_types::{ContextError, ContextStrategy, Message, Provider, Role};
8
9use crate::counter::TokenCounter;
10
11// ---- Dyn-compatible wrapper for CompositeStrategy --------------------------
12
13/// Type alias for a pinned, boxed, `Send` future returning compacted messages.
14type CompactFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<Message>, ContextError>> + Send + 'a>>;
15
16/// A dyn-compatible strategy object. Used internally by [`CompositeStrategy`].
17///
18/// Because `ContextStrategy::compact` returns `impl Future` (RPITIT), the trait
19/// is not dyn-compatible. `ErasedStrategy` provides a vtable-friendly equivalent
20/// that boxes the future.
21trait ErasedStrategy: Send + Sync {
22    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a>;
23    fn erased_token_estimate(&self, messages: &[Message]) -> usize;
24    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool;
25}
26
27impl<S: ContextStrategy> ErasedStrategy for S {
28    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a> {
29        Box::pin(self.compact(messages))
30    }
31
32    fn erased_token_estimate(&self, messages: &[Message]) -> usize {
33        self.token_estimate(messages)
34    }
35
36    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool {
37        self.should_compact(messages, token_count)
38    }
39}
40
41/// A type-erased wrapper around a [`ContextStrategy`] for use in [`CompositeStrategy`].
42///
43/// Use `BoxedStrategy::new(strategy)` to wrap any strategy.
44///
45/// # Example
46///
47/// ```
48/// use neuron_context::{SlidingWindowStrategy, strategies::BoxedStrategy};
49///
50/// let boxed = BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000));
51/// ```
52pub struct BoxedStrategy(Arc<dyn ErasedStrategy>);
53
54impl BoxedStrategy {
55    /// Wrap any [`ContextStrategy`] into a type-erased `BoxedStrategy`.
56    #[must_use]
57    pub fn new<S: ContextStrategy + 'static>(strategy: S) -> Self {
58        BoxedStrategy(Arc::new(strategy))
59    }
60}
61
62impl ContextStrategy for BoxedStrategy {
63    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
64        self.0.erased_should_compact(messages, token_count)
65    }
66
67    fn compact(
68        &self,
69        messages: Vec<Message>,
70    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend {
71        let inner = Arc::clone(&self.0);
72        async move { inner.erased_compact(messages).await }
73    }
74
75    fn token_estimate(&self, messages: &[Message]) -> usize {
76        self.0.erased_token_estimate(messages)
77    }
78}
79
80// ---- SlidingWindowStrategy --------------------------------------------------
81
82/// Keeps system messages plus the last `window_size` non-system messages.
83///
84/// Triggers compaction when the estimated token count exceeds `max_tokens`.
85///
86/// # Example
87///
88/// ```
89/// use neuron_context::SlidingWindowStrategy;
90///
91/// let strategy = SlidingWindowStrategy::new(10, 100_000);
92/// ```
93pub struct SlidingWindowStrategy {
94    window_size: usize,
95    counter: TokenCounter,
96    max_tokens: usize,
97}
98
99impl SlidingWindowStrategy {
100    /// Creates a new `SlidingWindowStrategy`.
101    ///
102    /// # Arguments
103    /// * `window_size` — maximum number of non-system messages to retain
104    /// * `max_tokens` — token threshold above which compaction is triggered
105    #[must_use]
106    pub fn new(window_size: usize, max_tokens: usize) -> Self {
107        Self { window_size, counter: TokenCounter::new(), max_tokens }
108    }
109
110    /// Creates a new `SlidingWindowStrategy` with a custom [`TokenCounter`].
111    #[must_use]
112    pub fn with_counter(window_size: usize, max_tokens: usize, counter: TokenCounter) -> Self {
113        Self { window_size, counter, max_tokens }
114    }
115}
116
117impl ContextStrategy for SlidingWindowStrategy {
118    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
119        let _ = messages;
120        token_count > self.max_tokens
121    }
122
123    fn compact(
124        &self,
125        messages: Vec<Message>,
126    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
127    {
128        let window_size = self.window_size;
129        async move {
130            let (system_msgs, non_system): (Vec<_>, Vec<_>) =
131                messages.into_iter().partition(|m| m.role == Role::System);
132
133            let recent: Vec<Message> = non_system
134                .into_iter()
135                .rev()
136                .take(window_size)
137                .collect::<Vec<_>>()
138                .into_iter()
139                .rev()
140                .collect();
141
142            let mut result = system_msgs;
143            result.extend(recent);
144            Ok(result)
145        }
146    }
147
148    fn token_estimate(&self, messages: &[Message]) -> usize {
149        self.counter.estimate_messages(messages)
150    }
151}
152
153// ---- ToolResultClearingStrategy ---------------------------------------------
154
155/// Replaces old tool result content with a placeholder to reduce token usage.
156///
157/// Keeps the most recent `keep_recent_n` tool results intact and replaces
158/// older ones with `[tool result cleared]` while preserving the `tool_use_id`
159/// so the conversation still makes semantic sense.
160///
161/// # Example
162///
163/// ```
164/// use neuron_context::ToolResultClearingStrategy;
165///
166/// let strategy = ToolResultClearingStrategy::new(2, 100_000);
167/// ```
168pub struct ToolResultClearingStrategy {
169    keep_recent_n: usize,
170    counter: TokenCounter,
171    max_tokens: usize,
172}
173
174impl ToolResultClearingStrategy {
175    /// Creates a new `ToolResultClearingStrategy`.
176    ///
177    /// # Arguments
178    /// * `keep_recent_n` — number of most-recent tool results to leave untouched
179    /// * `max_tokens` — token threshold above which compaction is triggered
180    #[must_use]
181    pub fn new(keep_recent_n: usize, max_tokens: usize) -> Self {
182        Self { keep_recent_n, counter: TokenCounter::new(), max_tokens }
183    }
184
185    /// Creates a new `ToolResultClearingStrategy` with a custom [`TokenCounter`].
186    #[must_use]
187    pub fn with_counter(
188        keep_recent_n: usize,
189        max_tokens: usize,
190        counter: TokenCounter,
191    ) -> Self {
192        Self { keep_recent_n, counter, max_tokens }
193    }
194}
195
196impl ContextStrategy for ToolResultClearingStrategy {
197    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
198        let _ = messages;
199        token_count > self.max_tokens
200    }
201
202    fn compact(
203        &self,
204        messages: Vec<Message>,
205    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
206    {
207        use neuron_types::{ContentBlock, ContentItem};
208
209        let keep_recent_n = self.keep_recent_n;
210        async move {
211            // Collect positions of all ToolResult blocks across all messages.
212            let mut tool_result_positions: Vec<(usize, usize)> = Vec::new();
213            for (msg_idx, msg) in messages.iter().enumerate() {
214                for (block_idx, block) in msg.content.iter().enumerate() {
215                    if matches!(block, ContentBlock::ToolResult { .. }) {
216                        tool_result_positions.push((msg_idx, block_idx));
217                    }
218                }
219            }
220
221            let total = tool_result_positions.len();
222            let to_clear_count = total.saturating_sub(keep_recent_n);
223
224            if to_clear_count == 0 {
225                return Ok(messages);
226            }
227
228            let to_clear = tool_result_positions[..to_clear_count].to_vec();
229            let mut messages = messages;
230            for (msg_idx, block_idx) in to_clear {
231                let block = &mut messages[msg_idx].content[block_idx];
232                if let ContentBlock::ToolResult { content, is_error, .. } = block {
233                    *content = vec![ContentItem::Text("[tool result cleared]".to_string())];
234                    *is_error = false;
235                }
236            }
237
238            Ok(messages)
239        }
240    }
241
242    fn token_estimate(&self, messages: &[Message]) -> usize {
243        self.counter.estimate_messages(messages)
244    }
245}
246
247// ---- SummarizationStrategy --------------------------------------------------
248
249/// Summarizes old messages using an LLM provider, preserving recent messages verbatim.
250///
251/// When compaction is triggered, messages older than `preserve_recent` are sent
252/// to the provider with a summarization prompt. The response replaces the old
253/// messages with a single `User` message containing the summary, followed by
254/// the preserved recent messages.
255///
256/// # Example
257///
258/// ```ignore
259/// use neuron_context::SummarizationStrategy;
260///
261/// let strategy = SummarizationStrategy::new(provider, 5, 100_000);
262/// ```
263pub struct SummarizationStrategy<P: Provider> {
264    provider: P,
265    preserve_recent: usize,
266    counter: TokenCounter,
267    max_tokens: usize,
268}
269
270impl<P: Provider> SummarizationStrategy<P> {
271    /// Creates a new `SummarizationStrategy`.
272    ///
273    /// # Arguments
274    /// * `provider` — the LLM provider used for summarization
275    /// * `preserve_recent` — number of most-recent messages to keep verbatim
276    /// * `max_tokens` — token threshold above which compaction is triggered
277    #[must_use]
278    pub fn new(provider: P, preserve_recent: usize, max_tokens: usize) -> Self {
279        Self { provider, preserve_recent, counter: TokenCounter::new(), max_tokens }
280    }
281
282    /// Creates a new `SummarizationStrategy` with a custom [`TokenCounter`].
283    #[must_use]
284    pub fn with_counter(
285        provider: P,
286        preserve_recent: usize,
287        max_tokens: usize,
288        counter: TokenCounter,
289    ) -> Self {
290        Self { provider, preserve_recent, counter, max_tokens }
291    }
292}
293
294impl<P: Provider> ContextStrategy for SummarizationStrategy<P> {
295    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
296        let _ = messages;
297        token_count > self.max_tokens
298    }
299
300    fn compact(
301        &self,
302        messages: Vec<Message>,
303    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
304    {
305        use neuron_types::{CompletionRequest, ContentBlock, Role, SystemPrompt};
306
307        let preserve_recent = self.preserve_recent;
308
309        // Partition before entering the async block so we don't borrow `messages`.
310        let (system_msgs, non_system): (Vec<Message>, Vec<Message>) =
311            messages.into_iter().partition(|m| m.role == Role::System);
312
313        let split_at = non_system.len().saturating_sub(preserve_recent);
314        let old_messages = non_system[..split_at].to_vec();
315        let recent_messages = non_system[split_at..].to_vec();
316
317        let summarize_request = CompletionRequest {
318            model: String::new(),
319            messages: old_messages,
320            system: Some(SystemPrompt::Text(
321                "Summarize the conversation above concisely. Focus on key information, \
322                 decisions made, and results from tool calls. Write in third person."
323                    .to_string(),
324            )),
325            tools: vec![],
326            max_tokens: Some(1024),
327            temperature: Some(0.0),
328            top_p: None,
329            stop_sequences: vec![],
330            tool_choice: None,
331            response_format: None,
332            thinking: None,
333            reasoning_effort: None,
334            extra: None,
335            context_management: None,
336        };
337
338        async move {
339            let response = self.provider.complete(summarize_request).await?;
340
341            let summary_text = response
342                .message
343                .content
344                .into_iter()
345                .filter_map(|block| {
346                    if let ContentBlock::Text(text) = block {
347                        Some(text)
348                    } else {
349                        None
350                    }
351                })
352                .collect::<Vec<_>>()
353                .join("\n");
354
355            let summary_message = Message {
356                role: Role::User,
357                content: vec![ContentBlock::Text(format!(
358                    "[Summary of earlier conversation]\n{summary_text}"
359                ))],
360            };
361
362            let mut result = system_msgs;
363            result.push(summary_message);
364            result.extend(recent_messages);
365            Ok(result)
366        }
367    }
368
369    fn token_estimate(&self, messages: &[Message]) -> usize {
370        self.counter.estimate_messages(messages)
371    }
372}
373
374// ---- CompositeStrategy ------------------------------------------------------
375
376/// Chains multiple strategies, applying each in order until token budget is met.
377///
378/// Each strategy is tried in sequence. After each strategy's `compact` runs,
379/// the resulting token count is re-estimated. If it falls below `max_tokens`,
380/// iteration stops early.
381///
382/// Use [`BoxedStrategy::new`] to wrap concrete strategies before collecting them.
383///
384/// # Example
385///
386/// ```
387/// use neuron_context::{CompositeStrategy, SlidingWindowStrategy, ToolResultClearingStrategy};
388/// use neuron_context::strategies::BoxedStrategy;
389///
390/// let strategy = CompositeStrategy::new(vec![
391///     BoxedStrategy::new(ToolResultClearingStrategy::new(2, 100_000)),
392///     BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000)),
393/// ], 100_000);
394/// ```
395pub struct CompositeStrategy {
396    strategies: Vec<BoxedStrategy>,
397    counter: TokenCounter,
398    max_tokens: usize,
399}
400
401impl CompositeStrategy {
402    /// Creates a new `CompositeStrategy`.
403    ///
404    /// # Arguments
405    /// * `strategies` — ordered list of type-erased strategies to apply
406    /// * `max_tokens` — token threshold above which compaction is triggered
407    #[must_use]
408    pub fn new(strategies: Vec<BoxedStrategy>, max_tokens: usize) -> Self {
409        Self { strategies, counter: TokenCounter::new(), max_tokens }
410    }
411}
412
413impl ContextStrategy for CompositeStrategy {
414    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
415        let _ = messages;
416        token_count > self.max_tokens
417    }
418
419    fn compact(
420        &self,
421        messages: Vec<Message>,
422    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
423    {
424        // Snapshot what we need before entering the async block.
425        let inner_refs: Vec<Arc<dyn ErasedStrategy>> =
426            self.strategies.iter().map(|b| Arc::clone(&b.0)).collect();
427        let max_tokens = self.max_tokens;
428        let counter = TokenCounter::new();
429
430        async move {
431            let mut current = messages;
432            for strategy in &inner_refs {
433                let token_count = counter.estimate_messages(&current);
434                if token_count <= max_tokens {
435                    break;
436                }
437                current = strategy.erased_compact(current).await?;
438            }
439            Ok(current)
440        }
441    }
442
443    fn token_estimate(&self, messages: &[Message]) -> usize {
444        self.counter.estimate_messages(messages)
445    }
446}