1#![deny(missing_docs)]
2use neuron_turn::context::ContextStrategy;
8use neuron_turn::types::{ContentPart, ProviderMessage};
9
10pub struct SlidingWindow {
15 chars_per_token: usize,
17}
18
19impl SlidingWindow {
20 pub fn new() -> Self {
25 Self { chars_per_token: 4 }
26 }
27
28 pub fn with_ratio(chars_per_token: usize) -> Self {
30 Self {
31 chars_per_token: chars_per_token.max(1),
32 }
33 }
34
35 fn estimate_message_tokens(&self, msg: &ProviderMessage) -> usize {
36 msg.content
37 .iter()
38 .map(|part| match part {
39 ContentPart::Text { text } => text.len() / self.chars_per_token,
40 ContentPart::ToolUse { input, .. } => {
41 input.to_string().len() / self.chars_per_token
42 }
43 ContentPart::ToolResult { content, .. } => content.len() / self.chars_per_token,
44 ContentPart::Image { .. } => 1000,
45 })
46 .sum::<usize>()
47 + 4 }
49}
50
51impl Default for SlidingWindow {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl ContextStrategy for SlidingWindow {
58 fn token_estimate(&self, messages: &[ProviderMessage]) -> usize {
59 messages
60 .iter()
61 .map(|m| self.estimate_message_tokens(m))
62 .sum()
63 }
64
65 fn should_compact(&self, messages: &[ProviderMessage], limit: usize) -> bool {
66 self.token_estimate(messages) > limit
67 }
68
69 fn compact(&self, messages: Vec<ProviderMessage>) -> Vec<ProviderMessage> {
70 if messages.len() <= 2 {
71 return messages;
72 }
73
74 let first = messages[0].clone();
76 let rest = &messages[1..];
77
78 let total_tokens: usize = messages
81 .iter()
82 .map(|m| self.estimate_message_tokens(m))
83 .sum();
84 let target = total_tokens / 2;
85
86 let mut kept = Vec::new();
87 let mut current_tokens = self.estimate_message_tokens(&first);
88
89 for msg in rest.iter().rev() {
90 let msg_tokens = self.estimate_message_tokens(msg);
91 if current_tokens + msg_tokens > target && !kept.is_empty() {
92 break;
93 }
94 kept.push(msg.clone());
95 current_tokens += msg_tokens;
96 }
97
98 kept.reverse();
99 let mut result = vec![first];
100 result.extend(kept);
101 result
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use neuron_turn::types::Role;
109
110 fn text_message(role: Role, text: &str) -> ProviderMessage {
111 ProviderMessage {
112 role,
113 content: vec![ContentPart::Text {
114 text: text.to_string(),
115 }],
116 }
117 }
118
119 #[test]
120 fn sliding_window_estimates_tokens() {
121 let sw = SlidingWindow::new();
122 let messages = vec![text_message(Role::User, &"a".repeat(400))];
123 assert_eq!(sw.token_estimate(&messages), 104);
125 }
126
127 #[test]
128 fn sliding_window_should_compact() {
129 let sw = SlidingWindow::new();
130 let messages = vec![text_message(Role::User, &"a".repeat(400))];
131 assert!(sw.should_compact(&messages, 50));
132 assert!(!sw.should_compact(&messages, 200));
133 }
134
135 #[test]
136 fn sliding_window_compact_preserves_first_and_recent() {
137 let sw = SlidingWindow::new();
138 let messages = vec![
139 text_message(Role::User, &"first ".repeat(100)),
140 text_message(Role::Assistant, &"old ".repeat(100)),
141 text_message(Role::User, &"middle ".repeat(100)),
142 text_message(Role::Assistant, &"recent ".repeat(100)),
143 text_message(Role::User, &"latest ".repeat(100)),
144 ];
145
146 let compacted = sw.compact(messages.clone());
147
148 assert_eq!(compacted[0].role, Role::User);
150 assert!(compacted[0].content[0] == messages[0].content[0]);
151
152 assert!(compacted.len() < messages.len());
154 assert!(compacted.len() >= 2);
155
156 assert_eq!(
158 compacted.last().unwrap().content[0],
159 messages.last().unwrap().content[0]
160 );
161 }
162
163 #[test]
164 fn sliding_window_short_messages_unchanged() {
165 let sw = SlidingWindow::new();
166 let messages = vec![
167 text_message(Role::User, "hi"),
168 text_message(Role::Assistant, "hello"),
169 ];
170
171 let compacted = sw.compact(messages.clone());
172 assert_eq!(compacted.len(), messages.len());
173 }
174
175 #[test]
176 fn sliding_window_single_message_unchanged() {
177 let sw = SlidingWindow::new();
178 let messages = vec![text_message(Role::User, "hi")];
179 let compacted = sw.compact(messages.clone());
180 assert_eq!(compacted.len(), 1);
181 }
182}