Skip to main content

aster/context/
summarizer.rs

1//! Summarizer Module
2//!
3//! This module provides intelligent message summarization functionality to compress
4//! old conversations while preserving key information. It supports:
5//!
6//! - AI-powered summarization using LLM
7//! - Simple text extraction fallback
8//! - Budget-aware message collection
9//! - Conversation turn formatting
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use aster::context::summarizer::Summarizer;
15//! use aster::context::types::ConversationTurn;
16//!
17//! let turns: Vec<ConversationTurn> = vec![/* ... */];
18//! let summary = Summarizer::create_simple_summary(&turns);
19//! ```
20
21use crate::context::token_estimator::TokenEstimator;
22use crate::context::types::{ContextError, ConversationTurn, TokenUsage};
23use crate::conversation::message::{Message, MessageContent};
24use async_trait::async_trait;
25use rmcp::model::Content;
26use std::result::Result;
27
28// ============================================================================
29// Constants
30// ============================================================================
31
32/// System prompt for AI summarization
33pub const SUMMARY_SYSTEM_PROMPT: &str =
34    "Summarize this coding conversation in under 50 characters.\n\
35     Capture the main task, key files, problems addressed, and current status.";
36
37/// Default context budget for summarization (in tokens)
38pub const DEFAULT_SUMMARY_BUDGET: usize = 4000;
39
40/// Maximum summary length in characters
41pub const MAX_SUMMARY_LENGTH: usize = 500;
42
43// ============================================================================
44// SummarizerClient Trait
45// ============================================================================
46
47/// Response from the summarizer client
48#[derive(Debug, Clone)]
49pub struct SummarizerResponse {
50    /// Content blocks from the response
51    pub content: Vec<Content>,
52    /// Token usage statistics
53    pub usage: Option<TokenUsage>,
54}
55
56impl SummarizerResponse {
57    /// Create a new SummarizerResponse
58    pub fn new(content: Vec<Content>, usage: Option<TokenUsage>) -> Self {
59        Self { content, usage }
60    }
61
62    /// Extract text content from the response
63    pub fn text(&self) -> String {
64        self.content
65            .iter()
66            .filter_map(|c| c.as_text().map(|t| t.text.clone()))
67            .collect::<Vec<_>>()
68            .join("")
69    }
70}
71
72/// Trait for clients that can generate AI summaries.
73///
74/// This trait abstracts the LLM client interface, allowing for different
75/// implementations (e.g., Anthropic, OpenAI) or mock clients for testing.
76#[async_trait]
77pub trait SummarizerClient: Send + Sync {
78    /// Create a message using the LLM.
79    ///
80    /// # Arguments
81    ///
82    /// * `messages` - The conversation messages to send
83    /// * `system_prompt` - Optional system prompt to guide the response
84    ///
85    /// # Returns
86    ///
87    /// A `SummarizerResponse` containing the generated content and usage stats.
88    async fn create_message(
89        &self,
90        messages: Vec<Message>,
91        system_prompt: Option<&str>,
92    ) -> Result<SummarizerResponse, ContextError>;
93}
94
95// ============================================================================
96// Summarizer
97// ============================================================================
98
99/// Intelligent summarizer for conversation turns.
100///
101/// Provides methods to generate concise summaries of conversation history,
102/// either using AI or simple text extraction.
103pub struct Summarizer;
104
105impl Summarizer {
106    /// Generate an AI-powered summary of conversation turns.
107    ///
108    /// Uses an LLM to create a concise summary capturing the main task,
109    /// key files, problems addressed, and current status.
110    ///
111    /// # Arguments
112    ///
113    /// * `turns` - The conversation turns to summarize
114    /// * `client` - The LLM client to use for summarization
115    /// * `context_budget` - Maximum tokens to include in the summarization request
116    ///
117    /// # Returns
118    ///
119    /// A summary string, or falls back to simple summary on failure.
120    pub async fn generate_ai_summary(
121        turns: &[ConversationTurn],
122        client: &dyn SummarizerClient,
123        context_budget: usize,
124    ) -> Result<String, ContextError> {
125        if turns.is_empty() {
126            return Ok(String::new());
127        }
128
129        // Collect turns within budget
130        let (collected_turns, _tokens_used) = Self::collect_within_budget(turns, context_budget);
131
132        if collected_turns.is_empty() {
133            return Ok(Self::create_simple_summary(turns));
134        }
135
136        // Format turns as text for summarization
137        let formatted_text = Self::format_turns_as_text(&collected_turns);
138
139        // Create the summarization request
140        let messages = vec![Message::user().with_text(formatted_text)];
141
142        // Call the LLM
143        match client
144            .create_message(messages, Some(SUMMARY_SYSTEM_PROMPT))
145            .await
146        {
147            Ok(response) => {
148                let summary = response.text();
149                if summary.is_empty() {
150                    // Fall back to simple summary if AI returns empty
151                    Ok(Self::create_simple_summary(turns))
152                } else {
153                    // Truncate if too long
154                    Ok(Self::truncate_summary(&summary, MAX_SUMMARY_LENGTH))
155                }
156            }
157            Err(_) => {
158                // Fall back to simple summary on error
159                Ok(Self::create_simple_summary(turns))
160            }
161        }
162    }
163
164    /// Create a simple summary without using AI.
165    ///
166    /// Extracts key information from conversation turns including:
167    /// - Number of turns
168    /// - Key topics mentioned
169    /// - Files referenced
170    /// - Tools used
171    ///
172    /// # Arguments
173    ///
174    /// * `turns` - The conversation turns to summarize
175    ///
176    /// # Returns
177    ///
178    /// A simple text summary.
179    pub fn create_simple_summary(turns: &[ConversationTurn]) -> String {
180        if turns.is_empty() {
181            return String::new();
182        }
183
184        let mut summary_parts: Vec<String> = Vec::new();
185
186        // Add turn count
187        summary_parts.push(format!("[{} turns]", turns.len()));
188
189        // Collect unique tools used
190        let mut tools_used: Vec<String> = Vec::new();
191        for turn in turns {
192            Self::collect_tools_from_message(&turn.user, &mut tools_used);
193            Self::collect_tools_from_message(&turn.assistant, &mut tools_used);
194        }
195        if !tools_used.is_empty() {
196            tools_used.sort();
197            tools_used.dedup();
198            let tools_str = tools_used
199                .iter()
200                .take(5)
201                .cloned()
202                .collect::<Vec<_>>()
203                .join(", ");
204            summary_parts.push(format!("Tools: {}", tools_str));
205        }
206
207        // Extract first user message as topic indicator
208        if let Some(first_turn) = turns.first() {
209            let first_text = Self::extract_message_text(&first_turn.user);
210            if !first_text.is_empty() {
211                let topic = Self::truncate_summary(&first_text, 100);
212                summary_parts.push(format!("Started: {}", topic));
213            }
214        }
215
216        // Extract last assistant response as status indicator
217        if let Some(last_turn) = turns.last() {
218            let last_text = Self::extract_message_text(&last_turn.assistant);
219            if !last_text.is_empty() {
220                let status = Self::truncate_summary(&last_text, 100);
221                summary_parts.push(format!("Last: {}", status));
222            }
223        }
224
225        summary_parts.join(" | ")
226    }
227
228    /// Collect conversation turns within a token budget.
229    ///
230    /// Iterates through turns from oldest to newest, collecting as many
231    /// as will fit within the specified token budget.
232    ///
233    /// # Arguments
234    ///
235    /// * `turns` - The conversation turns to collect from
236    /// * `budget` - Maximum tokens to collect
237    ///
238    /// # Returns
239    ///
240    /// A tuple of (collected turns, total tokens used).
241    pub fn collect_within_budget(
242        turns: &[ConversationTurn],
243        budget: usize,
244    ) -> (Vec<ConversationTurn>, usize) {
245        let mut collected: Vec<ConversationTurn> = Vec::new();
246        let mut tokens_used: usize = 0;
247
248        for turn in turns {
249            let turn_tokens = turn.token_estimate;
250            if tokens_used + turn_tokens <= budget {
251                collected.push(turn.clone());
252                tokens_used += turn_tokens;
253            } else {
254                // Budget exceeded, stop collecting
255                break;
256            }
257        }
258
259        (collected, tokens_used)
260    }
261
262    /// Format conversation turns as readable text for summarization.
263    ///
264    /// Creates a structured text representation of the conversation
265    /// suitable for sending to an LLM for summarization.
266    ///
267    /// # Arguments
268    ///
269    /// * `turns` - The conversation turns to format
270    ///
271    /// # Returns
272    ///
273    /// A formatted text string.
274    pub fn format_turns_as_text(turns: &[ConversationTurn]) -> String {
275        let mut parts: Vec<String> = Vec::new();
276
277        for (i, turn) in turns.iter().enumerate() {
278            parts.push(format!("--- Turn {} ---", i + 1));
279
280            // Format user message
281            let user_text = Self::extract_message_text(&turn.user);
282            if !user_text.is_empty() {
283                parts.push(format!("User: {}", user_text));
284            }
285
286            // Format assistant message
287            let assistant_text = Self::extract_message_text(&turn.assistant);
288            if !assistant_text.is_empty() {
289                parts.push(format!("Assistant: {}", assistant_text));
290            }
291
292            // Add summary if already summarized
293            if let Some(summary) = &turn.summary {
294                parts.push(format!("(Summary: {})", summary));
295            }
296
297            parts.push(String::new()); // Empty line between turns
298        }
299
300        parts.join("\n")
301    }
302
303    /// Extract text content from a message.
304    ///
305    /// Concatenates all text content blocks from the message,
306    /// ignoring non-text content like images or tool calls.
307    ///
308    /// # Arguments
309    ///
310    /// * `message` - The message to extract text from
311    ///
312    /// # Returns
313    ///
314    /// The concatenated text content.
315    pub fn extract_message_text(message: &Message) -> String {
316        message
317            .content
318            .iter()
319            .filter_map(|content| match content {
320                MessageContent::Text(text_content) => Some(text_content.text.clone()),
321                MessageContent::Thinking(thinking) => Some(thinking.thinking.clone()),
322                MessageContent::ToolRequest(req) => {
323                    // Include tool name for context
324                    req.tool_call
325                        .as_ref()
326                        .ok()
327                        .map(|call| format!("[Tool: {}]", call.name))
328                }
329                MessageContent::ToolResponse(resp) => {
330                    // Include brief tool result
331                    resp.tool_result.as_ref().ok().map(|result| {
332                        let text: String = result
333                            .content
334                            .iter()
335                            .filter_map(|c| c.as_text().map(|t| t.text.clone()))
336                            .take(1)
337                            .collect::<Vec<_>>()
338                            .join("");
339                        if text.len() > 100 {
340                            format!("[Tool result: {}...]", text.get(..100).unwrap_or(&text))
341                        } else if !text.is_empty() {
342                            format!("[Tool result: {}]", text)
343                        } else {
344                            String::new()
345                        }
346                    })
347                }
348                _ => None,
349            })
350            .filter(|s| !s.is_empty())
351            .collect::<Vec<_>>()
352            .join(" ")
353    }
354
355    /// Collect tool names from a message.
356    fn collect_tools_from_message(message: &Message, tools: &mut Vec<String>) {
357        for content in &message.content {
358            if let MessageContent::ToolRequest(req) = content {
359                if let Ok(call) = &req.tool_call {
360                    tools.push(call.name.to_string());
361                }
362            }
363        }
364    }
365
366    /// Truncate a summary to a maximum length.
367    fn truncate_summary(text: &str, max_len: usize) -> String {
368        let trimmed = text.trim();
369        if trimmed.len() <= max_len {
370            trimmed.to_string()
371        } else {
372            // Find a good break point (word boundary)
373            let truncated = trimmed.get(..max_len).unwrap_or(trimmed);
374            if let Some(last_space) = truncated.rfind(' ') {
375                format!("{}...", truncated.get(..last_space).unwrap_or(truncated))
376            } else {
377                format!("{}...", truncated)
378            }
379        }
380    }
381
382    /// Estimate the token count for a summary.
383    pub fn estimate_summary_tokens(summary: &str) -> usize {
384        TokenEstimator::estimate_tokens(summary)
385    }
386}
387
388// ============================================================================
389// Tests
390// ============================================================================
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    fn create_test_turn(user_text: &str, assistant_text: &str) -> ConversationTurn {
397        let user = Message::user().with_text(user_text);
398        let assistant = Message::assistant().with_text(assistant_text);
399        let token_estimate = TokenEstimator::estimate_message_tokens(&user)
400            + TokenEstimator::estimate_message_tokens(&assistant);
401        ConversationTurn::new(user, assistant, token_estimate)
402    }
403
404    #[test]
405    fn test_create_simple_summary_empty() {
406        let turns: Vec<ConversationTurn> = vec![];
407        let summary = Summarizer::create_simple_summary(&turns);
408        assert!(summary.is_empty());
409    }
410
411    #[test]
412    fn test_create_simple_summary_single_turn() {
413        let turns = vec![create_test_turn(
414            "How do I create a function in Rust?",
415            "You can create a function using the fn keyword.",
416        )];
417
418        let summary = Summarizer::create_simple_summary(&turns);
419
420        assert!(summary.contains("[1 turns]"));
421        assert!(summary.contains("Started:"));
422        assert!(summary.contains("Last:"));
423    }
424
425    #[test]
426    fn test_create_simple_summary_multiple_turns() {
427        let turns = vec![
428            create_test_turn("Hello", "Hi there!"),
429            create_test_turn("How are you?", "I'm doing well, thanks!"),
430            create_test_turn("Goodbye", "See you later!"),
431        ];
432
433        let summary = Summarizer::create_simple_summary(&turns);
434
435        assert!(summary.contains("[3 turns]"));
436    }
437
438    #[test]
439    fn test_collect_within_budget_all_fit() {
440        let turns = vec![
441            create_test_turn("Short", "Reply"),
442            create_test_turn("Another", "Response"),
443        ];
444
445        let (collected, tokens) = Summarizer::collect_within_budget(&turns, 10000);
446
447        assert_eq!(collected.len(), 2);
448        assert!(tokens > 0);
449    }
450
451    #[test]
452    fn test_collect_within_budget_partial() {
453        let turns = vec![
454            create_test_turn("Short", "Reply"),
455            create_test_turn("A".repeat(1000).as_str(), "B".repeat(1000).as_str()),
456        ];
457
458        // Very small budget should only fit first turn
459        let (collected, _tokens) = Summarizer::collect_within_budget(&turns, 50);
460
461        assert_eq!(collected.len(), 1);
462    }
463
464    #[test]
465    fn test_collect_within_budget_none_fit() {
466        let turns = vec![create_test_turn(
467            "A".repeat(1000).as_str(),
468            "B".repeat(1000).as_str(),
469        )];
470
471        // Budget too small for any turn
472        let (collected, tokens) = Summarizer::collect_within_budget(&turns, 10);
473
474        assert!(collected.is_empty());
475        assert_eq!(tokens, 0);
476    }
477
478    #[test]
479    fn test_format_turns_as_text() {
480        let turns = vec![
481            create_test_turn("Hello", "Hi there!"),
482            create_test_turn("How are you?", "I'm fine."),
483        ];
484
485        let formatted = Summarizer::format_turns_as_text(&turns);
486
487        assert!(formatted.contains("--- Turn 1 ---"));
488        assert!(formatted.contains("--- Turn 2 ---"));
489        assert!(formatted.contains("User: Hello"));
490        assert!(formatted.contains("Assistant: Hi there!"));
491        assert!(formatted.contains("User: How are you?"));
492        assert!(formatted.contains("Assistant: I'm fine."));
493    }
494
495    #[test]
496    fn test_extract_message_text_simple() {
497        let message = Message::user().with_text("Hello, world!");
498        let text = Summarizer::extract_message_text(&message);
499        assert_eq!(text, "Hello, world!");
500    }
501
502    #[test]
503    fn test_extract_message_text_multiple_blocks() {
504        let message = Message::user()
505            .with_text("First part")
506            .with_text("Second part");
507        let text = Summarizer::extract_message_text(&message);
508        assert!(text.contains("First part"));
509        assert!(text.contains("Second part"));
510    }
511
512    #[test]
513    fn test_truncate_summary_short() {
514        let text = "Short text";
515        let result = Summarizer::truncate_summary(text, 100);
516        assert_eq!(result, "Short text");
517    }
518
519    #[test]
520    fn test_truncate_summary_long() {
521        let text = "This is a very long text that needs to be truncated at a word boundary";
522        let result = Summarizer::truncate_summary(text, 30);
523        assert!(result.len() <= 33); // 30 + "..."
524        assert!(result.ends_with("..."));
525    }
526
527    #[test]
528    fn test_estimate_summary_tokens() {
529        let summary = "This is a test summary";
530        let tokens = Summarizer::estimate_summary_tokens(summary);
531        assert!(tokens > 0);
532    }
533
534    #[test]
535    fn test_summarizer_response_text() {
536        use rmcp::model::{RawContent, RawTextContent};
537
538        let content = vec![Content {
539            raw: RawContent::Text(RawTextContent {
540                text: "Summary text".to_string(),
541                meta: None,
542            }),
543            annotations: None,
544        }];
545
546        let response = SummarizerResponse::new(content, None);
547        assert_eq!(response.text(), "Summary text");
548    }
549
550    #[test]
551    fn test_summarizer_response_empty() {
552        let response = SummarizerResponse::new(vec![], None);
553        assert!(response.text().is_empty());
554    }
555
556    // Mock client for testing AI summary
557    struct MockSummarizerClient {
558        response: Option<String>,
559        should_fail: bool,
560    }
561
562    impl MockSummarizerClient {
563        fn new(response: Option<String>) -> Self {
564            Self {
565                response,
566                should_fail: false,
567            }
568        }
569
570        fn failing() -> Self {
571            Self {
572                response: None,
573                should_fail: true,
574            }
575        }
576    }
577
578    #[async_trait]
579    impl SummarizerClient for MockSummarizerClient {
580        async fn create_message(
581            &self,
582            _messages: Vec<Message>,
583            _system_prompt: Option<&str>,
584        ) -> Result<SummarizerResponse, ContextError> {
585            if self.should_fail {
586                return Err(ContextError::SummarizationFailed(
587                    "Mock failure".to_string(),
588                ));
589            }
590
591            let content = match &self.response {
592                Some(text) => {
593                    use rmcp::model::{RawContent, RawTextContent};
594                    vec![Content {
595                        raw: RawContent::Text(RawTextContent {
596                            text: text.clone(),
597                            meta: None,
598                        }),
599                        annotations: None,
600                    }]
601                }
602                None => vec![],
603            };
604
605            Ok(SummarizerResponse::new(content, None))
606        }
607    }
608
609    #[tokio::test]
610    async fn test_generate_ai_summary_success() {
611        let turns = vec![create_test_turn("Hello", "Hi there!")];
612        let client = MockSummarizerClient::new(Some("AI generated summary".to_string()));
613
614        let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
615
616        assert!(result.is_ok());
617        assert_eq!(result.unwrap(), "AI generated summary");
618    }
619
620    #[tokio::test]
621    async fn test_generate_ai_summary_empty_response_fallback() {
622        let turns = vec![create_test_turn("Hello", "Hi there!")];
623        let client = MockSummarizerClient::new(None); // Empty response
624
625        let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
626
627        assert!(result.is_ok());
628        let summary = result.unwrap();
629        // Should fall back to simple summary
630        assert!(summary.contains("[1 turns]"));
631    }
632
633    #[tokio::test]
634    async fn test_generate_ai_summary_error_fallback() {
635        let turns = vec![create_test_turn("Hello", "Hi there!")];
636        let client = MockSummarizerClient::failing();
637
638        let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
639
640        assert!(result.is_ok());
641        let summary = result.unwrap();
642        // Should fall back to simple summary
643        assert!(summary.contains("[1 turns]"));
644    }
645
646    #[tokio::test]
647    async fn test_generate_ai_summary_empty_turns() {
648        let turns: Vec<ConversationTurn> = vec![];
649        let client = MockSummarizerClient::new(Some("Should not be called".to_string()));
650
651        let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
652
653        assert!(result.is_ok());
654        assert!(result.unwrap().is_empty());
655    }
656
657    #[tokio::test]
658    async fn test_generate_ai_summary_truncates_long_response() {
659        let turns = vec![create_test_turn("Hello", "Hi there!")];
660        let long_response = "A".repeat(1000);
661        let client = MockSummarizerClient::new(Some(long_response));
662
663        let result = Summarizer::generate_ai_summary(&turns, &client, 10000).await;
664
665        assert!(result.is_ok());
666        let summary = result.unwrap();
667        // Should be truncated to MAX_SUMMARY_LENGTH
668        assert!(summary.len() <= MAX_SUMMARY_LENGTH + 3); // +3 for "..."
669    }
670}