1use crate::message::{Message, MessageContent};
10
11pub trait TokenCounter: Send + Sync {
13 fn count_messages(&self, messages: &[Message]) -> u32;
15 fn count_text(&self, text: &str) -> u32;
17}
18
19pub struct CharTokenCounter;
25
26impl TokenCounter for CharTokenCounter {
27 fn count_messages(&self, messages: &[Message]) -> u32 {
28 messages
29 .iter()
30 .map(|m| self.count_text(&message_to_text(m)))
31 .sum()
32 }
33
34 fn count_text(&self, text: &str) -> u32 {
35 (text.chars().count() / 4).max(1) as u32
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41#[non_exhaustive]
42pub enum TrimStrategy {
43 Last,
45 First,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51#[non_exhaustive]
52pub enum MessageRole {
53 Human,
54 Ai,
55 System,
56 Tool,
57}
58
59pub struct TrimOptions<'a> {
61 pub strategy: TrimStrategy,
63 pub max_tokens: u32,
65 pub token_counter: &'a dyn TokenCounter,
67 pub start_on: Option<MessageRole>,
69 pub end_on: Option<Vec<MessageRole>>,
71}
72
73pub fn trim_messages(messages: &[Message], opts: TrimOptions) -> Vec<Message> {
103 if messages.is_empty() {
104 return vec![];
105 }
106
107 let mut result: Vec<Message> = match opts.strategy {
109 TrimStrategy::Last => {
110 let mut selected = Vec::new();
112 let mut budget = opts.max_tokens;
113
114 for msg in messages.iter().rev() {
115 let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
116 if cost > budget {
117 break;
118 }
119 budget -= cost;
120 selected.push(msg.clone());
121 }
122 selected.reverse();
123 selected
124 }
125 TrimStrategy::First => {
126 let mut selected = Vec::new();
128 let mut budget = opts.max_tokens;
129
130 for msg in messages {
131 let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
132 if cost > budget {
133 break;
134 }
135 budget -= cost;
136 selected.push(msg.clone());
137 }
138 selected
139 }
140 };
141
142 if let Some(ref start_role) = opts.start_on {
144 if let Some(start_idx) = result.iter().position(|m| message_has_role(m, start_role)) {
145 result.drain(..start_idx);
146 } else {
147 result.clear();
148 }
149 }
150
151 if let Some(ref end_roles) = opts.end_on {
153 while !result.is_empty()
154 && !end_roles
155 .iter()
156 .any(|r| message_has_role(result.last().unwrap(), r))
157 {
158 result.pop();
159 }
160 }
161
162 result
163}
164
165fn message_has_role(msg: &Message, role: &MessageRole) -> bool {
166 matches!(
167 (msg, role),
168 (Message::Human(_), MessageRole::Human)
169 | (Message::Ai(_), MessageRole::Ai)
170 | (Message::System(_), MessageRole::System)
171 | (Message::Tool(_), MessageRole::Tool)
172 )
173}
174
175fn message_to_text(msg: &Message) -> String {
176 match msg {
177 Message::Human(m) => content_to_text(&m.content),
178 Message::Ai(m) => content_to_text(&m.content),
179 Message::System(m) => m.content.clone(),
180 Message::Tool(m) => m.content.clone(),
181 }
182}
183
184fn content_to_text(content: &MessageContent) -> String {
185 match content {
186 MessageContent::Text(t) => t.clone(),
187 MessageContent::Blocks(blocks) => blocks
188 .iter()
189 .map(|b| match b {
190 crate::message::ContentBlock::Text { text } => text.clone(),
191 _ => String::new(),
192 })
193 .collect::<Vec<_>>()
194 .join(" "),
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_char_token_counter() {
204 let counter = CharTokenCounter;
205 assert_eq!(counter.count_text("hello world"), 2);
207 assert_eq!(counter.count_text("hi"), 1); assert_eq!(counter.count_text(""), 1); }
210
211 #[test]
212 fn test_trim_empty_input() {
213 let result = trim_messages(
214 &[],
215 TrimOptions {
216 strategy: TrimStrategy::Last,
217 max_tokens: 100,
218 token_counter: &CharTokenCounter,
219 start_on: None,
220 end_on: None,
221 },
222 );
223 assert!(result.is_empty());
224 }
225
226 #[test]
227 fn test_trim_all_under_budget() {
228 let messages = vec![Message::human("hello"), Message::ai("hi")];
229 let result = trim_messages(
230 &messages,
231 TrimOptions {
232 strategy: TrimStrategy::Last,
233 max_tokens: 1000,
234 token_counter: &CharTokenCounter,
235 start_on: None,
236 end_on: None,
237 },
238 );
239 assert_eq!(result.len(), 2);
240 }
241
242 #[test]
243 fn test_trim_last_strategy() {
244 let messages = vec![
245 Message::system("You are helpful. This is a long system prompt with many tokens."),
246 Message::human("short q"),
247 Message::ai("short a"),
248 ];
249 let result = trim_messages(
250 &messages,
251 TrimOptions {
252 strategy: TrimStrategy::Last,
253 max_tokens: 5,
254 token_counter: &CharTokenCounter,
255 start_on: None,
256 end_on: None,
257 },
258 );
259 assert!(result.len() < messages.len());
261 if !result.is_empty() {
263 assert!(matches!(result.last().unwrap(), Message::Ai(_)));
264 }
265 }
266
267 #[test]
268 fn test_trim_first_strategy() {
269 let messages = vec![
270 Message::human("first"),
271 Message::ai("second"),
272 Message::human("this is a much longer message that uses more tokens"),
273 ];
274 let result = trim_messages(
275 &messages,
276 TrimOptions {
277 strategy: TrimStrategy::First,
278 max_tokens: 5,
279 token_counter: &CharTokenCounter,
280 start_on: None,
281 end_on: None,
282 },
283 );
284 assert!(result.len() < messages.len());
286 if !result.is_empty() {
287 assert!(matches!(result[0], Message::Human(_)));
288 }
289 }
290
291 #[test]
292 fn test_trim_start_on_human() {
293 let messages = vec![
294 Message::system("sys"),
295 Message::ai("ai response"),
296 Message::human("question"),
297 ];
298 let result = trim_messages(
299 &messages,
300 TrimOptions {
301 strategy: TrimStrategy::Last,
302 max_tokens: 1000,
303 token_counter: &CharTokenCounter,
304 start_on: Some(MessageRole::Human),
305 end_on: None,
306 },
307 );
308 assert!(matches!(result[0], Message::Human(_)));
310 }
311
312 #[test]
313 fn test_trim_end_on_human_or_tool() {
314 let messages = vec![Message::human("q"), Message::ai("response")];
315 let result = trim_messages(
316 &messages,
317 TrimOptions {
318 strategy: TrimStrategy::Last,
319 max_tokens: 1000,
320 token_counter: &CharTokenCounter,
321 start_on: None,
322 end_on: Some(vec![MessageRole::Human, MessageRole::Tool]),
323 },
324 );
325 if !result.is_empty() {
327 assert!(!matches!(result.last().unwrap(), Message::Ai(_)));
328 }
329 }
330}