1use crate::context::token_estimator::TokenEstimator;
26use crate::context::types::{MessagePriority, PrioritizedMessage};
27use crate::conversation::message::{Message, MessageContent};
28
29const RECENT_THRESHOLD: f64 = 0.8;
35
36const MEDIUM_THRESHOLD: f64 = 0.5;
38
39const LOW_THRESHOLD: f64 = 0.2;
41
42const SUMMARY_KEYWORDS: &[&str] = &[
44 "[summary]",
45 "[conversation summary]",
46 "summary:",
47 "summarized:",
48 "previous conversation:",
49];
50
51pub struct PrioritySorter;
63
64impl PrioritySorter {
65 pub fn evaluate_priority(
93 message: &Message,
94 index: usize,
95 total_messages: usize,
96 ) -> MessagePriority {
97 if Self::is_system_or_summary(message) {
99 return MessagePriority::Critical;
100 }
101
102 if Self::has_tool_calls(message) {
104 return MessagePriority::High;
105 }
106
107 let position_ratio = if total_messages <= 1 {
109 1.0
110 } else {
111 index as f64 / (total_messages - 1) as f64
112 };
113
114 if position_ratio >= RECENT_THRESHOLD {
116 return MessagePriority::High;
117 }
118
119 if position_ratio >= MEDIUM_THRESHOLD {
121 return MessagePriority::Medium;
122 }
123
124 if position_ratio >= LOW_THRESHOLD {
126 return MessagePriority::Low;
127 }
128
129 MessagePriority::Minimal
131 }
132
133 pub fn sort_by_priority<F>(messages: &[Message], estimate_tokens: F) -> Vec<PrioritizedMessage>
156 where
157 F: Fn(&Message) -> usize,
158 {
159 let total_messages = messages.len();
160
161 let mut prioritized: Vec<PrioritizedMessage> = messages
162 .iter()
163 .enumerate()
164 .map(|(index, message)| {
165 let priority = Self::evaluate_priority(message, index, total_messages);
166 let tokens = estimate_tokens(message);
167
168 PrioritizedMessage::new(message.clone(), priority, message.created, tokens)
169 })
170 .collect();
171
172 prioritized.sort_by(|a, b| match b.priority.cmp(&a.priority) {
174 std::cmp::Ordering::Equal => b.timestamp.cmp(&a.timestamp),
175 other => other,
176 });
177
178 prioritized
179 }
180
181 pub fn sort_by_priority_default(messages: &[Message]) -> Vec<PrioritizedMessage> {
193 Self::sort_by_priority(messages, TokenEstimator::estimate_message_tokens)
194 }
195
196 pub fn is_system_or_summary(message: &Message) -> bool {
206 for content in &message.content {
212 if let MessageContent::Text(text_content) = content {
213 let text_lower = text_content.text.to_lowercase();
214 for keyword in SUMMARY_KEYWORDS {
215 if text_lower.contains(keyword) {
216 return true;
217 }
218 }
219 }
220 }
221
222 false
223 }
224
225 pub fn has_tool_calls(message: &Message) -> bool {
235 message.content.iter().any(|content| {
236 matches!(
237 content,
238 MessageContent::ToolRequest(_)
239 | MessageContent::ToolResponse(_)
240 | MessageContent::ToolConfirmationRequest(_)
241 | MessageContent::FrontendToolRequest(_)
242 )
243 })
244 }
245
246 pub fn filter_by_priority(
259 prioritized: &[PrioritizedMessage],
260 min_priority: MessagePriority,
261 ) -> Vec<PrioritizedMessage> {
262 prioritized
263 .iter()
264 .filter(|p| p.priority >= min_priority)
265 .cloned()
266 .collect()
267 }
268
269 pub fn select_within_budget(
280 prioritized: &[PrioritizedMessage],
281 max_tokens: usize,
282 ) -> Vec<PrioritizedMessage> {
283 let mut result = Vec::new();
284 let mut current_tokens = 0;
285
286 for pm in prioritized {
287 if current_tokens + pm.tokens <= max_tokens {
288 result.push(pm.clone());
289 current_tokens += pm.tokens;
290 }
291 }
292
293 result
294 }
295
296 pub fn get_priority_distribution(messages: &[Message]) -> (usize, usize, usize, usize, usize) {
306 let total = messages.len();
307 let mut critical = 0;
308 let mut high = 0;
309 let mut medium = 0;
310 let mut low = 0;
311 let mut minimal = 0;
312
313 for (index, message) in messages.iter().enumerate() {
314 match Self::evaluate_priority(message, index, total) {
315 MessagePriority::Critical => critical += 1,
316 MessagePriority::High => high += 1,
317 MessagePriority::Medium => medium += 1,
318 MessagePriority::Low => low += 1,
319 MessagePriority::Minimal => minimal += 1,
320 }
321 }
322
323 (critical, high, medium, low, minimal)
324 }
325}
326
327#[cfg(test)]
332mod tests {
333 use super::*;
334 use rmcp::model::{CallToolRequestParam, JsonObject, Role};
335
336 fn create_text_message(role: Role, text: &str) -> Message {
337 match role {
338 Role::User => Message::user().with_text(text),
339 Role::Assistant => Message::assistant().with_text(text),
340 }
341 }
342
343 fn create_tool_call_message() -> Message {
344 Message::assistant().with_tool_request(
345 "tool_1",
346 Ok(CallToolRequestParam {
347 name: "test_tool".into(),
348 arguments: Some(JsonObject::new()),
349 }),
350 )
351 }
352
353 fn create_summary_message() -> Message {
354 Message::user().with_text("[Summary] Previous conversation discussed file operations.")
355 }
356
357 #[test]
358 fn test_evaluate_priority_summary_is_critical() {
359 let message = create_summary_message();
360 let priority = PrioritySorter::evaluate_priority(&message, 0, 10);
361 assert_eq!(priority, MessagePriority::Critical);
362 }
363
364 #[test]
365 fn test_evaluate_priority_tool_call_is_high() {
366 let message = create_tool_call_message();
367 let priority = PrioritySorter::evaluate_priority(&message, 0, 10);
368 assert_eq!(priority, MessagePriority::High);
369 }
370
371 #[test]
372 fn test_evaluate_priority_recent_is_high() {
373 let message = create_text_message(Role::User, "Recent message");
374 let priority = PrioritySorter::evaluate_priority(&message, 9, 10);
376 assert_eq!(priority, MessagePriority::High);
377 }
378
379 #[test]
380 fn test_evaluate_priority_middle_is_medium() {
381 let message = create_text_message(Role::User, "Middle message");
382 let priority = PrioritySorter::evaluate_priority(&message, 6, 10);
384 assert_eq!(priority, MessagePriority::Medium);
385 }
386
387 #[test]
388 fn test_evaluate_priority_older_is_low() {
389 let message = create_text_message(Role::User, "Older message");
390 let priority = PrioritySorter::evaluate_priority(&message, 3, 10);
392 assert_eq!(priority, MessagePriority::Low);
393 }
394
395 #[test]
396 fn test_evaluate_priority_oldest_is_minimal() {
397 let message = create_text_message(Role::User, "Oldest message");
398 let priority = PrioritySorter::evaluate_priority(&message, 1, 10);
400 assert_eq!(priority, MessagePriority::Minimal);
401 }
402
403 #[test]
404 fn test_is_system_or_summary_with_summary() {
405 let message = create_summary_message();
406 assert!(PrioritySorter::is_system_or_summary(&message));
407 }
408
409 #[test]
410 fn test_is_system_or_summary_without_summary() {
411 let message = create_text_message(Role::User, "Regular message");
412 assert!(!PrioritySorter::is_system_or_summary(&message));
413 }
414
415 #[test]
416 fn test_has_tool_calls_with_tool() {
417 let message = create_tool_call_message();
418 assert!(PrioritySorter::has_tool_calls(&message));
419 }
420
421 #[test]
422 fn test_has_tool_calls_without_tool() {
423 let message = create_text_message(Role::User, "No tools here");
424 assert!(!PrioritySorter::has_tool_calls(&message));
425 }
426
427 #[test]
428 fn test_sort_by_priority_ordering() {
429 let messages = vec![
430 create_text_message(Role::User, "First message"), create_text_message(Role::Assistant, "Second message"), create_summary_message(), create_text_message(Role::User, "Fourth message"), create_text_message(Role::Assistant, "Fifth message"), ];
436
437 let sorted = PrioritySorter::sort_by_priority_default(&messages);
438
439 assert_eq!(sorted[0].priority, MessagePriority::Critical);
441 assert_eq!(sorted[1].priority, MessagePriority::High);
443 }
444
445 #[test]
446 fn test_filter_by_priority() {
447 let messages = vec![
448 create_text_message(Role::User, "First"),
449 create_text_message(Role::Assistant, "Second"),
450 create_text_message(Role::User, "Third"),
451 create_text_message(Role::Assistant, "Fourth"),
452 create_text_message(Role::User, "Fifth"),
453 ];
454
455 let prioritized = PrioritySorter::sort_by_priority_default(&messages);
456 let high_and_above =
457 PrioritySorter::filter_by_priority(&prioritized, MessagePriority::High);
458
459 for pm in &high_and_above {
461 assert!(pm.priority >= MessagePriority::High);
462 }
463 }
464
465 #[test]
466 fn test_select_within_budget() {
467 let messages = vec![
468 create_text_message(Role::User, "Short"),
469 create_text_message(Role::Assistant, "Also short"),
470 create_text_message(Role::User, "Another short one"),
471 ];
472
473 let prioritized = PrioritySorter::sort_by_priority_default(&messages);
474 let selected = PrioritySorter::select_within_budget(&prioritized, 50);
475
476 let total_tokens: usize = selected.iter().map(|p| p.tokens).sum();
478 assert!(total_tokens <= 50);
479 }
480
481 #[test]
482 fn test_get_priority_distribution() {
483 let messages = vec![
484 create_summary_message(), create_text_message(Role::User, "First"), create_text_message(Role::Assistant, "Second"), create_text_message(Role::User, "Third"), create_text_message(Role::Assistant, "Fourth"), create_text_message(Role::User, "Fifth"), create_text_message(Role::Assistant, "Sixth"), create_text_message(Role::User, "Seventh"), create_text_message(Role::Assistant, "Eighth"), create_tool_call_message(), ];
495
496 let (critical, high, medium, low, _minimal) =
497 PrioritySorter::get_priority_distribution(&messages);
498
499 assert_eq!(critical, 1); assert!(high >= 1); assert!(medium >= 1);
502 assert!(low >= 1);
503 }
505
506 #[test]
507 fn test_single_message_is_high_priority() {
508 let message = create_text_message(Role::User, "Only message");
509 let priority = PrioritySorter::evaluate_priority(&message, 0, 1);
510 assert_eq!(priority, MessagePriority::High);
512 }
513
514 #[test]
515 fn test_empty_messages() {
516 let messages: Vec<Message> = vec![];
517 let sorted = PrioritySorter::sort_by_priority_default(&messages);
518 assert!(sorted.is_empty());
519 }
520
521 #[test]
522 fn test_summary_keywords_case_insensitive() {
523 let message = create_text_message(Role::User, "[SUMMARY] This is a summary");
524 assert!(PrioritySorter::is_system_or_summary(&message));
525
526 let message2 = create_text_message(Role::User, "Conversation Summary: blah blah");
527 assert!(PrioritySorter::is_system_or_summary(&message2));
528 }
529}