1use std::ops::Range;
32use std::sync::Arc;
33
34use async_trait::async_trait;
35
36use katu_core::compaction::{
37 CompactTrigger, CompactionConfig, CompactionResult, PreserveConfig,
38};
39use katu_core::message::{AssistantBlock, ContentBlock, Message};
40
41use katu_llm::model::ModelRef;
42use katu_llm::Provider;
43
44use crate::error::Result;
45use crate::session::Session;
46
47#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct PruneOutcome {
64 pub tokens_freed: u64,
66 pub parts_pruned: usize,
68}
69
70impl PruneOutcome {
71 pub fn none() -> Self {
73 Self {
74 tokens_freed: 0,
75 parts_pruned: 0,
76 }
77 }
78
79 pub fn has_effect(&self) -> bool {
81 self.parts_pruned > 0
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct MessagePartition {
103 pub to_summarize: Range<usize>,
105 pub to_preserve: Range<usize>,
107 pub previous_summary: Option<String>,
109}
110
111impl MessagePartition {
112 pub fn summarize_count(&self) -> usize {
114 self.to_summarize.len()
115 }
116
117 pub fn preserve_count(&self) -> usize {
119 self.to_preserve.len()
120 }
121
122 pub fn has_work(&self) -> bool {
124 !self.to_summarize.is_empty()
125 }
126}
127
128#[derive(Debug, Clone)]
153pub struct CompactionState {
154 consecutive_failures: u32,
156 last_compact_step: Option<u32>,
158 last_compact_tokens: Option<u64>,
160}
161
162impl CompactionState {
163 pub fn new() -> Self {
165 Self {
166 consecutive_failures: 0,
167 last_compact_step: None,
168 last_compact_tokens: None,
169 }
170 }
171
172 pub fn record_success(&mut self) {
174 self.consecutive_failures = 0;
175 }
176
177 pub fn record_failure(&mut self) {
179 self.consecutive_failures += 1;
180 }
181
182 pub fn is_circuit_broken(&self, max_failures: u32) -> bool {
186 max_failures > 0 && self.consecutive_failures >= max_failures
187 }
188
189 pub fn mark_compacted(&mut self, step: u32, tokens_after: Option<u64>) {
191 self.last_compact_step = Some(step);
192 self.last_compact_tokens = tokens_after;
193 }
194
195 pub fn already_compacted_at(&self, step: u32) -> bool {
197 self.last_compact_step == Some(step)
198 }
199
200 pub fn consecutive_failures(&self) -> u32 {
202 self.consecutive_failures
203 }
204
205 pub fn last_compact_tokens(&self) -> Option<u64> {
207 self.last_compact_tokens
208 }
209
210 pub fn reset(&mut self) {
212 *self = Self::new();
213 }
214}
215
216impl Default for CompactionState {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[async_trait]
250pub trait Compactor: Send + Sync {
251 async fn prune(&self, session: &mut Session) -> Result<PruneOutcome>;
257
258 fn partition(&self, session: &Session) -> MessagePartition;
262
263 async fn compact(
271 &self,
272 session: &mut Session,
273 trigger: CompactTrigger,
274 ) -> Result<CompactionResult>;
275}
276
277pub struct DefaultCompactor {
297 provider: Arc<dyn Provider>,
299 model: ModelRef,
301}
302
303impl DefaultCompactor {
304 pub fn new(provider: Arc<dyn Provider>, model: ModelRef) -> Self {
306 Self { provider, model }
307 }
308
309 pub fn provider(&self) -> &Arc<dyn Provider> {
311 &self.provider
312 }
313
314 pub fn model(&self) -> &ModelRef {
316 &self.model
317 }
318}
319
320impl std::fmt::Debug for DefaultCompactor {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 f.debug_struct("DefaultCompactor")
323 .field("model_id", &self.model.id)
324 .finish_non_exhaustive()
325 }
326}
327
328fn estimate_block_tokens(block: &ContentBlock) -> u64 {
334 match block {
335 ContentBlock::Text { text } => text.len() as u64 / 4,
336 ContentBlock::Image { .. } => 1_000, }
338}
339
340fn is_protected_tool(tool_name: &str, protected: &[String]) -> bool {
342 protected.iter().any(|p| p == tool_name)
343}
344
345fn find_preserve_start(
356 messages: &[Message],
357 preserve: &PreserveConfig,
358 _context_window: u64,
359 _reserve_tokens: u64,
360) -> usize {
361 if messages.is_empty() {
362 return 0;
363 }
364
365 let recent_turns = preserve.recent_turns as usize;
366 if recent_turns == 0 {
367 return messages.len();
368 }
369
370 let mut user_count = 0;
372 let mut cut = messages.len();
373 for i in (0..messages.len()).rev() {
374 if matches!(&messages[i], Message::User(_)) {
375 user_count += 1;
376 if user_count >= recent_turns {
377 cut = i;
378 break;
379 }
380 }
381 }
382
383 while cut > 0 && matches!(&messages[cut], Message::ToolResult(_)) {
386 cut -= 1;
387 }
388
389 if cut >= messages.len() {
391 for i in (0..messages.len()).rev() {
393 if matches!(&messages[i], Message::User(_)) {
394 cut = i;
395 break;
396 }
397 }
398 }
399
400 cut
401}
402
403#[async_trait]
408impl Compactor for DefaultCompactor {
409 async fn prune(&self, session: &mut Session) -> Result<PruneOutcome> {
410 let config = session.compaction_config().prune.clone();
411 if !config.enabled {
412 return Ok(PruneOutcome::none());
413 }
414
415 let messages = session.message_slice();
416 let recent_turns = session.compaction_config().preserve.recent_turns as usize;
417
418 let mut user_count = 0;
420 let mut prune_boundary = messages.len();
421 for i in (0..messages.len()).rev() {
422 if matches!(&messages[i], Message::User(_)) {
423 user_count += 1;
424 if user_count >= recent_turns {
425 prune_boundary = i;
426 break;
427 }
428 }
429 }
430
431 let mut protected_tokens: u64 = 0;
433 let mut prunable_tokens: u64 = 0;
434 let mut to_prune: Vec<usize> = Vec::new();
435
436 for i in (0..prune_boundary).rev() {
437 if let Message::ToolResult(ref tr) = messages[i] {
438 if is_protected_tool(&tr.tool_name, &config.protected_tools) {
439 continue;
440 }
441
442 let msg_tokens: u64 = tr.content.iter().map(estimate_block_tokens).sum();
443
444 if protected_tokens < config.protect_tokens {
445 protected_tokens += msg_tokens;
446 } else {
447 let total_chars: usize = tr.content.iter().map(|b| match b {
449 ContentBlock::Text { text } => text.len(),
450 ContentBlock::Image { .. } => 0,
451 }).sum();
452
453 if total_chars > config.tool_output_max_chars {
454 prunable_tokens += msg_tokens;
455 to_prune.push(i);
456 }
457 }
458 }
459 }
460
461 if prunable_tokens < config.minimum_tokens {
463 return Ok(PruneOutcome::none());
464 }
465
466 let truncation_msg = format!(
468 "[内容已修剪 - 原文超过 {} 字符]",
469 config.tool_output_max_chars
470 );
471
472 for &idx in &to_prune {
473 session.truncate_tool_result(idx, &truncation_msg, config.tool_output_max_chars);
474 }
475
476 Ok(PruneOutcome {
477 tokens_freed: prunable_tokens,
478 parts_pruned: to_prune.len(),
479 })
480 }
481
482 fn partition(&self, session: &Session) -> MessagePartition {
483 let messages = session.message_slice();
484 let config = session.compaction_config();
485
486 let previous_summary = detect_previous_summary(messages);
488
489 let cut = find_preserve_start(
490 messages,
491 &config.preserve,
492 session.context_window(),
493 config.reserve_tokens,
494 );
495
496 MessagePartition {
497 to_summarize: 0..cut,
498 to_preserve: cut..messages.len(),
499 previous_summary,
500 }
501 }
502
503 async fn compact(
504 &self,
505 session: &mut Session,
506 trigger: CompactTrigger,
507 ) -> Result<CompactionResult> {
508 let partition = self.partition(session);
509
510 if !partition.has_work() {
511 return Ok(CompactionResult {
512 summary: String::new(),
513 short_summary: None,
514 trigger,
515 tokens_before: session.context_tokens(),
516 tokens_after: Some(session.context_tokens()),
517 messages_compacted: 0,
518 messages_kept: session.message_slice().len(),
519 success: true,
520 });
521 }
522
523 let tokens_before = session.context_tokens();
524 let messages_to_summarize = &session.message_slice()[partition.to_summarize.clone()];
525 let messages_compacted = partition.summarize_count();
526 let messages_kept = partition.preserve_count();
527
528 let summary_prompt = build_summary_prompt(
530 messages_to_summarize,
531 partition.previous_summary.as_deref(),
532 session.compaction_config(),
533 );
534
535 let summary_request = katu_llm::LlmRequest::new(self.model.clone())
537 .with_system(COMPACTION_SYSTEM_PROMPT)
538 .with_message(Message::user(summary_prompt));
539
540 let response = match self.provider.generate(summary_request).await {
541 Ok(resp) => resp,
542 Err(e) => {
543 return Ok(CompactionResult {
544 summary: format!("压缩失败: {e}"),
545 short_summary: None,
546 trigger,
547 tokens_before,
548 tokens_after: None,
549 messages_compacted: 0,
550 messages_kept: session.message_slice().len(),
551 success: false,
552 });
553 }
554 };
555
556 let summary = extract_text_from_message(&response.message);
558
559 let preserved = session.message_slice()[partition.to_preserve.clone()].to_vec();
561 let mut new_messages = Vec::with_capacity(1 + preserved.len());
562
563 let summary_content = format!(
565 "<context_summary>\n{}\n</context_summary>\n\n以上是之前对话的摘要,请基于此上下文继续。",
566 summary
567 );
568 new_messages.push(Message::user(summary_content));
569 new_messages.extend(preserved);
570
571 session.replace_messages(new_messages);
572
573 Ok(CompactionResult {
574 summary: summary.clone(),
575 short_summary: None, trigger,
577 tokens_before,
578 tokens_after: None, messages_compacted,
580 messages_kept,
581 success: true,
582 })
583 }
584}
585
586const COMPACTION_SYSTEM_PROMPT: &str = "\
592你是一个对话摘要助手。你的任务是将一段对话历史压缩为简洁但信息完整的摘要。
593
594要求:
5951. 保留所有关键决策、代码变更、文件路径和技术细节
5962. 保留用户的偏好和约束
5973. 保留未完成的任务和待办事项
5984. 省略重复的探索过程和已解决的中间问题
5995. 使用结构化格式(标题 + 要点列表)
6006. 如果有文件操作,列出最终状态而非中间步骤";
601
602fn detect_previous_summary(messages: &[Message]) -> Option<String> {
604 if let Some(Message::User(user_msg)) = messages.first() {
605 let text = user_msg.content.text();
606 if text.contains("<context_summary>") && text.contains("</context_summary>") {
607 if let Some(start) = text.find("<context_summary>") {
609 let content_start = start + "<context_summary>".len();
610 if let Some(end) = text[content_start..].find("</context_summary>") {
611 return Some(text[content_start..content_start + end].trim().to_string());
612 }
613 }
614 }
615 }
616 None
617}
618
619fn extract_text_from_message(message: &Message) -> String {
621 match message {
622 Message::Assistant(a) => a.text(),
623 Message::User(u) => u.content.text(),
624 Message::ToolResult(t) => {
625 t.content
626 .iter()
627 .filter_map(|b| match b {
628 ContentBlock::Text { text } => Some(text.as_str()),
629 _ => None,
630 })
631 .collect::<Vec<_>>()
632 .join("\n")
633 }
634 }
635}
636
637fn build_summary_prompt(
639 messages: &[Message],
640 previous_summary: Option<&str>,
641 _config: &CompactionConfig,
642) -> String {
643 let mut prompt = String::with_capacity(4096);
644
645 if let Some(prev) = previous_summary {
646 prompt.push_str("## 上次摘要\n\n");
647 prompt.push_str(prev);
648 prompt.push_str("\n\n## 新增对话(需要整合到摘要中)\n\n");
649 } else {
650 prompt.push_str("## 对话历史(需要压缩为摘要)\n\n");
651 }
652
653 for msg in messages {
654 match msg {
655 Message::User(u) => {
656 prompt.push_str("**User**: ");
657 prompt.push_str(&u.content.text());
658 prompt.push('\n');
659 }
660 Message::Assistant(a) => {
661 prompt.push_str("**Assistant**: ");
662 let text = a.text();
664 if text.len() > 2000 {
665 prompt.push_str(&text[..2000]);
666 prompt.push_str("...[截断]");
667 } else {
668 prompt.push_str(&text);
669 }
670 prompt.push('\n');
671
672 for block in a.tool_calls() {
674 if let AssistantBlock::ToolCall { name, arguments, .. } = block {
675 prompt.push_str(&format!(" → tool_call: {}({})\n", name, arguments));
676 }
677 }
678 }
679 Message::ToolResult(t) => {
680 let content = t.content.iter().filter_map(|b| match b {
681 ContentBlock::Text { text } => Some(text.as_str()),
682 _ => None,
683 }).collect::<Vec<_>>().join("");
684
685 if content.len() > 500 {
687 prompt.push_str(&format!(
688 " ← {}: {}...[截断]\n",
689 t.tool_name,
690 &content[..500]
691 ));
692 } else {
693 prompt.push_str(&format!(" ← {}: {}\n", t.tool_name, content));
694 }
695 }
696 }
697 }
698
699 prompt.push_str("\n请生成压缩摘要。");
700 prompt
701}
702
703#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
714 fn test_prune_outcome_none() {
715 let outcome = PruneOutcome::none();
716 assert_eq!(outcome.tokens_freed, 0);
717 assert_eq!(outcome.parts_pruned, 0);
718 assert!(!outcome.has_effect());
719 }
720
721 #[test]
722 fn test_prune_outcome_with_effect() {
723 let outcome = PruneOutcome {
724 tokens_freed: 5000,
725 parts_pruned: 3,
726 };
727 assert!(outcome.has_effect());
728 }
729
730 #[test]
733 fn test_partition_counts() {
734 let partition = MessagePartition {
735 to_summarize: 0..10,
736 to_preserve: 10..15,
737 previous_summary: None,
738 };
739 assert_eq!(partition.summarize_count(), 10);
740 assert_eq!(partition.preserve_count(), 5);
741 assert!(partition.has_work());
742 }
743
744 #[test]
745 fn test_partition_empty_summarize() {
746 let partition = MessagePartition {
747 to_summarize: 0..0,
748 to_preserve: 0..5,
749 previous_summary: None,
750 };
751 assert!(!partition.has_work());
752 }
753
754 #[test]
757 fn test_compaction_state_new() {
758 let state = CompactionState::new();
759 assert_eq!(state.consecutive_failures(), 0);
760 assert!(!state.is_circuit_broken(3));
761 assert!(!state.already_compacted_at(0));
762 }
763
764 #[test]
765 fn test_compaction_state_circuit_breaker() {
766 let mut state = CompactionState::new();
767 state.record_failure();
768 state.record_failure();
769 assert!(!state.is_circuit_broken(3));
770 state.record_failure();
771 assert!(state.is_circuit_broken(3));
772 }
773
774 #[test]
775 fn test_compaction_state_reset_on_success() {
776 let mut state = CompactionState::new();
777 state.record_failure();
778 state.record_failure();
779 state.record_success();
780 assert_eq!(state.consecutive_failures(), 0);
781 assert!(!state.is_circuit_broken(3));
782 }
783
784 #[test]
785 fn test_compaction_state_no_limit() {
786 let mut state = CompactionState::new();
787 for _ in 0..100 {
788 state.record_failure();
789 }
790 assert!(!state.is_circuit_broken(0));
792 }
793
794 #[test]
795 fn test_compaction_state_mark_compacted() {
796 let mut state = CompactionState::new();
797 assert!(!state.already_compacted_at(5));
798 state.mark_compacted(5, Some(50_000));
799 assert!(state.already_compacted_at(5));
800 assert!(!state.already_compacted_at(6));
801 assert_eq!(state.last_compact_tokens(), Some(50_000));
802 }
803
804 #[test]
807 fn test_detect_no_summary() {
808 let messages = vec![Message::user("hello")];
809 assert_eq!(detect_previous_summary(&messages), None);
810 }
811
812 #[test]
813 fn test_detect_has_summary() {
814 let summary_msg = "<context_summary>\nPrevious work done\n</context_summary>\n\n以上是之前对话的摘要,请基于此上下文继续。";
815 let messages = vec![Message::user(summary_msg)];
816 assert_eq!(
817 detect_previous_summary(&messages),
818 Some("Previous work done".to_string())
819 );
820 }
821
822 #[test]
825 fn test_find_preserve_start_empty() {
826 let messages: Vec<Message> = vec![];
827 let preserve = PreserveConfig::default();
828 assert_eq!(find_preserve_start(&messages, &preserve, 200_000, 16_384), 0);
829 }
830
831 #[test]
832 fn test_find_preserve_start_keeps_recent_turns() {
833 let messages = vec![
834 Message::user("q1"),
835 Message::assistant("a1"),
836 Message::user("q2"),
837 Message::assistant("a2"),
838 Message::user("q3"),
839 Message::assistant("a3"),
840 ];
841 let preserve = PreserveConfig::new(2, 100_000);
843 let cut = find_preserve_start(&messages, &preserve, 200_000, 16_384);
844 assert_eq!(cut, 2); }
846
847 #[test]
850 fn test_estimate_block_tokens_text() {
851 let block = ContentBlock::Text { text: "a".repeat(400) };
852 assert_eq!(estimate_block_tokens(&block), 100);
853 }
854}