1use crate::context::token_estimator::TokenEstimator;
22use crate::context::types::{ContextError, ConversationTurn, TokenUsage};
23use crate::conversation::message::{Message, MessageContent};
24use async_trait::async_trait;
25use rmcp::model::Content;
26use std::result::Result;
27
28pub const SUMMARY_SYSTEM_PROMPT: &str =
34 "Summarize this coding conversation in under 50 characters.\n\
35 Capture the main task, key files, problems addressed, and current status.";
36
37pub const DEFAULT_SUMMARY_BUDGET: usize = 4000;
39
40pub const MAX_SUMMARY_LENGTH: usize = 500;
42
43#[derive(Debug, Clone)]
49pub struct SummarizerResponse {
50 pub content: Vec<Content>,
52 pub usage: Option<TokenUsage>,
54}
55
56impl SummarizerResponse {
57 pub fn new(content: Vec<Content>, usage: Option<TokenUsage>) -> Self {
59 Self { content, usage }
60 }
61
62 pub fn text(&self) -> String {
64 self.content
65 .iter()
66 .filter_map(|c| c.as_text().map(|t| t.text.clone()))
67 .collect::<Vec<_>>()
68 .join("")
69 }
70}
71
72#[async_trait]
77pub trait SummarizerClient: Send + Sync {
78 async fn create_message(
89 &self,
90 messages: Vec<Message>,
91 system_prompt: Option<&str>,
92 ) -> Result<SummarizerResponse, ContextError>;
93}
94
95pub struct Summarizer;
104
105impl Summarizer {
106 pub async fn generate_ai_summary(
121 turns: &[ConversationTurn],
122 client: &dyn SummarizerClient,
123 context_budget: usize,
124 ) -> Result<String, ContextError> {
125 if turns.is_empty() {
126 return Ok(String::new());
127 }
128
129 let (collected_turns, _tokens_used) = Self::collect_within_budget(turns, context_budget);
131
132 if collected_turns.is_empty() {
133 return Ok(Self::create_simple_summary(turns));
134 }
135
136 let formatted_text = Self::format_turns_as_text(&collected_turns);
138
139 let messages = vec![Message::user().with_text(formatted_text)];
141
142 match client
144 .create_message(messages, Some(SUMMARY_SYSTEM_PROMPT))
145 .await
146 {
147 Ok(response) => {
148 let summary = response.text();
149 if summary.is_empty() {
150 Ok(Self::create_simple_summary(turns))
152 } else {
153 Ok(Self::truncate_summary(&summary, MAX_SUMMARY_LENGTH))
155 }
156 }
157 Err(_) => {
158 Ok(Self::create_simple_summary(turns))
160 }
161 }
162 }
163
164 pub fn create_simple_summary(turns: &[ConversationTurn]) -> String {
180 if turns.is_empty() {
181 return String::new();
182 }
183
184 let mut summary_parts: Vec<String> = Vec::new();
185
186 summary_parts.push(format!("[{} turns]", turns.len()));
188
189 let mut tools_used: Vec<String> = Vec::new();
191 for turn in turns {
192 Self::collect_tools_from_message(&turn.user, &mut tools_used);
193 Self::collect_tools_from_message(&turn.assistant, &mut tools_used);
194 }
195 if !tools_used.is_empty() {
196 tools_used.sort();
197 tools_used.dedup();
198 let tools_str = tools_used
199 .iter()
200 .take(5)
201 .cloned()
202 .collect::<Vec<_>>()
203 .join(", ");
204 summary_parts.push(format!("Tools: {}", tools_str));
205 }
206
207 if let Some(first_turn) = turns.first() {
209 let first_text = Self::extract_message_text(&first_turn.user);
210 if !first_text.is_empty() {
211 let topic = Self::truncate_summary(&first_text, 100);
212 summary_parts.push(format!("Started: {}", topic));
213 }
214 }
215
216 if let Some(last_turn) = turns.last() {
218 let last_text = Self::extract_message_text(&last_turn.assistant);
219 if !last_text.is_empty() {
220 let status = Self::truncate_summary(&last_text, 100);
221 summary_parts.push(format!("Last: {}", status));
222 }
223 }
224
225 summary_parts.join(" | ")
226 }
227
228 pub fn collect_within_budget(
242 turns: &[ConversationTurn],
243 budget: usize,
244 ) -> (Vec<ConversationTurn>, usize) {
245 let mut collected: Vec<ConversationTurn> = Vec::new();
246 let mut tokens_used: usize = 0;
247
248 for turn in turns {
249 let turn_tokens = turn.token_estimate;
250 if tokens_used + turn_tokens <= budget {
251 collected.push(turn.clone());
252 tokens_used += turn_tokens;
253 } else {
254 break;
256 }
257 }
258
259 (collected, tokens_used)
260 }
261
262 pub fn format_turns_as_text(turns: &[ConversationTurn]) -> String {
275 let mut parts: Vec<String> = Vec::new();
276
277 for (i, turn) in turns.iter().enumerate() {
278 parts.push(format!("--- Turn {} ---", i + 1));
279
280 let user_text = Self::extract_message_text(&turn.user);
282 if !user_text.is_empty() {
283 parts.push(format!("User: {}", user_text));
284 }
285
286 let assistant_text = Self::extract_message_text(&turn.assistant);
288 if !assistant_text.is_empty() {
289 parts.push(format!("Assistant: {}", assistant_text));
290 }
291
292 if let Some(summary) = &turn.summary {
294 parts.push(format!("(Summary: {})", summary));
295 }
296
297 parts.push(String::new()); }
299
300 parts.join("\n")
301 }
302
303 pub fn extract_message_text(message: &Message) -> String {
316 message
317 .content
318 .iter()
319 .filter_map(|content| match content {
320 MessageContent::Text(text_content) => Some(text_content.text.clone()),
321 MessageContent::Thinking(thinking) => Some(thinking.thinking.clone()),
322 MessageContent::ToolRequest(req) => {
323 req.tool_call
325 .as_ref()
326 .ok()
327 .map(|call| format!("[Tool: {}]", call.name))
328 }
329 MessageContent::ToolResponse(resp) => {
330 resp.tool_result.as_ref().ok().map(|result| {
332 let text: String = result
333 .content
334 .iter()
335 .filter_map(|c| c.as_text().map(|t| t.text.clone()))
336 .take(1)
337 .collect::<Vec<_>>()
338 .join("");
339 if text.len() > 100 {
340 format!("[Tool result: {}...]", text.get(..100).unwrap_or(&text))
341 } else if !text.is_empty() {
342 format!("[Tool result: {}]", text)
343 } else {
344 String::new()
345 }
346 })
347 }
348 _ => None,
349 })
350 .filter(|s| !s.is_empty())
351 .collect::<Vec<_>>()
352 .join(" ")
353 }
354
355 fn collect_tools_from_message(message: &Message, tools: &mut Vec<String>) {
357 for content in &message.content {
358 if let MessageContent::ToolRequest(req) = content {
359 if let Ok(call) = &req.tool_call {
360 tools.push(call.name.to_string());
361 }
362 }
363 }
364 }
365
366 fn truncate_summary(text: &str, max_len: usize) -> String {
368 let trimmed = text.trim();
369 if trimmed.len() <= max_len {
370 trimmed.to_string()
371 } else {
372 let truncated = trimmed.get(..max_len).unwrap_or(trimmed);
374 if let Some(last_space) = truncated.rfind(' ') {
375 format!("{}...", truncated.get(..last_space).unwrap_or(truncated))
376 } else {
377 format!("{}...", truncated)
378 }
379 }
380 }
381
382 pub fn estimate_summary_tokens(summary: &str) -> usize {
384 TokenEstimator::estimate_tokens(summary)
385 }
386}
387
388#[cfg(test)]
393mod tests {
394 use super::*;
395
396 fn create_test_turn(user_text: &str, assistant_text: &str) -> ConversationTurn {
397 let user = Message::user().with_text(user_text);
398 let assistant = Message::assistant().with_text(assistant_text);
399 let token_estimate = TokenEstimator::estimate_message_tokens(&user)
400 + TokenEstimator::estimate_message_tokens(&assistant);
401 ConversationTurn::new(user, assistant, token_estimate)
402 }
403
404 #[test]
405 fn test_create_simple_summary_empty() {
406 let turns: Vec<ConversationTurn> = vec![];
407 let summary = Summarizer::create_simple_summary(&turns);
408 assert!(summary.is_empty());
409 }
410
411 #[test]
412 fn test_create_simple_summary_single_turn() {
413 let turns = vec![create_test_turn(
414 "How do I create a function in Rust?",
415 "You can create a function using the fn keyword.",
416 )];
417
418 let summary = Summarizer::create_simple_summary(&turns);
419
420 assert!(summary.contains("[1 turns]"));
421 assert!(summary.contains("Started:"));
422 assert!(summary.contains("Last:"));
423 }
424
425 #[test]
426 fn test_create_simple_summary_multiple_turns() {
427 let turns = vec![
428 create_test_turn("Hello", "Hi there!"),
429 create_test_turn("How are you?", "I'm doing well, thanks!"),
430 create_test_turn("Goodbye", "See you later!"),
431 ];
432
433 let summary = Summarizer::create_simple_summary(&turns);
434
435 assert!(summary.contains("[3 turns]"));
436 }
437
438 #[test]
439 fn test_collect_within_budget_all_fit() {
440 let turns = vec![
441 create_test_turn("Short", "Reply"),
442 create_test_turn("Another", "Response"),
443 ];
444
445 let (collected, tokens) = Summarizer::collect_within_budget(&turns, 10000);
446
447 assert_eq!(collected.len(), 2);
448 assert!(tokens > 0);
449 }
450
451 #[test]
452 fn test_collect_within_budget_partial() {
453 let turns = vec![
454 create_test_turn("Short", "Reply"),
455 create_test_turn("A".repeat(1000).as_str(), "B".repeat(1000).as_str()),
456 ];
457
458 let (collected, _tokens) = Summarizer::collect_within_budget(&turns, 50);
460
461 assert_eq!(collected.len(), 1);
462 }
463
464 #[test]
465 fn test_collect_within_budget_none_fit() {
466 let turns = vec![create_test_turn(
467 "A".repeat(1000).as_str(),
468 "B".repeat(1000).as_str(),
469 )];
470
471 let (collected, tokens) = Summarizer::collect_within_budget(&turns, 10);
473
474 assert!(collected.is_empty());
475 assert_eq!(tokens, 0);
476 }
477
478 #[test]
479 fn test_format_turns_as_text() {
480 let turns = vec![
481 create_test_turn("Hello", "Hi there!"),
482 create_test_turn("How are you?", "I'm fine."),
483 ];
484
485 let formatted = Summarizer::format_turns_as_text(&turns);
486
487 assert!(formatted.contains("--- Turn 1 ---"));
488 assert!(formatted.contains("--- Turn 2 ---"));
489 assert!(formatted.contains("User: Hello"));
490 assert!(formatted.contains("Assistant: Hi there!"));
491 assert!(formatted.contains("User: How are you?"));
492 assert!(formatted.contains("Assistant: I'm fine."));
493 }
494
495 #[test]
496 fn test_extract_message_text_simple() {
497 let message = Message::user().with_text("Hello, world!");
498 let text = Summarizer::extract_message_text(&message);
499 assert_eq!(text, "Hello, world!");
500 }
501
502 #[test]
503 fn test_extract_message_text_multiple_blocks() {
504 let message = Message::user()
505 .with_text("First part")
506 .with_text("Second part");
507 let text = Summarizer::extract_message_text(&message);
508 assert!(text.contains("First part"));
509 assert!(text.contains("Second part"));
510 }
511
512 #[test]
513 fn test_truncate_summary_short() {
514 let text = "Short text";
515 let result = Summarizer::truncate_summary(text, 100);
516 assert_eq!(result, "Short text");
517 }
518
519 #[test]
520 fn test_truncate_summary_long() {
521 let text = "This is a very long text that needs to be truncated at a word boundary";
522 let result = Summarizer::truncate_summary(text, 30);
523 assert!(result.len() <= 33); assert!(result.ends_with("..."));
525 }
526
527 #[test]
528 fn test_estimate_summary_tokens() {
529 let summary = "This is a test summary";
530 let tokens = Summarizer::estimate_summary_tokens(summary);
531 assert!(tokens > 0);
532 }
533
534 #[test]
535 fn test_summarizer_response_text() {
536 use rmcp::model::{RawContent, RawTextContent};
537
538 let content = vec![Content {
539 raw: RawContent::Text(RawTextContent {
540 text: "Summary text".to_string(),
541 meta: None,
542 }),
543 annotations: None,
544 }];
545
546 let response = SummarizerResponse::new(content, None);
547 assert_eq!(response.text(), "Summary text");
548 }
549
550 #[test]
551 fn test_summarizer_response_empty() {
552 let response = SummarizerResponse::new(vec![], None);
553 assert!(response.text().is_empty());
554 }
555
556 struct MockSummarizerClient {
558 response: Option<String>,
559 should_fail: bool,
560 }
561
562 impl MockSummarizerClient {
563 fn new(response: Option<String>) -> Self {
564 Self {
565 response,
566 should_fail: false,
567 }
568 }
569
570 fn failing() -> Self {
571 Self {
572 response: None,
573 should_fail: true,
574 }
575 }
576 }
577
578 #[async_trait]
579 impl SummarizerClient for MockSummarizerClient {
580 async fn create_message(
581 &self,
582 _messages: Vec<Message>,
583 _system_prompt: Option<&str>,
584 ) -> Result<SummarizerResponse, ContextError> {
585 if self.should_fail {
586 return Err(ContextError::SummarizationFailed(
587 "Mock failure".to_string(),
588 ));
589 }
590
591 let content = match &self.response {
592 Some(text) => {
593 use rmcp::model::{RawContent, RawTextContent};
594 vec![Content {
595 raw: RawContent::Text(RawTextContent {
596 text: text.clone(),
597 meta: None,
598 }),
599 annotations: None,
600 }]
601 }
602 None => vec![],
603 };
604
605 Ok(SummarizerResponse::new(content, None))
606 }
607 }
608
609 #[tokio::test]
610 async fn test_generate_ai_summary_success() {
611 let turns = vec![create_test_turn("Hello", "Hi there!")];
612 let client = MockSummarizerClient::new(Some("AI generated summary".to_string()));
613
614 let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
615
616 assert!(result.is_ok());
617 assert_eq!(result.unwrap(), "AI generated summary");
618 }
619
620 #[tokio::test]
621 async fn test_generate_ai_summary_empty_response_fallback() {
622 let turns = vec![create_test_turn("Hello", "Hi there!")];
623 let client = MockSummarizerClient::new(None); let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
626
627 assert!(result.is_ok());
628 let summary = result.unwrap();
629 assert!(summary.contains("[1 turns]"));
631 }
632
633 #[tokio::test]
634 async fn test_generate_ai_summary_error_fallback() {
635 let turns = vec![create_test_turn("Hello", "Hi there!")];
636 let client = MockSummarizerClient::failing();
637
638 let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
639
640 assert!(result.is_ok());
641 let summary = result.unwrap();
642 assert!(summary.contains("[1 turns]"));
644 }
645
646 #[tokio::test]
647 async fn test_generate_ai_summary_empty_turns() {
648 let turns: Vec<ConversationTurn> = vec![];
649 let client = MockSummarizerClient::new(Some("Should not be called".to_string()));
650
651 let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
652
653 assert!(result.is_ok());
654 assert!(result.unwrap().is_empty());
655 }
656
657 #[tokio::test]
658 async fn test_generate_ai_summary_truncates_long_response() {
659 let turns = vec![create_test_turn("Hello", "Hi there!")];
660 let long_response = "A".repeat(1000);
661 let client = MockSummarizerClient::new(Some(long_response));
662
663 let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
664
665 assert!(result.is_ok());
666 let summary = result.unwrap();
667 assert!(summary.len() <= MAX_SUMMARY_LENGTH + 3); }
670}