Skip to main content

neuron_context/
strategies.rs

1//! Context compaction strategies implementing [`ContextStrategy`].
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use neuron_types::{ContextError, ContextStrategy, Message, Provider, Role};
8
9use crate::counter::TokenCounter;
10
11// ---- Dyn-compatible wrapper for CompositeStrategy --------------------------
12
13/// Type alias for a pinned, boxed, `Send` future returning compacted messages.
14type CompactFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<Message>, ContextError>> + Send + 'a>>;
15
16/// A dyn-compatible strategy object. Used internally by [`CompositeStrategy`].
17///
18/// Because `ContextStrategy::compact` returns `impl Future` (RPITIT), the trait
19/// is not dyn-compatible. `ErasedStrategy` provides a vtable-friendly equivalent
20/// that boxes the future.
21trait ErasedStrategy: Send + Sync {
22    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a>;
23    fn erased_token_estimate(&self, messages: &[Message]) -> usize;
24    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool;
25}
26
27impl<S: ContextStrategy> ErasedStrategy for S {
28    fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a> {
29        Box::pin(self.compact(messages))
30    }
31
32    fn erased_token_estimate(&self, messages: &[Message]) -> usize {
33        self.token_estimate(messages)
34    }
35
36    fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool {
37        self.should_compact(messages, token_count)
38    }
39}
40
41/// A type-erased wrapper around a [`ContextStrategy`] for use in [`CompositeStrategy`].
42///
43/// Use `BoxedStrategy::new(strategy)` to wrap any strategy.
44///
45/// # Example
46///
47/// ```
48/// use neuron_context::{SlidingWindowStrategy, strategies::BoxedStrategy};
49///
50/// let boxed = BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000));
51/// ```
52pub struct BoxedStrategy(Arc<dyn ErasedStrategy>);
53
54impl BoxedStrategy {
55    /// Wrap any [`ContextStrategy`] into a type-erased `BoxedStrategy`.
56    #[must_use]
57    pub fn new<S: ContextStrategy + 'static>(strategy: S) -> Self {
58        BoxedStrategy(Arc::new(strategy))
59    }
60}
61
62impl ContextStrategy for BoxedStrategy {
63    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
64        self.0.erased_should_compact(messages, token_count)
65    }
66
67    fn compact(
68        &self,
69        messages: Vec<Message>,
70    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend {
71        let inner = Arc::clone(&self.0);
72        async move { inner.erased_compact(messages).await }
73    }
74
75    fn token_estimate(&self, messages: &[Message]) -> usize {
76        self.0.erased_token_estimate(messages)
77    }
78}
79
80// ---- SlidingWindowStrategy --------------------------------------------------
81
82/// Keeps system messages plus the last `window_size` non-system messages.
83///
84/// Triggers compaction when the estimated token count exceeds `max_tokens`.
85///
86/// # Example
87///
88/// ```
89/// use neuron_context::SlidingWindowStrategy;
90///
91/// let strategy = SlidingWindowStrategy::new(10, 100_000);
92/// ```
93pub struct SlidingWindowStrategy {
94    window_size: usize,
95    counter: TokenCounter,
96    max_tokens: usize,
97}
98
99impl SlidingWindowStrategy {
100    /// Creates a new `SlidingWindowStrategy`.
101    ///
102    /// # Arguments
103    /// * `window_size` — maximum number of non-system messages to retain
104    /// * `max_tokens` — token threshold above which compaction is triggered
105    #[must_use]
106    pub fn new(window_size: usize, max_tokens: usize) -> Self {
107        Self { window_size, counter: TokenCounter::new(), max_tokens }
108    }
109
110    /// Creates a new `SlidingWindowStrategy` with a custom [`TokenCounter`].
111    #[must_use]
112    pub fn with_counter(window_size: usize, max_tokens: usize, counter: TokenCounter) -> Self {
113        Self { window_size, counter, max_tokens }
114    }
115}
116
117impl ContextStrategy for SlidingWindowStrategy {
118    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
119        let _ = messages;
120        token_count > self.max_tokens
121    }
122
123    fn compact(
124        &self,
125        messages: Vec<Message>,
126    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
127    {
128        let window_size = self.window_size;
129        async move {
130            let (system_msgs, non_system): (Vec<_>, Vec<_>) =
131                messages.into_iter().partition(|m| m.role == Role::System);
132
133            let recent: Vec<Message> = non_system
134                .into_iter()
135                .rev()
136                .take(window_size)
137                .collect::<Vec<_>>()
138                .into_iter()
139                .rev()
140                .collect();
141
142            let mut result = system_msgs;
143            result.extend(recent);
144            Ok(result)
145        }
146    }
147
148    fn token_estimate(&self, messages: &[Message]) -> usize {
149        self.counter.estimate_messages(messages)
150    }
151}
152
153// ---- ToolResultClearingStrategy ---------------------------------------------
154
155/// Replaces old tool result content with a placeholder to reduce token usage.
156///
157/// Keeps the most recent `keep_recent_n` tool results intact and replaces
158/// older ones with `[tool result cleared]` while preserving the `tool_use_id`
159/// so the conversation still makes semantic sense.
160///
161/// # Example
162///
163/// ```
164/// use neuron_context::ToolResultClearingStrategy;
165///
166/// let strategy = ToolResultClearingStrategy::new(2, 100_000);
167/// ```
168pub struct ToolResultClearingStrategy {
169    keep_recent_n: usize,
170    counter: TokenCounter,
171    max_tokens: usize,
172}
173
174impl ToolResultClearingStrategy {
175    /// Creates a new `ToolResultClearingStrategy`.
176    ///
177    /// # Arguments
178    /// * `keep_recent_n` — number of most-recent tool results to leave untouched
179    /// * `max_tokens` — token threshold above which compaction is triggered
180    #[must_use]
181    pub fn new(keep_recent_n: usize, max_tokens: usize) -> Self {
182        Self { keep_recent_n, counter: TokenCounter::new(), max_tokens }
183    }
184
185    /// Creates a new `ToolResultClearingStrategy` with a custom [`TokenCounter`].
186    #[must_use]
187    pub fn with_counter(
188        keep_recent_n: usize,
189        max_tokens: usize,
190        counter: TokenCounter,
191    ) -> Self {
192        Self { keep_recent_n, counter, max_tokens }
193    }
194}
195
196impl ContextStrategy for ToolResultClearingStrategy {
197    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
198        let _ = messages;
199        token_count > self.max_tokens
200    }
201
202    fn compact(
203        &self,
204        messages: Vec<Message>,
205    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
206    {
207        use neuron_types::{ContentBlock, ContentItem};
208
209        let keep_recent_n = self.keep_recent_n;
210        async move {
211            // Collect positions of all ToolResult blocks across all messages.
212            let mut tool_result_positions: Vec<(usize, usize)> = Vec::new();
213            for (msg_idx, msg) in messages.iter().enumerate() {
214                for (block_idx, block) in msg.content.iter().enumerate() {
215                    if matches!(block, ContentBlock::ToolResult { .. }) {
216                        tool_result_positions.push((msg_idx, block_idx));
217                    }
218                }
219            }
220
221            let total = tool_result_positions.len();
222            let to_clear_count = total.saturating_sub(keep_recent_n);
223
224            if to_clear_count == 0 {
225                return Ok(messages);
226            }
227
228            let to_clear = tool_result_positions[..to_clear_count].to_vec();
229            let mut messages = messages;
230            for (msg_idx, block_idx) in to_clear {
231                let block = &mut messages[msg_idx].content[block_idx];
232                if let ContentBlock::ToolResult { content, is_error, .. } = block {
233                    *content = vec![ContentItem::Text("[tool result cleared]".to_string())];
234                    *is_error = false;
235                }
236            }
237
238            Ok(messages)
239        }
240    }
241
242    fn token_estimate(&self, messages: &[Message]) -> usize {
243        self.counter.estimate_messages(messages)
244    }
245}
246
247// ---- SummarizationStrategy --------------------------------------------------
248
249/// Summarizes old messages using an LLM provider, preserving recent messages verbatim.
250///
251/// When compaction is triggered, messages older than `preserve_recent` are sent
252/// to the provider with a summarization prompt. The response replaces the old
253/// messages with a single `User` message containing the summary, followed by
254/// the preserved recent messages.
255///
256/// # Example
257///
258/// ```ignore
259/// use neuron_context::SummarizationStrategy;
260///
261/// let strategy = SummarizationStrategy::new(provider, 5, 100_000);
262/// ```
263pub struct SummarizationStrategy<P: Provider> {
264    provider: P,
265    preserve_recent: usize,
266    counter: TokenCounter,
267    max_tokens: usize,
268}
269
270impl<P: Provider> SummarizationStrategy<P> {
271    /// Creates a new `SummarizationStrategy`.
272    ///
273    /// # Arguments
274    /// * `provider` — the LLM provider used for summarization
275    /// * `preserve_recent` — number of most-recent messages to keep verbatim
276    /// * `max_tokens` — token threshold above which compaction is triggered
277    #[must_use]
278    pub fn new(provider: P, preserve_recent: usize, max_tokens: usize) -> Self {
279        Self { provider, preserve_recent, counter: TokenCounter::new(), max_tokens }
280    }
281
282    /// Creates a new `SummarizationStrategy` with a custom [`TokenCounter`].
283    #[must_use]
284    pub fn with_counter(
285        provider: P,
286        preserve_recent: usize,
287        max_tokens: usize,
288        counter: TokenCounter,
289    ) -> Self {
290        Self { provider, preserve_recent, counter, max_tokens }
291    }
292}
293
294impl<P: Provider> ContextStrategy for SummarizationStrategy<P> {
295    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
296        let _ = messages;
297        token_count > self.max_tokens
298    }
299
300    fn compact(
301        &self,
302        messages: Vec<Message>,
303    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
304    {
305        use neuron_types::{CompletionRequest, ContentBlock, Role, SystemPrompt};
306
307        let preserve_recent = self.preserve_recent;
308
309        // Partition before entering the async block so we don't borrow `messages`.
310        let (system_msgs, non_system): (Vec<Message>, Vec<Message>) =
311            messages.into_iter().partition(|m| m.role == Role::System);
312
313        let split_at = non_system.len().saturating_sub(preserve_recent);
314        let old_messages = non_system[..split_at].to_vec();
315        let recent_messages = non_system[split_at..].to_vec();
316
317        let summarize_request = CompletionRequest {
318            model: String::new(),
319            messages: old_messages,
320            system: Some(SystemPrompt::Text(
321                "Summarize the conversation above concisely. Focus on key information, \
322                 decisions made, and results from tool calls. Write in third person."
323                    .to_string(),
324            )),
325            tools: vec![],
326            max_tokens: Some(1024),
327            temperature: Some(0.0),
328            top_p: None,
329            stop_sequences: vec![],
330            tool_choice: None,
331            response_format: None,
332            thinking: None,
333            reasoning_effort: None,
334            extra: None,
335        };
336
337        async move {
338            let response = self.provider.complete(summarize_request).await?;
339
340            let summary_text = response
341                .message
342                .content
343                .into_iter()
344                .filter_map(|block| {
345                    if let ContentBlock::Text(text) = block {
346                        Some(text)
347                    } else {
348                        None
349                    }
350                })
351                .collect::<Vec<_>>()
352                .join("\n");
353
354            let summary_message = Message {
355                role: Role::User,
356                content: vec![ContentBlock::Text(format!(
357                    "[Summary of earlier conversation]\n{summary_text}"
358                ))],
359            };
360
361            let mut result = system_msgs;
362            result.push(summary_message);
363            result.extend(recent_messages);
364            Ok(result)
365        }
366    }
367
368    fn token_estimate(&self, messages: &[Message]) -> usize {
369        self.counter.estimate_messages(messages)
370    }
371}
372
373// ---- CompositeStrategy ------------------------------------------------------
374
375/// Chains multiple strategies, applying each in order until token budget is met.
376///
377/// Each strategy is tried in sequence. After each strategy's `compact` runs,
378/// the resulting token count is re-estimated. If it falls below `max_tokens`,
379/// iteration stops early.
380///
381/// Use [`BoxedStrategy::new`] to wrap concrete strategies before collecting them.
382///
383/// # Example
384///
385/// ```
386/// use neuron_context::{CompositeStrategy, SlidingWindowStrategy, ToolResultClearingStrategy};
387/// use neuron_context::strategies::BoxedStrategy;
388///
389/// let strategy = CompositeStrategy::new(vec![
390///     BoxedStrategy::new(ToolResultClearingStrategy::new(2, 100_000)),
391///     BoxedStrategy::new(SlidingWindowStrategy::new(10, 100_000)),
392/// ], 100_000);
393/// ```
394pub struct CompositeStrategy {
395    strategies: Vec<BoxedStrategy>,
396    counter: TokenCounter,
397    max_tokens: usize,
398}
399
400impl CompositeStrategy {
401    /// Creates a new `CompositeStrategy`.
402    ///
403    /// # Arguments
404    /// * `strategies` — ordered list of type-erased strategies to apply
405    /// * `max_tokens` — token threshold above which compaction is triggered
406    #[must_use]
407    pub fn new(strategies: Vec<BoxedStrategy>, max_tokens: usize) -> Self {
408        Self { strategies, counter: TokenCounter::new(), max_tokens }
409    }
410}
411
412impl ContextStrategy for CompositeStrategy {
413    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
414        let _ = messages;
415        token_count > self.max_tokens
416    }
417
418    fn compact(
419        &self,
420        messages: Vec<Message>,
421    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
422    {
423        // Snapshot what we need before entering the async block.
424        let inner_refs: Vec<Arc<dyn ErasedStrategy>> =
425            self.strategies.iter().map(|b| Arc::clone(&b.0)).collect();
426        let max_tokens = self.max_tokens;
427        let counter = TokenCounter::new();
428
429        async move {
430            let mut current = messages;
431            for strategy in &inner_refs {
432                let token_count = counter.estimate_messages(&current);
433                if token_count <= max_tokens {
434                    break;
435                }
436                current = strategy.erased_compact(current).await?;
437            }
438            Ok(current)
439        }
440    }
441
442    fn token_estimate(&self, messages: &[Message]) -> usize {
443        self.counter.estimate_messages(messages)
444    }
445}