1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use neuron_types::{ContextError, ContextStrategy, Message, Provider, Role};
8
9use crate::counter::TokenCounter;
10
11type CompactFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<Message>, ContextError>> + Send + 'a>>;
15
16trait ErasedStrategy: Send + Sync {
22 fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a>;
23 fn erased_token_estimate(&self, messages: &[Message]) -> usize;
24 fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool;
25}
26
27impl<S: ContextStrategy> ErasedStrategy for S {
28 fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a> {
29 Box::pin(self.compact(messages))
30 }
31
32 fn erased_token_estimate(&self, messages: &[Message]) -> usize {
33 self.token_estimate(messages)
34 }
35
36 fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool {
37 self.should_compact(messages, token_count)
38 }
39}
40
41pub struct BoxedStrategy(Arc<dyn ErasedStrategy>);
53
54impl BoxedStrategy {
55 #[must_use]
57 pub fn new<S: ContextStrategy + 'static>(strategy: S) -> Self {
58 BoxedStrategy(Arc::new(strategy))
59 }
60}
61
62impl ContextStrategy for BoxedStrategy {
63 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
64 self.0.erased_should_compact(messages, token_count)
65 }
66
67 fn compact(
68 &self,
69 messages: Vec<Message>,
70 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend {
71 let inner = Arc::clone(&self.0);
72 async move { inner.erased_compact(messages).await }
73 }
74
75 fn token_estimate(&self, messages: &[Message]) -> usize {
76 self.0.erased_token_estimate(messages)
77 }
78}
79
80pub struct SlidingWindowStrategy {
94 window_size: usize,
95 counter: TokenCounter,
96 max_tokens: usize,
97}
98
99impl SlidingWindowStrategy {
100 #[must_use]
106 pub fn new(window_size: usize, max_tokens: usize) -> Self {
107 Self { window_size, counter: TokenCounter::new(), max_tokens }
108 }
109
110 #[must_use]
112 pub fn with_counter(window_size: usize, max_tokens: usize, counter: TokenCounter) -> Self {
113 Self { window_size, counter, max_tokens }
114 }
115}
116
117impl ContextStrategy for SlidingWindowStrategy {
118 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
119 let _ = messages;
120 token_count > self.max_tokens
121 }
122
123 fn compact(
124 &self,
125 messages: Vec<Message>,
126 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
127 {
128 let window_size = self.window_size;
129 async move {
130 let (system_msgs, non_system): (Vec<_>, Vec<_>) =
131 messages.into_iter().partition(|m| m.role == Role::System);
132
133 let recent: Vec<Message> = non_system
134 .into_iter()
135 .rev()
136 .take(window_size)
137 .collect::<Vec<_>>()
138 .into_iter()
139 .rev()
140 .collect();
141
142 let mut result = system_msgs;
143 result.extend(recent);
144 Ok(result)
145 }
146 }
147
148 fn token_estimate(&self, messages: &[Message]) -> usize {
149 self.counter.estimate_messages(messages)
150 }
151}
152
153pub struct ToolResultClearingStrategy {
169 keep_recent_n: usize,
170 counter: TokenCounter,
171 max_tokens: usize,
172}
173
174impl ToolResultClearingStrategy {
175 #[must_use]
181 pub fn new(keep_recent_n: usize, max_tokens: usize) -> Self {
182 Self { keep_recent_n, counter: TokenCounter::new(), max_tokens }
183 }
184
185 #[must_use]
187 pub fn with_counter(
188 keep_recent_n: usize,
189 max_tokens: usize,
190 counter: TokenCounter,
191 ) -> Self {
192 Self { keep_recent_n, counter, max_tokens }
193 }
194}
195
196impl ContextStrategy for ToolResultClearingStrategy {
197 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
198 let _ = messages;
199 token_count > self.max_tokens
200 }
201
202 fn compact(
203 &self,
204 messages: Vec<Message>,
205 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
206 {
207 use neuron_types::{ContentBlock, ContentItem};
208
209 let keep_recent_n = self.keep_recent_n;
210 async move {
211 let mut tool_result_positions: Vec<(usize, usize)> = Vec::new();
213 for (msg_idx, msg) in messages.iter().enumerate() {
214 for (block_idx, block) in msg.content.iter().enumerate() {
215 if matches!(block, ContentBlock::ToolResult { .. }) {
216 tool_result_positions.push((msg_idx, block_idx));
217 }
218 }
219 }
220
221 let total = tool_result_positions.len();
222 let to_clear_count = total.saturating_sub(keep_recent_n);
223
224 if to_clear_count == 0 {
225 return Ok(messages);
226 }
227
228 let to_clear = tool_result_positions[..to_clear_count].to_vec();
229 let mut messages = messages;
230 for (msg_idx, block_idx) in to_clear {
231 let block = &mut messages[msg_idx].content[block_idx];
232 if let ContentBlock::ToolResult { content, is_error, .. } = block {
233 *content = vec![ContentItem::Text("[tool result cleared]".to_string())];
234 *is_error = false;
235 }
236 }
237
238 Ok(messages)
239 }
240 }
241
242 fn token_estimate(&self, messages: &[Message]) -> usize {
243 self.counter.estimate_messages(messages)
244 }
245}
246
247pub struct SummarizationStrategy<P: Provider> {
264 provider: P,
265 preserve_recent: usize,
266 counter: TokenCounter,
267 max_tokens: usize,
268}
269
270impl<P: Provider> SummarizationStrategy<P> {
271 #[must_use]
278 pub fn new(provider: P, preserve_recent: usize, max_tokens: usize) -> Self {
279 Self { provider, preserve_recent, counter: TokenCounter::new(), max_tokens }
280 }
281
282 #[must_use]
284 pub fn with_counter(
285 provider: P,
286 preserve_recent: usize,
287 max_tokens: usize,
288 counter: TokenCounter,
289 ) -> Self {
290 Self { provider, preserve_recent, counter, max_tokens }
291 }
292}
293
294impl<P: Provider> ContextStrategy for SummarizationStrategy<P> {
295 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
296 let _ = messages;
297 token_count > self.max_tokens
298 }
299
300 fn compact(
301 &self,
302 messages: Vec<Message>,
303 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
304 {
305 use neuron_types::{CompletionRequest, ContentBlock, Role, SystemPrompt};
306
307 let preserve_recent = self.preserve_recent;
308
309 let (system_msgs, non_system): (Vec<Message>, Vec<Message>) =
311 messages.into_iter().partition(|m| m.role == Role::System);
312
313 let split_at = non_system.len().saturating_sub(preserve_recent);
314 let old_messages = non_system[..split_at].to_vec();
315 let recent_messages = non_system[split_at..].to_vec();
316
317 let summarize_request = CompletionRequest {
318 model: String::new(),
319 messages: old_messages,
320 system: Some(SystemPrompt::Text(
321 "Summarize the conversation above concisely. Focus on key information, \
322 decisions made, and results from tool calls. Write in third person."
323 .to_string(),
324 )),
325 tools: vec![],
326 max_tokens: Some(1024),
327 temperature: Some(0.0),
328 top_p: None,
329 stop_sequences: vec![],
330 tool_choice: None,
331 response_format: None,
332 thinking: None,
333 reasoning_effort: None,
334 extra: None,
335 context_management: None,
336 };
337
338 async move {
339 let response = self.provider.complete(summarize_request).await?;
340
341 let summary_text = response
342 .message
343 .content
344 .into_iter()
345 .filter_map(|block| {
346 if let ContentBlock::Text(text) = block {
347 Some(text)
348 } else {
349 None
350 }
351 })
352 .collect::<Vec<_>>()
353 .join("\n");
354
355 let summary_message = Message {
356 role: Role::User,
357 content: vec![ContentBlock::Text(format!(
358 "[Summary of earlier conversation]\n{summary_text}"
359 ))],
360 };
361
362 let mut result = system_msgs;
363 result.push(summary_message);
364 result.extend(recent_messages);
365 Ok(result)
366 }
367 }
368
369 fn token_estimate(&self, messages: &[Message]) -> usize {
370 self.counter.estimate_messages(messages)
371 }
372}
373
374pub struct CompositeStrategy {
396 strategies: Vec<BoxedStrategy>,
397 counter: TokenCounter,
398 max_tokens: usize,
399}
400
401impl CompositeStrategy {
402 #[must_use]
408 pub fn new(strategies: Vec<BoxedStrategy>, max_tokens: usize) -> Self {
409 Self { strategies, counter: TokenCounter::new(), max_tokens }
410 }
411}
412
413impl ContextStrategy for CompositeStrategy {
414 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
415 let _ = messages;
416 token_count > self.max_tokens
417 }
418
419 fn compact(
420 &self,
421 messages: Vec<Message>,
422 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
423 {
424 let inner_refs: Vec<Arc<dyn ErasedStrategy>> =
426 self.strategies.iter().map(|b| Arc::clone(&b.0)).collect();
427 let max_tokens = self.max_tokens;
428 let counter = TokenCounter::new();
429
430 async move {
431 let mut current = messages;
432 for strategy in &inner_refs {
433 let token_count = counter.estimate_messages(¤t);
434 if token_count <= max_tokens {
435 break;
436 }
437 current = strategy.erased_compact(current).await?;
438 }
439 Ok(current)
440 }
441 }
442
443 fn token_estimate(&self, messages: &[Message]) -> usize {
444 self.counter.estimate_messages(messages)
445 }
446}