1use async_trait::async_trait;
7use bamboo_agent_core::{
8 ContextBlock, ContextBlockPriority, ContextBlockStability, ContextBlockType, Message, Role,
9};
10use bamboo_domain::ReasoningEffort;
11use bamboo_infrastructure::LLMChunk;
12use bamboo_infrastructure::{LLMProvider, LLMRequestOptions};
13use futures::StreamExt;
14use std::collections::HashSet;
15use std::sync::Arc;
16
17#[async_trait]
19pub trait Summarizer: Send + Sync {
20 async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError>;
24
25 fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
29 (message_count * 50).min(1000) as u32
31 }
32}
33
34#[derive(Debug, Default)]
43pub struct HeuristicSummarizer;
44
45impl HeuristicSummarizer {
46 pub fn new() -> Self {
48 Self
49 }
50
51 fn extract_user_questions<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
53 messages
54 .iter()
55 .filter(|m| m.role == Role::User)
56 .filter(|m| !m.content.is_empty())
57 .take(10) .map(|m| m.content.as_str())
59 .collect()
60 }
61
62 fn extract_tools_used(&self, messages: &[Message]) -> Vec<String> {
64 let mut tools = HashSet::new();
65
66 for message in messages {
67 if let Some(ref tool_calls) = message.tool_calls {
68 for call in tool_calls {
69 tools.insert(call.function.name.clone());
70 }
71 }
72 }
73
74 let mut result: Vec<String> = tools.into_iter().collect();
75 result.sort();
76 result
77 }
78
79 fn extract_key_responses<'a>(&self, messages: &'a [Message]) -> Vec<&'a str> {
81 messages
82 .iter()
83 .filter(|m| m.role == Role::Assistant)
84 .filter(|m| !m.content.is_empty())
85 .rev() .take(3)
87 .map(|m| m.content.as_str())
88 .collect()
89 }
90
91 fn safe_truncate(&self, s: &str, max_chars: usize) -> String {
94 if s.chars().count() <= max_chars {
95 return s.to_string();
96 }
97
98 let truncated: String = s.chars().take(max_chars).collect();
100 format!("{}...", truncated)
101 }
102}
103
104#[async_trait]
105impl Summarizer for HeuristicSummarizer {
106 async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError> {
107 if messages.is_empty() {
108 return Ok("No conversation history.".to_string());
109 }
110
111 let questions = self.extract_user_questions(messages);
112 let tools = self.extract_tools_used(messages);
113 let responses = self.extract_key_responses(messages);
114
115 let mut summary_parts = Vec::new();
116
117 if !questions.is_empty() {
119 summary_parts.push("## User Requests".to_string());
120 for (i, q) in questions.iter().enumerate() {
121 let truncated = self.safe_truncate(q, 200);
123 summary_parts.push(format!("{}. {}", i + 1, truncated));
124 }
125 }
126
127 if !tools.is_empty() {
129 summary_parts.push("\n## Tools Used".to_string());
130 for tool in tools {
131 summary_parts.push(format!("- {}", tool));
132 }
133 }
134
135 if !responses.is_empty() {
137 summary_parts.push("\n## Key Outcomes".to_string());
138 for (i, r) in responses.iter().enumerate() {
139 let truncated = self.safe_truncate(r, 300);
141 summary_parts.push(format!("{}. {}", i + 1, truncated));
142 }
143 }
144
145 if summary_parts.is_empty() {
146 Ok("Previous conversation context available.".to_string())
147 } else {
148 Ok(summary_parts.join("\n"))
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub enum SummaryTrigger {
156 OnTruncation,
158 Periodic { interval: usize },
160 TokenThreshold { threshold: u32 },
162}
163
164pub struct SummaryManager {
166 summarizer: Box<dyn Summarizer>,
167 trigger: SummaryTrigger,
168}
169
170impl std::fmt::Debug for SummaryManager {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("SummaryManager")
173 .field("trigger", &self.trigger)
174 .finish_non_exhaustive()
175 }
176}
177
178impl SummaryManager {
179 pub fn new(summarizer: impl Summarizer + 'static, trigger: SummaryTrigger) -> Self {
181 Self {
182 summarizer: Box::new(summarizer),
183 trigger,
184 }
185 }
186
187 pub fn should_summarize(
189 &self,
190 messages: &[Message],
191 _truncation_occurred: bool,
192 current_token_count: u32,
193 ) -> bool {
194 match &self.trigger {
195 SummaryTrigger::OnTruncation => _truncation_occurred,
196 SummaryTrigger::Periodic { interval } => messages.len() >= *interval,
197 SummaryTrigger::TokenThreshold { threshold } => current_token_count >= *threshold,
198 }
199 }
200
201 pub async fn summarize(
203 &self,
204 messages: &[Message],
205 ) -> Result<String, crate::types::BudgetError> {
206 self.summarizer.summarize(messages).await
207 }
208
209 pub fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
211 self.summarizer.estimate_summary_tokens(message_count)
212 }
213}
214
215#[derive(Debug, Clone, Default)]
217pub enum SummaryMode {
218 #[default]
220 FullRewrite,
221 IncrementalMerge,
223}
224
225pub struct LlmSummarizer {
230 llm: Arc<dyn LLMProvider>,
231 model: String,
232 existing_summary: Option<String>,
234 context_blocks: Vec<ContextBlock>,
236 custom_instructions: Option<String>,
238 summary_mode: SummaryMode,
240}
241
242impl LlmSummarizer {
243 pub fn new(
244 llm: Arc<dyn LLMProvider>,
245 model: String,
246 existing_summary: Option<String>,
247 task_list_prompt: Option<String>,
248 ) -> Self {
249 let context_blocks = task_list_prompt
250 .as_deref()
251 .map(str::trim)
252 .filter(|value| !value.is_empty())
253 .map(|task_list| {
254 vec![ContextBlock::new(
255 ContextBlockType::TaskSnapshot,
256 ContextBlockPriority::High,
257 ContextBlockStability::RoundDynamic,
258 "Current Task List",
259 task_list,
260 )]
261 })
262 .unwrap_or_default();
263
264 Self {
265 llm,
266 model,
267 existing_summary,
268 context_blocks,
269 custom_instructions: None,
270 summary_mode: SummaryMode::default(),
271 }
272 }
273
274 pub fn with_context_blocks(mut self, context_blocks: Vec<ContextBlock>) -> Self {
275 self.context_blocks = context_blocks;
276 self
277 }
278
279 pub fn with_custom_instructions(mut self, instructions: Option<String>) -> Self {
280 self.custom_instructions = instructions;
281 self
282 }
283
284 pub fn with_summary_mode(mut self, mode: SummaryMode) -> Self {
285 self.summary_mode = mode;
286 self
287 }
288
289 fn build_summarization_messages(&self, messages: &[Message]) -> Vec<Message> {
291 let mut prompt_messages = Vec::new();
292
293 let system_prompt = match self.summary_mode {
294 SummaryMode::FullRewrite => {
295 r#"You are a conversation summarizer. Your task is to create a concise but reliable working-memory summary for a conversation that was removed due to context window limits.
296
297Guidelines:
298- First capture the in-flight work right before compression (what was being done, where, and with which tool/file)
299- Distinguish clearly between CURRENT ACTIVE work, COMPLETED work, and OBSOLETE or superseded work
300- Do not restate old tasks as active unless they are still unresolved
301- The provided current task list is the source of truth for active work
302- Preserve key decisions, constraints, file paths, code changes, tool findings, blockers, and important outcomes
303- Preserve error messages, test results (pass/fail counts), and function/variable names that are relevant to active work
304- If earlier plans conflict with newer messages or the current task list, mark them as obsolete or completed
305- Explicitly evaluate each clear user requirement (e.g. requirement 1, requirement 2) with a status and evidence
306- Keep the next step specific and aligned with the active work only
307- Use structured sections
308- Write in the same language as the original conversation"#
309 }
310 SummaryMode::IncrementalMerge => {
311 r#"You are updating an existing conversation summary with new information from recent messages.
312
313Guidelines:
314- Incorporate new information into the existing summary structure
315- Mark previously active work as completed if the new messages confirm completion
316- Remove or condense information that is no longer relevant
317- Preserve all key decisions, file paths, and constraints that remain active
318- If new messages conflict with the existing summary, the new messages take precedence
319- Keep the summary focused on what is currently active and relevant
320- The provided current task list is the source of truth for active work
321- Maintain the same structured sections as the existing summary
322- Write in the same language as the original conversation
323- Be concise: avoid repeating information already well-captured in the existing summary"#
324 }
325 };
326
327 prompt_messages.push(Message::system(system_prompt));
328
329 let mut user_content = String::new();
330
331 if let Some(ref existing) = self.existing_summary {
332 user_content.push_str("## Previous Summary\n\n");
333 user_content.push_str(existing);
334 user_content.push_str("\n\n---\n\n");
335 }
336
337 if !self.context_blocks.is_empty() {
338 user_content.push_str("## Compression Context Blocks\n\n");
339 for block in &self.context_blocks {
340 user_content.push_str(&format!(
341 "### {}\n- type: {}\n- priority: {}\n- stability: {}\n\n{}\n\n",
342 block.title.trim(),
343 block.block_type.as_str(),
344 block.priority.as_str(),
345 block.stability.as_str(),
346 block.content.trim(),
347 ));
348 }
349 user_content.push_str("---\n\n");
350 }
351
352 if let Some(ref instructions) = self.custom_instructions {
353 if !instructions.trim().is_empty() {
354 user_content.push_str("## Custom Compression Instructions\n\n");
355 user_content.push_str(instructions.trim());
356 user_content.push_str("\n\n---\n\n");
357 }
358 }
359
360 user_content.push_str(
361 "## Required Output Sections\n1. Pre-compression in-flight work (what was being done immediately before compression)\n2. Current active objective\n3. Requirement checklist (Requirement | Status: completed/in_progress/pending/blocked/obsolete | Evidence)\n4. Active tasks\n5. Completed tasks\n6. Obsolete or superseded tasks\n7. Important context and constraints\n8. Files, code, and tool findings\n9. Open issues and next step\n\n",
362 );
363
364 user_content.push_str("## Messages to Summarize\n\n");
365
366 for message in messages {
367 let role_label = match message.role {
368 Role::User => "User",
369 Role::Assistant => "Assistant",
370 Role::Tool => "Tool Result",
371 Role::System => continue,
372 };
373
374 if let Some(ref tool_calls) = message.tool_calls {
375 if !tool_calls.is_empty() {
376 let tool_names: Vec<&str> = tool_calls
377 .iter()
378 .map(|tc| tc.function.name.as_str())
379 .collect();
380 user_content.push_str(&format!(
381 "**{}** [called tools: {}]:\n",
382 role_label,
383 tool_names.join(", ")
384 ));
385 } else {
386 user_content.push_str(&format!("**{}**:\n", role_label));
387 }
388 } else {
389 user_content.push_str(&format!("**{}**:\n", role_label));
390 }
391
392 if let Some(ref tool_call_id) = message.tool_call_id {
393 user_content.push_str(&format!("(tool_call_id: {})\n", tool_call_id));
394 }
395
396 let content = &message.content;
397 const MAX_CONTENT_CHARS: usize = 2000;
398 if content.chars().count() > MAX_CONTENT_CHARS {
399 let truncated: String = content.chars().take(MAX_CONTENT_CHARS).collect();
400 user_content.push_str(&truncated);
401 user_content.push_str("... [truncated]\n\n");
402 } else {
403 user_content.push_str(content);
404 user_content.push_str("\n\n");
405 }
406 }
407
408 user_content.push_str(
409 "\n---\n\nReturn only the summary text. Be explicit about what is active now versus what is already completed or no longer relevant.",
410 );
411
412 prompt_messages.push(Message::user(user_content));
413
414 prompt_messages
415 }
416
417 async fn collect_stream_response(
419 &self,
420 messages: &[Message],
421 ) -> Result<String, crate::types::BudgetError> {
422 let options = LLMRequestOptions {
425 session_id: None,
426 reasoning_effort: Some(ReasoningEffort::High),
427 parallel_tool_calls: None,
428 responses: None,
429 request_purpose: Some("compression".to_string()),
430 cache: None,
431 };
432 let stream = self
433 .llm
434 .chat_stream_with_options(messages, &[], Some(8192), &self.model, Some(&options))
435 .await
436 .map_err(|e| {
437 crate::types::BudgetError::TokenCountError(format!(
438 "LLM summarization call failed: {}",
439 e
440 ))
441 })?;
442
443 let mut content = String::new();
444 let mut stream = stream;
445
446 while let Some(chunk_result) = stream.next().await {
447 match chunk_result {
448 Ok(LLMChunk::Token(text)) => content.push_str(&text),
449 Ok(LLMChunk::Done) => break,
450 Ok(_) => {} Err(e) => {
452 tracing::warn!("LLM summarization stream error: {}", e);
453 if !content.is_empty() {
454 break;
455 }
456 return Err(crate::types::BudgetError::TokenCountError(format!(
457 "LLM summarization stream failed: {}",
458 e
459 )));
460 }
461 }
462 }
463
464 Ok(content)
465 }
466}
467
468impl std::fmt::Debug for LlmSummarizer {
469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470 f.debug_struct("LlmSummarizer")
471 .field("model", &self.model)
472 .field("has_existing_summary", &self.existing_summary.is_some())
473 .field("context_block_count", &self.context_blocks.len())
474 .finish()
475 }
476}
477
478#[async_trait]
479impl Summarizer for LlmSummarizer {
480 async fn summarize(&self, messages: &[Message]) -> Result<String, crate::types::BudgetError> {
481 if messages.is_empty() {
482 return Ok("No conversation history to summarize.".to_string());
483 }
484
485 let prompt_messages = self.build_summarization_messages(messages);
486
487 tracing::info!(
488 "LlmSummarizer: summarizing {} messages using model '{}' (existing_summary={})",
489 messages.len(),
490 self.model,
491 self.existing_summary.is_some()
492 );
493
494 match self.collect_stream_response(&prompt_messages).await {
495 Ok(summary) if !summary.trim().is_empty() => {
496 tracing::info!("LlmSummarizer: generated summary ({} chars)", summary.len());
497 Ok(summary)
498 }
499 Ok(_) => {
500 tracing::warn!(
501 "LlmSummarizer: LLM returned empty summary, falling back to heuristic"
502 );
503 HeuristicSummarizer::new().summarize(messages).await
504 }
505 Err(e) => {
506 tracing::warn!(
507 "LlmSummarizer: LLM call failed ({}), falling back to heuristic",
508 e
509 );
510 HeuristicSummarizer::new().summarize(messages).await
511 }
512 }
513 }
514
515 fn estimate_summary_tokens(&self, message_count: usize) -> u32 {
516 (message_count * 80).min(2000) as u32
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use async_trait::async_trait;
525 use bamboo_domain::ReasoningEffort;
526 use bamboo_infrastructure::{LLMChunk, LLMError, LLMRequestOptions, LLMStream};
527 use futures::stream;
528 use std::sync::Mutex;
529
530 struct DummyProvider;
531
532 #[async_trait]
533 impl LLMProvider for DummyProvider {
534 async fn chat_stream(
535 &self,
536 _messages: &[Message],
537 _tools: &[bamboo_agent_core::ToolSchema],
538 _max_output_tokens: Option<u32>,
539 _model: &str,
540 ) -> Result<LLMStream, LLMError> {
541 Ok(Box::pin(stream::iter(vec![
542 Ok::<LLMChunk, LLMError>(LLMChunk::Token("dummy summary".to_string())),
543 Ok::<LLMChunk, LLMError>(LLMChunk::Done),
544 ])))
545 }
546 }
547
548 #[test]
549 fn heuristic_summarizer_extracts_user_questions() {
550 let summarizer = HeuristicSummarizer::new();
551 let messages = vec![
552 Message::user("What is the weather?"),
553 Message::assistant("It's sunny.", None),
554 Message::user("What about tomorrow?"),
555 ];
556
557 let questions = summarizer.extract_user_questions(&messages);
558 assert_eq!(questions.len(), 2);
559 assert!(questions[0].contains("weather"));
560 }
561
562 #[test]
563 fn heuristic_summarizer_extracts_tools_used() {
564 use bamboo_agent_core::{FunctionCall, ToolCall};
565
566 let summarizer = HeuristicSummarizer::new();
567 let tool_call = ToolCall {
568 id: "call_1".to_string(),
569 tool_type: "function".to_string(),
570 function: FunctionCall {
571 name: "search".to_string(),
572 arguments: "{}".to_string(),
573 },
574 };
575
576 let messages = vec![
577 Message::user("Search for something"),
578 Message::assistant("I'll search", Some(vec![tool_call])),
579 ];
580
581 let tools = summarizer.extract_tools_used(&messages);
582 assert_eq!(tools, vec!["search"]);
583 }
584
585 #[test]
586 fn heuristic_summarizer_extracts_key_responses() {
587 let summarizer = HeuristicSummarizer::new();
588 let messages = vec![
589 Message::user("Hello"),
590 Message::assistant("First response", None),
591 Message::user("How are you?"),
592 Message::assistant("Most recent response", None),
593 ];
594
595 let responses = summarizer.extract_key_responses(&messages);
596 assert_eq!(responses[0], "Most recent response");
598 }
599
600 #[tokio::test]
601 async fn heuristic_summarizer_generates_summary() {
602 let summarizer = HeuristicSummarizer::new();
603 let messages = vec![
604 Message::user("What is Rust?"),
605 Message::assistant("Rust is a systems programming language.", None),
606 ];
607
608 let summary = summarizer.summarize(&messages).await.unwrap();
609 assert!(summary.contains("User Requests"));
610 assert!(summary.contains("What is Rust?"));
611 }
612
613 #[test]
614 fn summary_trigger_on_truncation() {
615 let trigger = SummaryTrigger::OnTruncation;
616
617 assert!(matches!(trigger, SummaryTrigger::OnTruncation));
618 assert!(matches!(trigger, SummaryTrigger::OnTruncation));
620 }
622
623 #[test]
624 fn summary_trigger_periodic() {
625 let trigger = SummaryTrigger::Periodic { interval: 5 };
626 let messages: Vec<Message> = (0..5).map(|_| Message::user("Test")).collect();
627
628 if let SummaryTrigger::Periodic { interval } = trigger {
630 assert_eq!(interval, 5);
631 assert!(messages.len() >= interval);
632 } else {
633 panic!("Expected Periodic trigger");
634 }
635 }
636
637 #[test]
638 fn summary_trigger_token_threshold() {
639 let trigger = SummaryTrigger::TokenThreshold { threshold: 1000 };
640
641 if let SummaryTrigger::TokenThreshold { threshold } = trigger {
643 assert_eq!(threshold, 1000);
644 } else {
645 panic!("Expected TokenThreshold trigger");
646 }
647 }
648
649 #[test]
650 fn safe_truncate_handles_ascii() {
651 let summarizer = HeuristicSummarizer::new();
652 let text = "Hello world this is a test";
653 let truncated = summarizer.safe_truncate(text, 10);
654
655 assert!(truncated.ends_with("..."));
656 assert!(truncated.chars().count() <= 13);
658 }
659
660 #[test]
661 fn safe_truncate_handles_unicode() {
662 let summarizer = HeuristicSummarizer::new();
663
664 let text = "Hello 😀🎉🚀 World with emoji";
666 let truncated = summarizer.safe_truncate(text, 10);
667
668 assert!(truncated.ends_with("..."));
670 assert!(truncated.chars().count() <= 13);
671 }
672
673 #[test]
674 fn safe_truncate_handles_cjk() {
675 let summarizer = HeuristicSummarizer::new();
676
677 let text = "这是一个中文测试消息用于验证截断";
679 let truncated = summarizer.safe_truncate(text, 10);
680
681 assert!(truncated.ends_with("..."));
683 assert!(truncated.chars().count() <= 13);
684 }
685
686 #[test]
687 fn safe_truncate_handles_mixed_unicode() {
688 let summarizer = HeuristicSummarizer::new();
689
690 let text = "Hello 世界 🌍 test message";
692 let truncated = summarizer.safe_truncate(text, 8);
693
694 assert!(truncated.ends_with("..."));
696 assert!(truncated.chars().count() <= 11);
697 }
698
699 #[tokio::test]
700 async fn summarizer_handles_unicode_messages() {
701 let summarizer = HeuristicSummarizer::new();
702
703 let long_unicode =
705 "这是一段很长的中文消息需要被截断以测试我们的安全截断功能 😀🎉🚀".repeat(10);
706 let messages = vec![
707 Message::user(&long_unicode),
708 Message::assistant("Response", None),
709 ];
710
711 let summary = summarizer.summarize(&messages).await.unwrap();
713 assert!(summary.contains("User Requests"));
714 }
715
716 #[test]
717 fn safe_truncate_returns_short_text_unchanged() {
718 let summarizer = HeuristicSummarizer::new();
719 let text = "Short";
720 let truncated = summarizer.safe_truncate(text, 100);
721
722 assert_eq!(truncated, text);
724 }
725
726 #[test]
727 fn llm_summarizer_prompt_includes_context_blocks_and_state_sections() {
728 let summarizer = LlmSummarizer::new(
729 Arc::new(DummyProvider),
730 "gpt-4o-mini".to_string(),
731 Some("Earlier summary".to_string()),
732 Some(
733 "## Current Task List\n[/] task_1: Fix compression bounce\n[x] task_0: Analyze bug"
734 .to_string(),
735 ),
736 )
737 .with_context_blocks(vec![
738 ContextBlock::new(
739 ContextBlockType::TaskSnapshot,
740 ContextBlockPriority::High,
741 ContextBlockStability::RoundDynamic,
742 "Current Task List",
743 "[/] task_1: Fix compression bounce",
744 ),
745 ContextBlock::new(
746 ContextBlockType::ExternalMemory,
747 ContextBlockPriority::Medium,
748 ContextBlockStability::RoundDynamic,
749 "External Memory (Persistent)",
750 "Session note body",
751 ),
752 ]);
753 let messages = vec![
754 Message::user("继续做压缩修复"),
755 Message::assistant("我先检查 trigger 与 target", None),
756 ];
757
758 let prompt_messages = summarizer.build_summarization_messages(&messages);
759 assert_eq!(prompt_messages.len(), 2);
760 assert_eq!(prompt_messages[0].role, Role::System);
761 assert!(prompt_messages[1]
762 .content
763 .contains("## Compression Context Blocks"));
764 assert!(prompt_messages[1].content.contains("Current Task List"));
765 assert!(prompt_messages[1]
766 .content
767 .contains("External Memory (Persistent)"));
768 assert!(prompt_messages[1]
769 .content
770 .contains("Current active objective"));
771 assert!(prompt_messages[1].content.contains("Requirement checklist"));
772 assert!(prompt_messages[1].content.contains("Active tasks"));
773 assert!(prompt_messages[1].content.contains("Completed tasks"));
774 assert!(prompt_messages[1]
775 .content
776 .contains("Obsolete or superseded tasks"));
777 assert!(prompt_messages[1].content.contains("Earlier summary"));
778 }
779
780 #[derive(Default)]
781 struct ReasoningCaptureProvider {
782 captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
783 }
784
785 #[async_trait]
786 impl LLMProvider for ReasoningCaptureProvider {
787 async fn chat_stream(
788 &self,
789 _messages: &[Message],
790 _tools: &[bamboo_agent_core::ToolSchema],
791 _max_output_tokens: Option<u32>,
792 _model: &str,
793 ) -> Result<LLMStream, LLMError> {
794 Ok(Box::pin(stream::iter(vec![
795 Ok::<LLMChunk, LLMError>(LLMChunk::Token("captured summary".to_string())),
796 Ok::<LLMChunk, LLMError>(LLMChunk::Done),
797 ])))
798 }
799
800 async fn chat_stream_with_options(
801 &self,
802 messages: &[Message],
803 tools: &[bamboo_agent_core::ToolSchema],
804 max_output_tokens: Option<u32>,
805 model: &str,
806 options: Option<&LLMRequestOptions>,
807 ) -> Result<LLMStream, LLMError> {
808 self.captured_reasoning
809 .lock()
810 .expect("captured reasoning lock should not be poisoned")
811 .push(options.and_then(|o| o.reasoning_effort));
812 self.chat_stream(messages, tools, max_output_tokens, model)
813 .await
814 }
815 }
816
817 #[tokio::test]
818 async fn llm_summarizer_requests_high_reasoning_effort_for_summary_calls() {
819 let provider = Arc::new(ReasoningCaptureProvider::default());
820 let summarizer = LlmSummarizer::new(
821 provider.clone(),
822 "gpt-5-mini".to_string(),
823 None,
824 Some("task list".to_string()),
825 );
826 let messages = vec![
827 Message::user("请总结最近三轮"),
828 Message::assistant("已完成第一步并准备第二步", None),
829 ];
830
831 let summary = summarizer
832 .summarize(&messages)
833 .await
834 .expect("summary generation should succeed");
835 assert_eq!(summary, "captured summary");
836
837 let captured = provider
838 .captured_reasoning
839 .lock()
840 .expect("captured reasoning lock should not be poisoned");
841 assert_eq!(captured.as_slice(), [Some(ReasoningEffort::High)]);
842 }
843
844 #[derive(Default)]
846 struct RequestOptionsCaptureProvider {
847 captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
848 captured_max_tokens: Mutex<Vec<Option<u32>>>,
849 }
850
851 #[async_trait]
852 impl LLMProvider for RequestOptionsCaptureProvider {
853 async fn chat_stream(
854 &self,
855 _messages: &[Message],
856 _tools: &[bamboo_agent_core::ToolSchema],
857 _max_output_tokens: Option<u32>,
858 _model: &str,
859 ) -> Result<LLMStream, LLMError> {
860 Ok(Box::pin(stream::iter(vec![
861 Ok::<LLMChunk, LLMError>(LLMChunk::Token("captured summary".to_string())),
862 Ok::<LLMChunk, LLMError>(LLMChunk::Done),
863 ])))
864 }
865
866 async fn chat_stream_with_options(
867 &self,
868 messages: &[Message],
869 tools: &[bamboo_agent_core::ToolSchema],
870 max_output_tokens: Option<u32>,
871 model: &str,
872 options: Option<&LLMRequestOptions>,
873 ) -> Result<LLMStream, LLMError> {
874 self.captured_reasoning
875 .lock()
876 .expect("lock should not be poisoned")
877 .push(options.and_then(|o| o.reasoning_effort));
878 self.captured_max_tokens
879 .lock()
880 .expect("lock should not be poisoned")
881 .push(max_output_tokens);
882 self.chat_stream(messages, tools, max_output_tokens, model)
883 .await
884 }
885 }
886
887 #[tokio::test]
888 async fn llm_summarizer_sufficient_max_tokens_for_high_reasoning() {
889 let provider = Arc::new(RequestOptionsCaptureProvider::default());
890 let summarizer = LlmSummarizer::new(
891 provider.clone(),
892 "gpt-5-mini".to_string(),
893 None,
894 Some("task list".to_string()),
895 );
896 let messages = vec![
897 Message::user("请总结最近三轮"),
898 Message::assistant("已完成第一步并准备第二步", None),
899 ];
900
901 let summary = summarizer
902 .summarize(&messages)
903 .await
904 .expect("summary generation should succeed");
905 assert_eq!(summary, "captured summary");
906
907 let captured_reasoning = provider
908 .captured_reasoning
909 .lock()
910 .expect("lock should not be poisoned");
911 let captured_max_tokens = provider
912 .captured_max_tokens
913 .lock()
914 .expect("lock should not be poisoned");
915 assert_eq!(captured_reasoning.as_slice(), [Some(ReasoningEffort::High)]);
916 let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
917 assert!(
919 max_tokens > 4096,
920 "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
921 max_tokens
922 );
923 }
924
925 #[test]
926 fn full_rewrite_mode_uses_default_system_prompt() {
927 let summarizer =
928 LlmSummarizer::new(Arc::new(DummyProvider), "model".to_string(), None, None)
929 .with_summary_mode(SummaryMode::FullRewrite);
930 let messages = vec![Message::user("hello"), Message::assistant("hi", None)];
931 let prompts = summarizer.build_summarization_messages(&messages);
932 let system = &prompts[0].content;
933 assert!(
934 system.contains("conversation summarizer"),
935 "FullRewrite prompt should contain 'conversation summarizer'"
936 );
937 assert!(
938 !system.contains("updating an existing"),
939 "FullRewrite prompt should not contain incremental language"
940 );
941 }
942
943 #[test]
944 fn incremental_merge_mode_uses_update_system_prompt() {
945 let summarizer = LlmSummarizer::new(
946 Arc::new(DummyProvider),
947 "model".to_string(),
948 Some("Previous summary content".to_string()),
949 None,
950 )
951 .with_summary_mode(SummaryMode::IncrementalMerge);
952 let messages = vec![Message::user("hello"), Message::assistant("hi", None)];
953 let prompts = summarizer.build_summarization_messages(&messages);
954 let system = &prompts[0].content;
955 assert!(
956 system.contains("updating an existing conversation summary"),
957 "IncrementalMerge prompt should contain 'updating an existing conversation summary'"
958 );
959 assert!(
960 system.contains("Incorporate new information"),
961 "IncrementalMerge prompt should mention incorporating new information"
962 );
963 }
964
965 #[test]
966 fn default_summary_mode_is_full_rewrite() {
967 assert!(matches!(SummaryMode::default(), SummaryMode::FullRewrite));
968 }
969
970 #[test]
971 fn incremental_merge_includes_existing_summary_in_user_content() {
972 let summarizer = LlmSummarizer::new(
973 Arc::new(DummyProvider),
974 "model".to_string(),
975 Some("Previous summary content".to_string()),
976 None,
977 )
978 .with_summary_mode(SummaryMode::IncrementalMerge);
979 let messages = vec![
980 Message::user("new work"),
981 Message::assistant("doing it", None),
982 ];
983 let prompts = summarizer.build_summarization_messages(&messages);
984 let user_content = &prompts[1].content;
985 assert!(
986 user_content.contains("Previous Summary"),
987 "IncrementalMerge user prompt should include the existing summary"
988 );
989 assert!(
990 user_content.contains("Previous summary content"),
991 "IncrementalMerge user prompt should include the actual summary text"
992 );
993 }
994}