1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use neuron_types::{ContextError, ContextStrategy, Message, Provider, Role};
8
9use crate::counter::TokenCounter;
10
11type CompactFuture<'a> =
15 Pin<Box<dyn Future<Output = Result<Vec<Message>, ContextError>> + Send + 'a>>;
16
17trait ErasedStrategy: Send + Sync {
23 fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a>;
24 fn erased_token_estimate(&self, messages: &[Message]) -> usize;
25 fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool;
26}
27
28impl<S: ContextStrategy> ErasedStrategy for S {
29 fn erased_compact<'a>(&'a self, messages: Vec<Message>) -> CompactFuture<'a> {
30 Box::pin(self.compact(messages))
31 }
32
33 fn erased_token_estimate(&self, messages: &[Message]) -> usize {
34 self.token_estimate(messages)
35 }
36
37 fn erased_should_compact(&self, messages: &[Message], token_count: usize) -> bool {
38 self.should_compact(messages, token_count)
39 }
40}
41
42pub struct BoxedStrategy(Arc<dyn ErasedStrategy>);
54
55impl BoxedStrategy {
56 #[must_use]
58 pub fn new<S: ContextStrategy + 'static>(strategy: S) -> Self {
59 BoxedStrategy(Arc::new(strategy))
60 }
61}
62
63impl ContextStrategy for BoxedStrategy {
64 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
65 self.0.erased_should_compact(messages, token_count)
66 }
67
68 fn compact(
69 &self,
70 messages: Vec<Message>,
71 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
72 {
73 let inner = Arc::clone(&self.0);
74 async move { inner.erased_compact(messages).await }
75 }
76
77 fn token_estimate(&self, messages: &[Message]) -> usize {
78 self.0.erased_token_estimate(messages)
79 }
80}
81
82pub struct SlidingWindowStrategy {
96 window_size: usize,
97 counter: TokenCounter,
98 max_tokens: usize,
99}
100
101impl SlidingWindowStrategy {
102 #[must_use]
108 pub fn new(window_size: usize, max_tokens: usize) -> Self {
109 Self {
110 window_size,
111 counter: TokenCounter::new(),
112 max_tokens,
113 }
114 }
115
116 #[must_use]
118 pub fn with_counter(window_size: usize, max_tokens: usize, counter: TokenCounter) -> Self {
119 Self {
120 window_size,
121 counter,
122 max_tokens,
123 }
124 }
125}
126
127impl ContextStrategy for SlidingWindowStrategy {
128 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
129 let _ = messages;
130 token_count > self.max_tokens
131 }
132
133 fn compact(
134 &self,
135 messages: Vec<Message>,
136 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
137 {
138 let window_size = self.window_size;
139 async move {
140 let (system_msgs, non_system): (Vec<_>, Vec<_>) =
141 messages.into_iter().partition(|m| m.role == Role::System);
142
143 let recent: Vec<Message> = non_system
144 .into_iter()
145 .rev()
146 .take(window_size)
147 .collect::<Vec<_>>()
148 .into_iter()
149 .rev()
150 .collect();
151
152 let mut result = system_msgs;
153 result.extend(recent);
154 Ok(result)
155 }
156 }
157
158 fn token_estimate(&self, messages: &[Message]) -> usize {
159 self.counter.estimate_messages(messages)
160 }
161}
162
163pub struct ToolResultClearingStrategy {
179 keep_recent_n: usize,
180 counter: TokenCounter,
181 max_tokens: usize,
182}
183
184impl ToolResultClearingStrategy {
185 #[must_use]
191 pub fn new(keep_recent_n: usize, max_tokens: usize) -> Self {
192 Self {
193 keep_recent_n,
194 counter: TokenCounter::new(),
195 max_tokens,
196 }
197 }
198
199 #[must_use]
201 pub fn with_counter(keep_recent_n: usize, max_tokens: usize, counter: TokenCounter) -> Self {
202 Self {
203 keep_recent_n,
204 counter,
205 max_tokens,
206 }
207 }
208}
209
210impl ContextStrategy for ToolResultClearingStrategy {
211 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
212 let _ = messages;
213 token_count > self.max_tokens
214 }
215
216 fn compact(
217 &self,
218 messages: Vec<Message>,
219 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
220 {
221 use neuron_types::{ContentBlock, ContentItem};
222
223 let keep_recent_n = self.keep_recent_n;
224 async move {
225 let mut tool_result_positions: Vec<(usize, usize)> = Vec::new();
227 for (msg_idx, msg) in messages.iter().enumerate() {
228 for (block_idx, block) in msg.content.iter().enumerate() {
229 if matches!(block, ContentBlock::ToolResult { .. }) {
230 tool_result_positions.push((msg_idx, block_idx));
231 }
232 }
233 }
234
235 let total = tool_result_positions.len();
236 let to_clear_count = total.saturating_sub(keep_recent_n);
237
238 if to_clear_count == 0 {
239 return Ok(messages);
240 }
241
242 let to_clear = tool_result_positions[..to_clear_count].to_vec();
243 let mut messages = messages;
244 for (msg_idx, block_idx) in to_clear {
245 let block = &mut messages[msg_idx].content[block_idx];
246 if let ContentBlock::ToolResult {
247 content, is_error, ..
248 } = block
249 {
250 *content = vec![ContentItem::Text("[tool result cleared]".to_string())];
251 *is_error = false;
252 }
253 }
254
255 Ok(messages)
256 }
257 }
258
259 fn token_estimate(&self, messages: &[Message]) -> usize {
260 self.counter.estimate_messages(messages)
261 }
262}
263
264pub struct SummarizationStrategy<P: Provider> {
281 provider: P,
282 preserve_recent: usize,
283 counter: TokenCounter,
284 max_tokens: usize,
285}
286
287impl<P: Provider> SummarizationStrategy<P> {
288 #[must_use]
295 pub fn new(provider: P, preserve_recent: usize, max_tokens: usize) -> Self {
296 Self {
297 provider,
298 preserve_recent,
299 counter: TokenCounter::new(),
300 max_tokens,
301 }
302 }
303
304 #[must_use]
306 pub fn with_counter(
307 provider: P,
308 preserve_recent: usize,
309 max_tokens: usize,
310 counter: TokenCounter,
311 ) -> Self {
312 Self {
313 provider,
314 preserve_recent,
315 counter,
316 max_tokens,
317 }
318 }
319}
320
321impl<P: Provider> ContextStrategy for SummarizationStrategy<P> {
322 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
323 let _ = messages;
324 token_count > self.max_tokens
325 }
326
327 fn compact(
328 &self,
329 messages: Vec<Message>,
330 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
331 {
332 use neuron_types::{CompletionRequest, ContentBlock, Role, SystemPrompt};
333
334 let preserve_recent = self.preserve_recent;
335
336 let (system_msgs, non_system): (Vec<Message>, Vec<Message>) =
338 messages.into_iter().partition(|m| m.role == Role::System);
339
340 let split_at = non_system.len().saturating_sub(preserve_recent);
341 let old_messages = non_system[..split_at].to_vec();
342 let recent_messages = non_system[split_at..].to_vec();
343
344 let summarize_request = CompletionRequest {
345 model: String::new(),
346 messages: old_messages,
347 system: Some(SystemPrompt::Text(
348 "Summarize the conversation above concisely. Focus on key information, \
349 decisions made, and results from tool calls. Write in third person."
350 .to_string(),
351 )),
352 tools: vec![],
353 max_tokens: Some(1024),
354 temperature: Some(0.0),
355 top_p: None,
356 stop_sequences: vec![],
357 tool_choice: None,
358 response_format: None,
359 thinking: None,
360 reasoning_effort: None,
361 extra: None,
362 context_management: None,
363 };
364
365 async move {
366 let response = self.provider.complete(summarize_request).await?;
367
368 let summary_text = response
369 .message
370 .content
371 .into_iter()
372 .filter_map(|block| {
373 if let ContentBlock::Text(text) = block {
374 Some(text)
375 } else {
376 None
377 }
378 })
379 .collect::<Vec<_>>()
380 .join("\n");
381
382 let summary_message = Message {
383 role: Role::User,
384 content: vec![ContentBlock::Text(format!(
385 "[Summary of earlier conversation]\n{summary_text}"
386 ))],
387 };
388
389 let mut result = system_msgs;
390 result.push(summary_message);
391 result.extend(recent_messages);
392 Ok(result)
393 }
394 }
395
396 fn token_estimate(&self, messages: &[Message]) -> usize {
397 self.counter.estimate_messages(messages)
398 }
399}
400
401pub struct CompositeStrategy {
423 strategies: Vec<BoxedStrategy>,
424 counter: TokenCounter,
425 max_tokens: usize,
426}
427
428impl CompositeStrategy {
429 #[must_use]
435 pub fn new(strategies: Vec<BoxedStrategy>, max_tokens: usize) -> Self {
436 Self {
437 strategies,
438 counter: TokenCounter::new(),
439 max_tokens,
440 }
441 }
442}
443
444impl ContextStrategy for CompositeStrategy {
445 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool {
446 let _ = messages;
447 token_count > self.max_tokens
448 }
449
450 fn compact(
451 &self,
452 messages: Vec<Message>,
453 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + neuron_types::WasmCompatSend
454 {
455 let inner_refs: Vec<Arc<dyn ErasedStrategy>> =
457 self.strategies.iter().map(|b| Arc::clone(&b.0)).collect();
458 let max_tokens = self.max_tokens;
459 let counter = TokenCounter::new();
460
461 async move {
462 let mut current = messages;
463 for strategy in &inner_refs {
464 let token_count = counter.estimate_messages(¤t);
465 if token_count <= max_tokens {
466 break;
467 }
468 current = strategy.erased_compact(current).await?;
469 }
470 Ok(current)
471 }
472 }
473
474 fn token_estimate(&self, messages: &[Message]) -> usize {
475 self.counter.estimate_messages(messages)
476 }
477}