Skip to main content

neuron_context/
lib.rs

1#![deny(missing_docs)]
2//! Context strategy implementations for neuron-turn.
3//!
4//! Provides [`SlidingWindow`] for dropping oldest messages when context
5//! exceeds a limit. `NoCompaction` is in neuron-turn itself.
6
7use neuron_turn::context::ContextStrategy;
8use neuron_turn::types::{ContentPart, ProviderMessage};
9
10/// Sliding window context strategy.
11///
12/// When context exceeds the limit, drops the oldest messages
13/// (keeping the first message, which is typically the initial user message).
14pub struct SlidingWindow {
15    /// Approximate chars-per-token ratio for estimation.
16    chars_per_token: usize,
17}
18
19impl SlidingWindow {
20    /// Create a new sliding window strategy.
21    ///
22    /// `chars_per_token` controls the token estimation granularity
23    /// (default: 4 chars per token).
24    pub fn new() -> Self {
25        Self { chars_per_token: 4 }
26    }
27
28    /// Create with a custom chars-per-token ratio.
29    pub fn with_ratio(chars_per_token: usize) -> Self {
30        Self {
31            chars_per_token: chars_per_token.max(1),
32        }
33    }
34
35    fn estimate_message_tokens(&self, msg: &ProviderMessage) -> usize {
36        msg.content
37            .iter()
38            .map(|part| match part {
39                ContentPart::Text { text } => text.len() / self.chars_per_token,
40                ContentPart::ToolUse { input, .. } => {
41                    input.to_string().len() / self.chars_per_token
42                }
43                ContentPart::ToolResult { content, .. } => content.len() / self.chars_per_token,
44                ContentPart::Image { .. } => 1000,
45            })
46            .sum::<usize>()
47            + 4 // overhead per message (role, formatting)
48    }
49}
50
51impl Default for SlidingWindow {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl ContextStrategy for SlidingWindow {
58    fn token_estimate(&self, messages: &[ProviderMessage]) -> usize {
59        messages
60            .iter()
61            .map(|m| self.estimate_message_tokens(m))
62            .sum()
63    }
64
65    fn should_compact(&self, messages: &[ProviderMessage], limit: usize) -> bool {
66        self.token_estimate(messages) > limit
67    }
68
69    fn compact(&self, messages: Vec<ProviderMessage>) -> Vec<ProviderMessage> {
70        if messages.len() <= 2 {
71            return messages;
72        }
73
74        // Keep first message + most recent messages that fit
75        let first = messages[0].clone();
76        let rest = &messages[1..];
77
78        // Work backwards, accumulating messages until we hit roughly half the
79        // original size (heuristic: keep recent context, drop old)
80        let total_tokens: usize = messages
81            .iter()
82            .map(|m| self.estimate_message_tokens(m))
83            .sum();
84        let target = total_tokens / 2;
85
86        let mut kept = Vec::new();
87        let mut current_tokens = self.estimate_message_tokens(&first);
88
89        for msg in rest.iter().rev() {
90            let msg_tokens = self.estimate_message_tokens(msg);
91            if current_tokens + msg_tokens > target && !kept.is_empty() {
92                break;
93            }
94            kept.push(msg.clone());
95            current_tokens += msg_tokens;
96        }
97
98        kept.reverse();
99        let mut result = vec![first];
100        result.extend(kept);
101        result
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use neuron_turn::types::Role;
109
110    fn text_message(role: Role, text: &str) -> ProviderMessage {
111        ProviderMessage {
112            role,
113            content: vec![ContentPart::Text {
114                text: text.to_string(),
115            }],
116        }
117    }
118
119    #[test]
120    fn sliding_window_estimates_tokens() {
121        let sw = SlidingWindow::new();
122        let messages = vec![text_message(Role::User, &"a".repeat(400))];
123        // 400 chars / 4 = 100, + 4 overhead = 104
124        assert_eq!(sw.token_estimate(&messages), 104);
125    }
126
127    #[test]
128    fn sliding_window_should_compact() {
129        let sw = SlidingWindow::new();
130        let messages = vec![text_message(Role::User, &"a".repeat(400))];
131        assert!(sw.should_compact(&messages, 50));
132        assert!(!sw.should_compact(&messages, 200));
133    }
134
135    #[test]
136    fn sliding_window_compact_preserves_first_and_recent() {
137        let sw = SlidingWindow::new();
138        let messages = vec![
139            text_message(Role::User, &"first ".repeat(100)),
140            text_message(Role::Assistant, &"old ".repeat(100)),
141            text_message(Role::User, &"middle ".repeat(100)),
142            text_message(Role::Assistant, &"recent ".repeat(100)),
143            text_message(Role::User, &"latest ".repeat(100)),
144        ];
145
146        let compacted = sw.compact(messages.clone());
147
148        // Should keep first message
149        assert_eq!(compacted[0].role, Role::User);
150        assert!(compacted[0].content[0] == messages[0].content[0]);
151
152        // Should keep some recent messages
153        assert!(compacted.len() < messages.len());
154        assert!(compacted.len() >= 2);
155
156        // Last message should be the latest
157        assert_eq!(
158            compacted.last().unwrap().content[0],
159            messages.last().unwrap().content[0]
160        );
161    }
162
163    #[test]
164    fn sliding_window_short_messages_unchanged() {
165        let sw = SlidingWindow::new();
166        let messages = vec![
167            text_message(Role::User, "hi"),
168            text_message(Role::Assistant, "hello"),
169        ];
170
171        let compacted = sw.compact(messages.clone());
172        assert_eq!(compacted.len(), messages.len());
173    }
174
175    #[test]
176    fn sliding_window_single_message_unchanged() {
177        let sw = SlidingWindow::new();
178        let messages = vec![text_message(Role::User, "hi")];
179        let compacted = sw.compact(messages.clone());
180        assert_eq!(compacted.len(), 1);
181    }
182}