Skip to main content

bamboo_engine/
llm_summarizer.rs

1//! LLM-backed conversation summarizer.
2//!
3//! `LlmSummarizer` is the infrastructure-coupled implementation of
4//! `bamboo_compression::Summarizer`: it calls the session model to produce a
5//! rich summary of compressed/removed messages, falling back to the pure
6//! `HeuristicSummarizer` on failure. It lives in the engine (not in
7//! bamboo-compression) so that the compression crate stays free of any
8//! LLM-provider dependency.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use futures::StreamExt;
14
15use bamboo_compression::{HeuristicSummarizer, Summarizer};
16use bamboo_domain::ReasoningEffort;
17use bamboo_domain::{
18    ContextBlock, ContextBlockPriority, ContextBlockStability, ContextBlockType, Message, Role,
19};
20use bamboo_llm::LLMChunk;
21use bamboo_llm::{LLMProvider, LLMRequestOptions};
22
23/// Mode controlling how the LLM summarizer handles existing summaries.
24#[derive(Debug, Clone, Default)]
25pub enum SummaryMode {
26    /// Generate a complete summary from scratch (default).
27    #[default]
28    FullRewrite,
29    /// Update an existing summary by incorporating new information incrementally.
30    IncrementalMerge,
31}
32
33/// LLM-based summarizer that calls the current session's model to generate
34/// a rich summary of compressed/removed messages.
35///
36/// Falls back to [`HeuristicSummarizer`] if the LLM call fails.
37pub struct LlmSummarizer {
38    llm: Arc<dyn LLMProvider>,
39    model: String,
40    /// Optional existing summary to build upon (incremental summarization).
41    existing_summary: Option<String>,
42    /// Structured runtime context blocks that should inform summarization.
43    context_blocks: Vec<ContextBlock>,
44    /// Optional user-provided instructions that override/extend the default summary focus.
45    custom_instructions: Option<String>,
46    /// Controls how the summarizer handles existing summaries.
47    summary_mode: SummaryMode,
48}
49
50impl LlmSummarizer {
51    pub fn new(
52        llm: Arc<dyn LLMProvider>,
53        model: String,
54        existing_summary: Option<String>,
55        task_list_prompt: Option<String>,
56    ) -> Self {
57        let context_blocks = task_list_prompt
58            .as_deref()
59            .map(str::trim)
60            .filter(|value| !value.is_empty())
61            .map(|task_list| {
62                vec![ContextBlock::new(
63                    ContextBlockType::TaskSnapshot,
64                    ContextBlockPriority::High,
65                    ContextBlockStability::RoundDynamic,
66                    "Current Task List",
67                    task_list,
68                )]
69            })
70            .unwrap_or_default();
71
72        Self {
73            llm,
74            model,
75            existing_summary,
76            context_blocks,
77            custom_instructions: None,
78            summary_mode: SummaryMode::default(),
79        }
80    }
81
82    pub fn with_context_blocks(mut self, context_blocks: Vec<ContextBlock>) -> Self {
83        self.context_blocks = context_blocks;
84        self
85    }
86
87    pub fn with_custom_instructions(mut self, instructions: Option<String>) -> Self {
88        self.custom_instructions = instructions;
89        self
90    }
91
92    pub fn with_summary_mode(mut self, mode: SummaryMode) -> Self {
93        self.summary_mode = mode;
94        self
95    }
96
97    /// Build the summarization prompt for the LLM.
98    fn build_summarization_messages(&self, messages: &[Message]) -> Vec<Message> {
99        let mut prompt_messages = Vec::new();
100
101        let system_prompt = match self.summary_mode {
102            SummaryMode::FullRewrite => {
103                r#"You are a conversation summarizer. Your task is to create a concise but reliable working-memory summary for a conversation that was removed due to context window limits.
104
105Guidelines:
106- First capture the in-flight work right before compression (what was being done, where, and with which tool/file)
107- Distinguish clearly between CURRENT ACTIVE work, COMPLETED work, and OBSOLETE or superseded work
108- Do not restate old tasks as active unless they are still unresolved
109- The provided current task list is the source of truth for active work
110- Preserve key decisions, constraints, file paths, code changes, tool findings, blockers, and important outcomes
111- Preserve error messages, test results (pass/fail counts), and function/variable names that are relevant to active work
112- If earlier plans conflict with newer messages or the current task list, mark them as obsolete or completed
113- Explicitly evaluate each clear user requirement (e.g. requirement 1, requirement 2) with a status and evidence
114- Keep the next step specific and aligned with the active work only
115- Use structured sections
116- Write in the same language as the original conversation"#
117            }
118            SummaryMode::IncrementalMerge => {
119                r#"You are updating an existing conversation summary with new information from recent messages.
120
121Guidelines:
122- Incorporate new information into the existing summary structure
123- Mark previously active work as completed if the new messages confirm completion
124- Remove or condense information that is no longer relevant
125- Preserve all key decisions, file paths, and constraints that remain active
126- If new messages conflict with the existing summary, the new messages take precedence
127- Keep the summary focused on what is currently active and relevant
128- The provided current task list is the source of truth for active work
129- Maintain the same structured sections as the existing summary
130- Write in the same language as the original conversation
131- Be concise: avoid repeating information already well-captured in the existing summary"#
132            }
133        };
134
135        prompt_messages.push(Message::system(system_prompt));
136
137        let mut user_content = String::new();
138
139        if let Some(ref existing) = self.existing_summary {
140            user_content.push_str("## Previous Summary\n\n");
141            user_content.push_str(existing);
142            user_content.push_str("\n\n---\n\n");
143        }
144
145        if !self.context_blocks.is_empty() {
146            user_content.push_str("## Compression Context Blocks\n\n");
147            for block in &self.context_blocks {
148                user_content.push_str(&format!(
149                    "### {}\n- type: {}\n- priority: {}\n- stability: {}\n\n{}\n\n",
150                    block.title.trim(),
151                    block.block_type.as_str(),
152                    block.priority.as_str(),
153                    block.stability.as_str(),
154                    block.content.trim(),
155                ));
156            }
157            user_content.push_str("---\n\n");
158        }
159
160        if let Some(ref instructions) = self.custom_instructions {
161            if !instructions.trim().is_empty() {
162                user_content.push_str("## Custom Compression Instructions\n\n");
163                user_content.push_str(instructions.trim());
164                user_content.push_str("\n\n---\n\n");
165            }
166        }
167
168        user_content.push_str(
169            "## Required Output Sections\n1. Pre-compression in-flight work (what was being done immediately before compression)\n2. Current active objective\n3. Requirement checklist (Requirement | Status: completed/in_progress/pending/blocked/obsolete | Evidence)\n4. Active tasks\n5. Completed tasks\n6. Obsolete or superseded tasks\n7. Important context and constraints\n8. Files, code, and tool findings\n9. Open issues and next step\n\n",
170        );
171
172        user_content.push_str("## Messages to Summarize\n\n");
173
174        for message in messages {
175            let role_label = match message.role {
176                Role::User => "User",
177                Role::Assistant => "Assistant",
178                Role::Tool => "Tool Result",
179                Role::System => continue,
180            };
181
182            if let Some(ref tool_calls) = message.tool_calls {
183                if !tool_calls.is_empty() {
184                    let tool_names: Vec<&str> = tool_calls
185                        .iter()
186                        .map(|tc| tc.function.name.as_str())
187                        .collect();
188                    user_content.push_str(&format!(
189                        "**{}** [called tools: {}]:\n",
190                        role_label,
191                        tool_names.join(", ")
192                    ));
193                } else {
194                    user_content.push_str(&format!("**{}**:\n", role_label));
195                }
196            } else {
197                user_content.push_str(&format!("**{}**:\n", role_label));
198            }
199
200            if let Some(ref tool_call_id) = message.tool_call_id {
201                user_content.push_str(&format!("(tool_call_id: {})\n", tool_call_id));
202            }
203
204            let content = &message.content;
205            const MAX_CONTENT_CHARS: usize = 2000;
206            if content.chars().count() > MAX_CONTENT_CHARS {
207                let truncated: String = content.chars().take(MAX_CONTENT_CHARS).collect();
208                user_content.push_str(&truncated);
209                user_content.push_str("... [truncated]\n\n");
210            } else {
211                user_content.push_str(content);
212                user_content.push_str("\n\n");
213            }
214        }
215
216        user_content.push_str(
217            "\n---\n\nReturn only the summary text. Be explicit about what is active now versus what is already completed or no longer relevant.",
218        );
219
220        prompt_messages.push(Message::user(user_content));
221
222        prompt_messages
223    }
224
225    /// Consume an LLM stream and collect the full text response.
226    async fn collect_stream_response(
227        &self,
228        messages: &[Message],
229    ) -> Result<String, bamboo_compression::types::BudgetError> {
230        // Summarization is a lightweight auxiliary request; cap reasoning effort at `high`
231        // to stay compatible with fast models (e.g. gpt-5-mini).
232        let options = LLMRequestOptions {
233            session_id: None,
234            reasoning_effort: Some(ReasoningEffort::High),
235            parallel_tool_calls: None,
236            responses: None,
237            request_purpose: Some("compression".to_string()),
238            cache: None,
239        };
240        let stream = self
241            .llm
242            .chat_stream_with_options(messages, &[], Some(8192), &self.model, Some(&options))
243            .await
244            .map_err(|e| {
245                bamboo_compression::types::BudgetError::TokenCountError(format!(
246                    "LLM summarization call failed: {}",
247                    e
248                ))
249            })?;
250
251        let mut content = String::new();
252        let mut stream = stream;
253
254        while let Some(chunk_result) = stream.next().await {
255            match chunk_result {
256                Ok(LLMChunk::Token(text)) => content.push_str(&text),
257                Ok(LLMChunk::Done) => break,
258                Ok(_) => {} // Ignore reasoning tokens, tool calls, etc.
259                Err(e) => {
260                    tracing::warn!("LLM summarization stream error: {}", e);
261                    if !content.is_empty() {
262                        break;
263                    }
264                    return Err(bamboo_compression::types::BudgetError::TokenCountError(
265                        format!("LLM summarization stream failed: {}", e),
266                    ));
267                }
268            }
269        }
270
271        Ok(content)
272    }
273}
274
275impl std::fmt::Debug for LlmSummarizer {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("LlmSummarizer")
278            .field("model", &self.model)
279            .field("has_existing_summary", &self.existing_summary.is_some())
280            .field("context_block_count", &self.context_blocks.len())
281            .finish()
282    }
283}
284
285#[async_trait]
286impl Summarizer for LlmSummarizer {
287    async fn summarize(
288        &self,
289        messages: &[Message],
290    ) -> Result<String, bamboo_compression::types::BudgetError> {
291        if messages.is_empty() {
292            return Ok("No conversation history to summarize.".to_string());
293        }
294
295        let prompt_messages = self.build_summarization_messages(messages);
296
297        tracing::info!(
298            "LlmSummarizer: summarizing {} messages using model '{}' (existing_summary={})",
299            messages.len(),
300            self.model,
301            self.existing_summary.is_some()
302        );
303
304        match self.collect_stream_response(&prompt_messages).await {
305            Ok(summary) if !summary.trim().is_empty() => {
306                tracing::info!("LlmSummarizer: generated summary ({} chars)", summary.len());
307                Ok(summary)
308            }
309            Ok(_) => {
310                tracing::warn!(
311                    "LlmSummarizer: LLM returned empty summary, falling back to heuristic"
312                );
313                HeuristicSummarizer::new().summarize(messages).await
314            }
315            Err(e) => {
316                tracing::warn!(
317                    "LlmSummarizer: LLM call failed ({}), falling back to heuristic",
318                    e
319                );
320                HeuristicSummarizer::new().summarize(messages).await
321            }
322        }
323    }
324
325    fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
326        // LLM summaries tend to be more detailed; estimate higher than heuristic
327        (message_count * 80).min(2000) as u32
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use bamboo_domain::ReasoningEffort;
335    use bamboo_llm::{LLMChunk, LLMError, LLMRequestOptions, LLMStream};
336    use futures::stream;
337    use std::sync::Mutex;
338
339    struct DummyProvider;
340
341    #[async_trait]
342    impl LLMProvider for DummyProvider {
343        async fn chat_stream(
344            &self,
345            _messages: &[Message],
346            _tools: &[bamboo_domain::ToolSchema],
347            _max_output_tokens: Option<u32>,
348            _model: &str,
349        ) -> Result<LLMStream, LLMError> {
350            Ok(Box::pin(stream::iter(vec![
351                Ok::<LLMChunk, LLMError>(LLMChunk::Token("dummy summary".to_string())),
352                Ok::<LLMChunk, LLMError>(LLMChunk::Done),
353            ])))
354        }
355    }
356    fn llm_summarizer_prompt_includes_context_blocks_and_state_sections() {
357        let summarizer = LlmSummarizer::new(
358            Arc::new(DummyProvider),
359            "gpt-4o-mini".to_string(),
360            Some("Earlier summary".to_string()),
361            Some(
362                "## Current Task List\n[/] task_1: Fix compression bounce\n[x] task_0: Analyze bug"
363                    .to_string(),
364            ),
365        )
366        .with_context_blocks(vec![
367            ContextBlock::new(
368                ContextBlockType::TaskSnapshot,
369                ContextBlockPriority::High,
370                ContextBlockStability::RoundDynamic,
371                "Current Task List",
372                "[/] task_1: Fix compression bounce",
373            ),
374            ContextBlock::new(
375                ContextBlockType::ExternalMemory,
376                ContextBlockPriority::Medium,
377                ContextBlockStability::RoundDynamic,
378                "External Memory (Persistent)",
379                "Session note body",
380            ),
381        ]);
382        let messages = vec![
383            Message::user("继续做压缩修复"),
384            Message::assistant("我先检查 trigger 与 target", None),
385        ];
386
387        let prompt_messages = summarizer.build_summarization_messages(&messages);
388        assert_eq!(prompt_messages.len(), 2);
389        assert_eq!(prompt_messages[0].role, Role::System);
390        assert!(prompt_messages[1]
391            .content
392            .contains("## Compression Context Blocks"));
393        assert!(prompt_messages[1].content.contains("Current Task List"));
394        assert!(prompt_messages[1]
395            .content
396            .contains("External Memory (Persistent)"));
397        assert!(prompt_messages[1]
398            .content
399            .contains("Current active objective"));
400        assert!(prompt_messages[1].content.contains("Requirement checklist"));
401        assert!(prompt_messages[1].content.contains("Active tasks"));
402        assert!(prompt_messages[1].content.contains("Completed tasks"));
403        assert!(prompt_messages[1]
404            .content
405            .contains("Obsolete or superseded tasks"));
406        assert!(prompt_messages[1].content.contains("Earlier summary"));
407    }
408
409    #[derive(Default)]
410    struct ReasoningCaptureProvider {
411        captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
412    }
413
414    #[async_trait]
415    impl LLMProvider for ReasoningCaptureProvider {
416        async fn chat_stream(
417            &self,
418            _messages: &[Message],
419            _tools: &[bamboo_domain::ToolSchema],
420            _max_output_tokens: Option<u32>,
421            _model: &str,
422        ) -> Result<LLMStream, LLMError> {
423            Ok(Box::pin(stream::iter(vec![
424                Ok::<LLMChunk, LLMError>(LLMChunk::Token("captured summary".to_string())),
425                Ok::<LLMChunk, LLMError>(LLMChunk::Done),
426            ])))
427        }
428
429        async fn chat_stream_with_options(
430            &self,
431            messages: &[Message],
432            tools: &[bamboo_domain::ToolSchema],
433            max_output_tokens: Option<u32>,
434            model: &str,
435            options: Option<&LLMRequestOptions>,
436        ) -> Result<LLMStream, LLMError> {
437            self.captured_reasoning
438                .lock()
439                .expect("captured reasoning lock should not be poisoned")
440                .push(options.and_then(|o| o.reasoning_effort));
441            self.chat_stream(messages, tools, max_output_tokens, model)
442                .await
443        }
444    }
445
446    #[tokio::test]
447    async fn llm_summarizer_requests_high_reasoning_effort_for_summary_calls() {
448        let provider = Arc::new(ReasoningCaptureProvider::default());
449        let summarizer = LlmSummarizer::new(
450            provider.clone(),
451            "gpt-5-mini".to_string(),
452            None,
453            Some("task list".to_string()),
454        );
455        let messages = vec![
456            Message::user("请总结最近三轮"),
457            Message::assistant("已完成第一步并准备第二步", None),
458        ];
459
460        let summary = summarizer
461            .summarize(&messages)
462            .await
463            .expect("summary generation should succeed");
464        assert_eq!(summary, "captured summary");
465
466        let captured = provider
467            .captured_reasoning
468            .lock()
469            .expect("captured reasoning lock should not be poisoned");
470        assert_eq!(captured.as_slice(), [Some(ReasoningEffort::High)]);
471    }
472
473    /// Provider that captures both `reasoning_effort` and `max_output_tokens`.
474    #[derive(Default)]
475    struct RequestOptionsCaptureProvider {
476        captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
477        captured_max_tokens: Mutex<Vec<Option<u32>>>,
478    }
479
480    #[async_trait]
481    impl LLMProvider for RequestOptionsCaptureProvider {
482        async fn chat_stream(
483            &self,
484            _messages: &[Message],
485            _tools: &[bamboo_domain::ToolSchema],
486            _max_output_tokens: Option<u32>,
487            _model: &str,
488        ) -> Result<LLMStream, LLMError> {
489            Ok(Box::pin(stream::iter(vec![
490                Ok::<LLMChunk, LLMError>(LLMChunk::Token("captured summary".to_string())),
491                Ok::<LLMChunk, LLMError>(LLMChunk::Done),
492            ])))
493        }
494
495        async fn chat_stream_with_options(
496            &self,
497            messages: &[Message],
498            tools: &[bamboo_domain::ToolSchema],
499            max_output_tokens: Option<u32>,
500            model: &str,
501            options: Option<&LLMRequestOptions>,
502        ) -> Result<LLMStream, LLMError> {
503            self.captured_reasoning
504                .lock()
505                .expect("lock should not be poisoned")
506                .push(options.and_then(|o| o.reasoning_effort));
507            self.captured_max_tokens
508                .lock()
509                .expect("lock should not be poisoned")
510                .push(max_output_tokens);
511            self.chat_stream(messages, tools, max_output_tokens, model)
512                .await
513        }
514    }
515
516    #[tokio::test]
517    async fn llm_summarizer_sufficient_max_tokens_for_high_reasoning() {
518        let provider = Arc::new(RequestOptionsCaptureProvider::default());
519        let summarizer = LlmSummarizer::new(
520            provider.clone(),
521            "gpt-5-mini".to_string(),
522            None,
523            Some("task list".to_string()),
524        );
525        let messages = vec![
526            Message::user("请总结最近三轮"),
527            Message::assistant("已完成第一步并准备第二步", None),
528        ];
529
530        let summary = summarizer
531            .summarize(&messages)
532            .await
533            .expect("summary generation should succeed");
534        assert_eq!(summary, "captured summary");
535
536        let captured_reasoning = provider
537            .captured_reasoning
538            .lock()
539            .expect("lock should not be poisoned");
540        let captured_max_tokens = provider
541            .captured_max_tokens
542            .lock()
543            .expect("lock should not be poisoned");
544        assert_eq!(captured_reasoning.as_slice(), [Some(ReasoningEffort::High)]);
545        let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
546        // ReasoningEffort::High targets 4096 thinking budget; max_tokens must leave room for output.
547        assert!(
548            max_tokens > 4096,
549            "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
550            max_tokens
551        );
552    }
553
554    #[test]
555    fn full_rewrite_mode_uses_default_system_prompt() {
556        let summarizer =
557            LlmSummarizer::new(Arc::new(DummyProvider), "model".to_string(), None, None)
558                .with_summary_mode(SummaryMode::FullRewrite);
559        let messages = vec![Message::user("hello"), Message::assistant("hi", None)];
560        let prompts = summarizer.build_summarization_messages(&messages);
561        let system = &prompts[0].content;
562        assert!(
563            system.contains("conversation summarizer"),
564            "FullRewrite prompt should contain 'conversation summarizer'"
565        );
566        assert!(
567            !system.contains("updating an existing"),
568            "FullRewrite prompt should not contain incremental language"
569        );
570    }
571
572    #[test]
573    fn incremental_merge_mode_uses_update_system_prompt() {
574        let summarizer = LlmSummarizer::new(
575            Arc::new(DummyProvider),
576            "model".to_string(),
577            Some("Previous summary content".to_string()),
578            None,
579        )
580        .with_summary_mode(SummaryMode::IncrementalMerge);
581        let messages = vec![Message::user("hello"), Message::assistant("hi", None)];
582        let prompts = summarizer.build_summarization_messages(&messages);
583        let system = &prompts[0].content;
584        assert!(
585            system.contains("updating an existing conversation summary"),
586            "IncrementalMerge prompt should contain 'updating an existing conversation summary'"
587        );
588        assert!(
589            system.contains("Incorporate new information"),
590            "IncrementalMerge prompt should mention incorporating new information"
591        );
592    }
593
594    #[test]
595    fn default_summary_mode_is_full_rewrite() {
596        assert!(matches!(SummaryMode::default(), SummaryMode::FullRewrite));
597    }
598
599    #[test]
600    fn incremental_merge_includes_existing_summary_in_user_content() {
601        let summarizer = LlmSummarizer::new(
602            Arc::new(DummyProvider),
603            "model".to_string(),
604            Some("Previous summary content".to_string()),
605            None,
606        )
607        .with_summary_mode(SummaryMode::IncrementalMerge);
608        let messages = vec![
609            Message::user("new work"),
610            Message::assistant("doing it", None),
611        ];
612        let prompts = summarizer.build_summarization_messages(&messages);
613        let user_content = &prompts[1].content;
614        assert!(
615            user_content.contains("Previous Summary"),
616            "IncrementalMerge user prompt should include the existing summary"
617        );
618        assert!(
619            user_content.contains("Previous summary content"),
620            "IncrementalMerge user prompt should include the actual summary text"
621        );
622    }
623}