Skip to main content

pe_core/
token.rs

1//! Token counting and message trimming utilities.
2//!
3//! Provides `TokenCounter` trait for pluggable token counting,
4//! `CharTokenCounter` for test approximation, and `trim_messages`
5//! for context window management.
6//!
7//! Based on Group 16.6 and Group 19 of the pre-plan.
8
9use crate::message::{Message, MessageContent};
10
11/// How messages should be counted for trimming.
12pub trait TokenCounter: Send + Sync {
13    /// Count tokens in a slice of messages.
14    fn count_messages(&self, messages: &[Message]) -> u32;
15    /// Count tokens in a text string.
16    fn count_text(&self, text: &str) -> u32;
17}
18
19/// Approximate token counter using character count / 4.
20///
21/// Suitable for tests — no tokenizer dependency required.
22/// For production, implement `TokenCounter` with a real tokenizer
23/// (tiktoken, etc.).
24pub struct CharTokenCounter;
25
26impl TokenCounter for CharTokenCounter {
27    fn count_messages(&self, messages: &[Message]) -> u32 {
28        messages
29            .iter()
30            .map(|m| self.count_text(&message_to_text(m)))
31            .sum()
32    }
33
34    fn count_text(&self, text: &str) -> u32 {
35        (text.chars().count() / 4).max(1) as u32
36    }
37}
38
39/// Strategy for which messages to keep when trimming.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41#[non_exhaustive]
42pub enum TrimStrategy {
43    /// Keep the most recent messages (discard oldest first).
44    Last,
45    /// Keep the oldest messages (discard newest first).
46    First,
47}
48
49/// Which message roles are acceptable at window boundaries.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51#[non_exhaustive]
52pub enum MessageRole {
53    Human,
54    Ai,
55    System,
56    Tool,
57}
58
59/// Options for `trim_messages`.
60pub struct TrimOptions<'a> {
61    /// Which end to keep messages from.
62    pub strategy: TrimStrategy,
63    /// Hard token cap.
64    pub max_tokens: u32,
65    /// Token counting implementation.
66    pub token_counter: &'a dyn TokenCounter,
67    /// If set, the trimmed window must start with this role.
68    pub start_on: Option<MessageRole>,
69    /// If set, the trimmed window must end with one of these roles.
70    pub end_on: Option<Vec<MessageRole>>,
71}
72
73/// Trim a message list to fit within a token budget.
74///
75/// Messages are selected according to the strategy (Last = keep newest,
76/// First = keep oldest), then boundary constraints (start_on, end_on)
77/// are enforced by dropping messages at the boundary.
78///
79/// # Example
80///
81/// ```
82/// use pe_core::token::{trim_messages, TrimOptions, TrimStrategy, CharTokenCounter};
83/// use pe_core::message::Message;
84///
85/// let messages = vec![
86///     Message::system("You are helpful"),
87///     Message::human("What is Rust?"),
88///     Message::ai("Rust is a systems programming language."),
89/// ];
90///
91/// let trimmed = trim_messages(&messages, TrimOptions {
92///     strategy: TrimStrategy::Last,
93///     max_tokens: 20,
94///     token_counter: &CharTokenCounter,
95///     start_on: None,
96///     end_on: None,
97/// });
98///
99/// // Some messages will be trimmed to fit the budget
100/// assert!(trimmed.len() <= messages.len());
101/// ```
102pub fn trim_messages(messages: &[Message], opts: TrimOptions) -> Vec<Message> {
103    if messages.is_empty() {
104        return vec![];
105    }
106
107    // Build candidate list based on strategy and budget
108    let mut result: Vec<Message> = match opts.strategy {
109        TrimStrategy::Last => {
110            // Iterate from end, accumulate until budget
111            let mut selected = Vec::new();
112            let mut budget = opts.max_tokens;
113
114            for msg in messages.iter().rev() {
115                let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
116                if cost > budget {
117                    break;
118                }
119                budget -= cost;
120                selected.push(msg.clone());
121            }
122            selected.reverse();
123            selected
124        }
125        TrimStrategy::First => {
126            // Iterate from start, accumulate until budget
127            let mut selected = Vec::new();
128            let mut budget = opts.max_tokens;
129
130            for msg in messages {
131                let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
132                if cost > budget {
133                    break;
134                }
135                budget -= cost;
136                selected.push(msg.clone());
137            }
138            selected
139        }
140    };
141
142    // Enforce start_on: drop leading messages that don't match
143    if let Some(ref start_role) = opts.start_on {
144        if let Some(start_idx) = result.iter().position(|m| message_has_role(m, start_role)) {
145            result.drain(..start_idx);
146        } else {
147            result.clear();
148        }
149    }
150
151    // Enforce end_on: drop trailing messages that don't match
152    if let Some(ref end_roles) = opts.end_on {
153        while !result.is_empty()
154            && !end_roles
155                .iter()
156                .any(|r| message_has_role(result.last().unwrap(), r))
157        {
158            result.pop();
159        }
160    }
161
162    result
163}
164
165fn message_has_role(msg: &Message, role: &MessageRole) -> bool {
166    matches!(
167        (msg, role),
168        (Message::Human(_), MessageRole::Human)
169            | (Message::Ai(_), MessageRole::Ai)
170            | (Message::System(_), MessageRole::System)
171            | (Message::Tool(_), MessageRole::Tool)
172    )
173}
174
175fn message_to_text(msg: &Message) -> String {
176    match msg {
177        Message::Human(m) => content_to_text(&m.content),
178        Message::Ai(m) => content_to_text(&m.content),
179        Message::System(m) => m.content.clone(),
180        Message::Tool(m) => m.content.clone(),
181    }
182}
183
184fn content_to_text(content: &MessageContent) -> String {
185    match content {
186        MessageContent::Text(t) => t.clone(),
187        MessageContent::Blocks(blocks) => blocks
188            .iter()
189            .map(|b| match b {
190                crate::message::ContentBlock::Text { text } => text.clone(),
191                _ => String::new(),
192            })
193            .collect::<Vec<_>>()
194            .join(" "),
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_char_token_counter() {
204        let counter = CharTokenCounter;
205        // "hello world" = 11 chars / 4 = 2 (rounded down, min 1)
206        assert_eq!(counter.count_text("hello world"), 2);
207        assert_eq!(counter.count_text("hi"), 1); // 2 / 4 = 0, but min 1
208        assert_eq!(counter.count_text(""), 1); // min 1
209    }
210
211    #[test]
212    fn test_trim_empty_input() {
213        let result = trim_messages(
214            &[],
215            TrimOptions {
216                strategy: TrimStrategy::Last,
217                max_tokens: 100,
218                token_counter: &CharTokenCounter,
219                start_on: None,
220                end_on: None,
221            },
222        );
223        assert!(result.is_empty());
224    }
225
226    #[test]
227    fn test_trim_all_under_budget() {
228        let messages = vec![Message::human("hello"), Message::ai("hi")];
229        let result = trim_messages(
230            &messages,
231            TrimOptions {
232                strategy: TrimStrategy::Last,
233                max_tokens: 1000,
234                token_counter: &CharTokenCounter,
235                start_on: None,
236                end_on: None,
237            },
238        );
239        assert_eq!(result.len(), 2);
240    }
241
242    #[test]
243    fn test_trim_last_strategy() {
244        let messages = vec![
245            Message::system("You are helpful. This is a long system prompt with many tokens."),
246            Message::human("short q"),
247            Message::ai("short a"),
248        ];
249        let result = trim_messages(
250            &messages,
251            TrimOptions {
252                strategy: TrimStrategy::Last,
253                max_tokens: 5,
254                token_counter: &CharTokenCounter,
255                start_on: None,
256                end_on: None,
257            },
258        );
259        // Should keep newest messages that fit
260        assert!(result.len() < messages.len());
261        // Last messages should be preserved
262        if !result.is_empty() {
263            assert!(matches!(result.last().unwrap(), Message::Ai(_)));
264        }
265    }
266
267    #[test]
268    fn test_trim_first_strategy() {
269        let messages = vec![
270            Message::human("first"),
271            Message::ai("second"),
272            Message::human("this is a much longer message that uses more tokens"),
273        ];
274        let result = trim_messages(
275            &messages,
276            TrimOptions {
277                strategy: TrimStrategy::First,
278                max_tokens: 5,
279                token_counter: &CharTokenCounter,
280                start_on: None,
281                end_on: None,
282            },
283        );
284        // Should keep oldest messages that fit
285        assert!(result.len() < messages.len());
286        if !result.is_empty() {
287            assert!(matches!(result[0], Message::Human(_)));
288        }
289    }
290
291    #[test]
292    fn test_trim_start_on_human() {
293        let messages = vec![
294            Message::system("sys"),
295            Message::ai("ai response"),
296            Message::human("question"),
297        ];
298        let result = trim_messages(
299            &messages,
300            TrimOptions {
301                strategy: TrimStrategy::Last,
302                max_tokens: 1000,
303                token_counter: &CharTokenCounter,
304                start_on: Some(MessageRole::Human),
305                end_on: None,
306            },
307        );
308        // Should start with a Human message
309        assert!(matches!(result[0], Message::Human(_)));
310    }
311
312    #[test]
313    fn test_trim_end_on_human_or_tool() {
314        let messages = vec![Message::human("q"), Message::ai("response")];
315        let result = trim_messages(
316            &messages,
317            TrimOptions {
318                strategy: TrimStrategy::Last,
319                max_tokens: 1000,
320                token_counter: &CharTokenCounter,
321                start_on: None,
322                end_on: Some(vec![MessageRole::Human, MessageRole::Tool]),
323            },
324        );
325        // Should not end with AI message — AI gets dropped
326        if !result.is_empty() {
327            assert!(!matches!(result.last().unwrap(), Message::Ai(_)));
328        }
329    }
330}