1use roboticus_llm::format::UnifiedMessage;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum CompactionStage {
13 Verbatim,
15 SelectiveTrim,
17 SemanticCompress,
19 TopicExtract,
21 Skeleton,
23}
24
25impl CompactionStage {
26 pub fn from_excess(excess_ratio: f64) -> Self {
31 if excess_ratio <= 1.0 {
32 Self::Verbatim
33 } else if excess_ratio <= 1.5 {
34 Self::SelectiveTrim
35 } else if excess_ratio <= 2.5 {
36 Self::SemanticCompress
37 } else if excess_ratio <= 4.0 {
38 Self::TopicExtract
39 } else {
40 Self::Skeleton
41 }
42 }
43}
44
45pub fn compact_to_stage(
49 messages: &[UnifiedMessage],
50 stage: CompactionStage,
51) -> Vec<UnifiedMessage> {
52 match stage {
53 CompactionStage::Verbatim => messages.to_vec(),
54 CompactionStage::SelectiveTrim => selective_trim(messages),
55 CompactionStage::SemanticCompress => semantic_compress(messages),
56 CompactionStage::TopicExtract => topic_extract(messages),
57 CompactionStage::Skeleton => skeleton_compress(messages),
58 }
59}
60
61fn selective_trim(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
63 const FILLER: &[&str] = &[
64 "hello",
65 "hi",
66 "hey",
67 "thanks",
68 "thank you",
69 "ok",
70 "okay",
71 "sure",
72 "got it",
73 "sounds good",
74 "no problem",
75 "np",
76 "ack",
77 "roger",
78 ];
79 messages
80 .iter()
81 .filter(|m| {
82 if m.role == "system" {
83 return true;
84 }
85 if m.content.len() >= 40 {
87 return true;
88 }
89 let lower = m.content.trim().to_lowercase();
90 !FILLER.contains(&lower.as_str())
93 })
94 .cloned()
95 .collect()
96}
97
98fn semantic_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
100 use roboticus_llm::compression::PromptCompressor;
101 let compressor = PromptCompressor::new(0.6);
102 messages
103 .iter()
104 .map(|m| {
105 if m.role == "system" || m.content.len() < 100 {
106 m.clone()
107 } else {
108 UnifiedMessage {
109 role: m.role.clone(),
110 content: compressor.compress(&m.content),
111 parts: None,
112 }
113 }
114 })
115 .collect()
116}
117
118fn topic_extract(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
120 messages
121 .iter()
122 .map(|m| {
123 if m.role == "system" {
124 m.clone()
125 } else {
126 UnifiedMessage {
127 role: m.role.clone(),
128 content: extract_topic_sentence(&m.content),
129 parts: None,
130 }
131 }
132 })
133 .collect()
134}
135
136fn skeleton_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
138 let topics: Vec<String> = messages
139 .iter()
140 .filter(|m| m.role != "system")
141 .map(|m| {
142 let topic = extract_topic_sentence(&m.content);
143 format!("[{}] {}", m.role, topic)
144 })
145 .filter(|line| line.len() > 10)
146 .collect();
147
148 if topics.is_empty() {
149 return messages
150 .iter()
151 .filter(|m| m.role == "system")
152 .cloned()
153 .collect();
154 }
155
156 let mut result: Vec<UnifiedMessage> = messages
157 .iter()
158 .filter(|m| m.role == "system")
159 .cloned()
160 .collect();
161 result.push(UnifiedMessage {
162 role: "assistant".into(),
163 content: format!("[Conversation Skeleton]\n{}", topics.join("\n")),
164 parts: None,
165 });
166 result
167}
168
169fn extract_topic_sentence(text: &str) -> String {
171 let end = text
172 .find(". ")
173 .or_else(|| text.find(".\n"))
174 .or_else(|| text.find('?'))
175 .or_else(|| text.find('!'))
176 .map(|i| i + 1)
177 .unwrap_or_else(|| text.len().min(120));
178 text[..end.min(text.len())].trim().to_string()
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum ComplexityLevel {
185 L0,
186 L1,
187 L2,
188 L3,
189}
190
191pub fn determine_level(complexity_score: f64) -> ComplexityLevel {
192 if complexity_score < 0.3 {
193 ComplexityLevel::L0
194 } else if complexity_score < 0.6 {
195 ComplexityLevel::L1
196 } else if complexity_score < 0.9 {
197 ComplexityLevel::L2
198 } else {
199 ComplexityLevel::L3
200 }
201}
202
203pub fn token_budget(level: ComplexityLevel) -> usize {
204 match level {
205 ComplexityLevel::L0 => 4_000,
206 ComplexityLevel::L1 => 8_000,
207 ComplexityLevel::L2 => 16_000,
208 ComplexityLevel::L3 => 32_000,
209 }
210}
211
212fn estimate_tokens(text: &str) -> usize {
214 text.len().div_ceil(4)
215}
216
217pub fn build_context(
219 level: ComplexityLevel,
220 system_prompt: &str,
221 memories: &str,
222 history: &[UnifiedMessage],
223) -> Vec<UnifiedMessage> {
224 let budget = token_budget(level);
225 let mut used = 0usize;
226 let mut messages = Vec::new();
227
228 let sys_tokens = estimate_tokens(system_prompt);
233 if sys_tokens <= budget {
234 messages.push(UnifiedMessage {
235 role: "system".into(),
236 content: system_prompt.to_string(),
237 parts: None,
238 });
239 used += sys_tokens;
240 } else {
241 let max_chars = budget.saturating_mul(4);
245 let truncated: String = system_prompt.chars().take(max_chars).collect();
246 let truncated_tokens = estimate_tokens(&truncated);
247 messages.push(UnifiedMessage {
248 role: "system".into(),
249 content: truncated,
250 parts: None,
251 });
252 used += truncated_tokens;
253 tracing::warn!(
254 sys_tokens,
255 budget,
256 "system prompt exceeds budget — truncated to fit"
257 );
258 }
259
260 if !memories.is_empty() {
261 let mem_tokens = estimate_tokens(memories);
262 if used + mem_tokens <= budget {
263 messages.push(UnifiedMessage {
264 role: "system".into(),
265 content: memories.to_string(),
266 parts: None,
267 });
268 used += mem_tokens;
269 }
270 }
271
272 let mut history_buf: Vec<&UnifiedMessage> = Vec::new();
273 let mut history_tokens = 0usize;
274
275 for msg in history.iter().rev() {
276 let msg_tokens = estimate_tokens(&msg.content);
277 if used + history_tokens + msg_tokens > budget {
278 break;
279 }
280 history_tokens += msg_tokens;
281 history_buf.push(msg);
282 }
283
284 history_buf.reverse();
285 for msg in history_buf {
286 messages.push(msg.clone());
287 }
288
289 let prune_cfg = PruningConfig {
292 max_tokens: budget,
293 soft_trim_ratio: 1.0,
294 ..PruningConfig::default()
295 };
296 if needs_pruning(&messages, &prune_cfg) {
297 return soft_trim(&messages, &prune_cfg).messages;
298 }
299
300 messages
301}
302
303pub fn inject_instruction_reminder(messages: &mut Vec<UnifiedMessage>, reminder: &str) -> bool {
318 let non_system_turns = messages.iter().filter(|m| m.role != "system").count();
319 if non_system_turns < crate::prompt::ANTI_FADE_TURN_THRESHOLD {
320 return false;
321 }
322
323 let insert_pos = messages
328 .iter()
329 .rposition(|m| m.role == "user")
330 .unwrap_or(messages.len());
331
332 messages.insert(
333 insert_pos,
334 UnifiedMessage {
335 role: "user".into(),
336 content: format!("[System Note] {reminder}"),
337 parts: None,
338 },
339 );
340 true
341}
342
343#[derive(Debug, Clone)]
344pub struct PruningConfig {
345 pub max_tokens: usize,
346 pub soft_trim_ratio: f64,
347 pub hard_clear_ratio: f64,
348 pub preserve_recent: usize,
349}
350
351impl Default for PruningConfig {
352 fn default() -> Self {
353 Self {
354 max_tokens: 128_000,
355 soft_trim_ratio: 0.8,
356 hard_clear_ratio: 0.95,
357 preserve_recent: 10,
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
363pub struct PruningResult {
364 pub messages: Vec<UnifiedMessage>,
365 pub trimmed_count: usize,
366 pub compaction_summary: Option<String>,
367 pub total_tokens: usize,
368}
369
370pub fn count_tokens(messages: &[UnifiedMessage]) -> usize {
371 messages.iter().map(|m| estimate_tokens(&m.content)).sum()
372}
373
374pub fn needs_pruning(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
375 let tokens = count_tokens(messages);
376 tokens > ((config.max_tokens as f64 * config.soft_trim_ratio) as usize)
377}
378
379pub fn needs_hard_clear(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
380 let tokens = count_tokens(messages);
381 tokens > ((config.max_tokens as f64 * config.hard_clear_ratio) as usize)
382}
383
384pub fn soft_trim(messages: &[UnifiedMessage], config: &PruningConfig) -> PruningResult {
386 let target_tokens = (config.max_tokens as f64 * config.soft_trim_ratio) as usize;
387
388 let system_msgs: Vec<_> = messages
389 .iter()
390 .filter(|m| m.role == "system")
391 .cloned()
392 .collect();
393
394 let non_system: Vec<_> = messages
395 .iter()
396 .filter(|m| m.role != "system")
397 .cloned()
398 .collect();
399
400 let preserve_count = config.preserve_recent.min(non_system.len());
401 let preserved = &non_system[non_system.len().saturating_sub(preserve_count)..];
402
403 let mut result: Vec<UnifiedMessage> = system_msgs;
404 let system_tokens = count_tokens(&result);
405
406 let mut available = target_tokens.saturating_sub(system_tokens);
407 let mut kept = Vec::new();
408
409 for msg in preserved.iter().rev() {
410 let msg_tokens = estimate_tokens(&msg.content);
411 if msg_tokens <= available {
412 kept.push(msg.clone());
413 available = available.saturating_sub(msg_tokens);
414 }
415 }
418 kept.reverse();
419
420 let trimmed_count = non_system.len() - kept.len();
421 result.extend(kept);
422
423 let total_tokens = count_tokens(&result);
424
425 PruningResult {
426 messages: result,
427 trimmed_count,
428 compaction_summary: None,
429 total_tokens,
430 }
431}
432
433pub fn extract_trimmable(
435 messages: &[UnifiedMessage],
436 config: &PruningConfig,
437) -> Vec<UnifiedMessage> {
438 let non_system: Vec<_> = messages
439 .iter()
440 .filter(|m| m.role != "system")
441 .cloned()
442 .collect();
443
444 let preserve_count = config.preserve_recent.min(non_system.len());
445 let trim_end = non_system.len().saturating_sub(preserve_count);
446
447 non_system[..trim_end].to_vec()
448}
449
450pub fn build_compaction_prompt(trimmed: &[UnifiedMessage]) -> String {
452 let mut prompt = String::from(
453 "Summarize the following conversation history into a concise paragraph. \
454 Capture key facts, decisions, and context. Do not include greetings or filler.\n\n",
455 );
456
457 for msg in trimmed {
458 prompt.push_str(&format!("{}: {}\n", msg.role, msg.content));
459 }
460
461 prompt
462}
463
464pub fn compress_context(messages: &mut [UnifiedMessage], target_ratio: f64) {
470 use roboticus_llm::compression::PromptCompressor;
471
472 let compressor = PromptCompressor::new(target_ratio);
473
474 let last_user_idx = messages.iter().rposition(|m| m.role == "user");
476
477 for (i, msg) in messages.iter_mut().enumerate() {
478 if Some(i) == last_user_idx {
479 continue; }
481 if msg.content.len() < 200 {
483 continue;
484 }
485 msg.content = compressor.compress(&msg.content);
486 }
487}
488
489pub fn insert_compaction_summary(messages: &mut Vec<UnifiedMessage>, summary: String) {
491 let insert_pos = messages
492 .iter()
493 .position(|m| m.role != "system")
494 .unwrap_or(messages.len());
495
496 messages.insert(
497 insert_pos,
498 UnifiedMessage {
499 role: "system".into(),
500 content: format!("[Conversation Summary] {summary}"),
501 parts: None,
502 },
503 );
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn level_determination() {
512 assert_eq!(determine_level(0.0), ComplexityLevel::L0);
513 assert_eq!(determine_level(0.29), ComplexityLevel::L0);
514 assert_eq!(determine_level(0.3), ComplexityLevel::L1);
515 assert_eq!(determine_level(0.59), ComplexityLevel::L1);
516 assert_eq!(determine_level(0.6), ComplexityLevel::L2);
517 assert_eq!(determine_level(0.89), ComplexityLevel::L2);
518 assert_eq!(determine_level(0.9), ComplexityLevel::L3);
519 assert_eq!(determine_level(1.0), ComplexityLevel::L3);
520 }
521
522 #[test]
523 fn budget_values() {
524 assert_eq!(token_budget(ComplexityLevel::L0), 4_000);
525 assert_eq!(token_budget(ComplexityLevel::L1), 8_000);
526 assert_eq!(token_budget(ComplexityLevel::L2), 16_000);
527 assert_eq!(token_budget(ComplexityLevel::L3), 32_000);
528 }
529
530 #[test]
531 fn context_assembly_respects_budget() {
532 let sys = "You are a helpful agent.";
533 let mem = "User prefers concise answers.";
534 let history = vec![
535 UnifiedMessage {
536 role: "user".into(),
537 content: "Hello".into(),
538 parts: None,
539 },
540 UnifiedMessage {
541 role: "assistant".into(),
542 content: "Hi there!".into(),
543 parts: None,
544 },
545 ];
546
547 let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
548
549 assert!(!ctx.is_empty());
550 assert_eq!(ctx[0].role, "system");
551 assert_eq!(ctx[0].content, sys);
552
553 let total_chars: usize = ctx.iter().map(|m| m.content.len()).sum();
554 let total_tokens = total_chars.div_ceil(4);
555 assert!(total_tokens <= token_budget(ComplexityLevel::L0));
556 }
557
558 #[test]
559 fn context_truncates_old_history() {
560 let sys = "System prompt";
561 let mem = "";
562 let big_msg = "x".repeat(8000);
563 let history = vec![
564 UnifiedMessage {
565 role: "user".into(),
566 content: big_msg,
567 parts: None,
568 },
569 UnifiedMessage {
570 role: "user".into(),
571 content: "recent message".into(),
572 parts: None,
573 },
574 ];
575
576 let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
577 assert!(ctx.len() >= 2);
578 assert_eq!(ctx.last().unwrap().content, "recent message");
579 }
580
581 #[test]
582 fn pruning_config_defaults() {
583 let cfg = PruningConfig::default();
584 assert_eq!(cfg.max_tokens, 128_000);
585 assert_eq!(cfg.soft_trim_ratio, 0.8);
586 assert_eq!(cfg.hard_clear_ratio, 0.95);
587 assert_eq!(cfg.preserve_recent, 10);
588 }
589
590 #[test]
591 fn count_tokens_basic() {
592 let msgs = vec![UnifiedMessage {
593 role: "user".into(),
594 content: "hello world".into(),
595 parts: None,
596 }];
597 let tokens = count_tokens(&msgs);
598 assert!(tokens > 0);
599 assert_eq!(tokens, estimate_tokens("hello world"));
600 }
601
602 #[test]
603 fn needs_pruning_under_threshold() {
604 let msgs = vec![UnifiedMessage {
605 role: "user".into(),
606 content: "short".into(),
607 parts: None,
608 }];
609 let cfg = PruningConfig::default();
610 assert!(!needs_pruning(&msgs, &cfg));
611 }
612
613 #[test]
614 fn needs_pruning_over_threshold() {
615 let big = "x".repeat(500_000);
616 let msgs = vec![UnifiedMessage {
617 role: "user".into(),
618 content: big,
619 parts: None,
620 }];
621 let cfg = PruningConfig::default();
622 assert!(needs_pruning(&msgs, &cfg));
623 }
624
625 #[test]
626 fn soft_trim_preserves_recent() {
627 let mut msgs = Vec::new();
628 msgs.push(UnifiedMessage {
629 role: "system".into(),
630 content: "sys".into(),
631 parts: None,
632 });
633 for i in 0..20 {
634 msgs.push(UnifiedMessage {
635 role: if i % 2 == 0 { "user" } else { "assistant" }.into(),
636 content: format!("message {i}"),
637 parts: None,
638 });
639 }
640
641 let cfg = PruningConfig {
642 max_tokens: 200,
643 soft_trim_ratio: 0.8,
644 preserve_recent: 5,
645 ..Default::default()
646 };
647
648 let result = soft_trim(&msgs, &cfg);
649 assert!(result.messages[0].role == "system");
650 assert!(result.trimmed_count > 0);
651 let last = result.messages.last().unwrap();
652 assert_eq!(last.content, "message 19");
653 }
654
655 #[test]
656 fn extract_trimmable_gets_old_messages() {
657 let mut msgs = Vec::new();
658 msgs.push(UnifiedMessage {
659 role: "system".into(),
660 content: "sys".into(),
661 parts: None,
662 });
663 for i in 0..10 {
664 msgs.push(UnifiedMessage {
665 role: "user".into(),
666 content: format!("msg {i}"),
667 parts: None,
668 });
669 }
670
671 let cfg = PruningConfig {
672 preserve_recent: 3,
673 ..Default::default()
674 };
675 let trimmed = extract_trimmable(&msgs, &cfg);
676 assert_eq!(trimmed.len(), 7);
677 assert_eq!(trimmed[0].content, "msg 0");
678 }
679
680 #[test]
681 fn build_compaction_prompt_format() {
682 let msgs = vec![
683 UnifiedMessage {
684 role: "user".into(),
685 content: "hi".into(),
686 parts: None,
687 },
688 UnifiedMessage {
689 role: "assistant".into(),
690 content: "hello".into(),
691 parts: None,
692 },
693 ];
694 let prompt = build_compaction_prompt(&msgs);
695 assert!(prompt.contains("Summarize"));
696 assert!(prompt.contains("user: hi"));
697 assert!(prompt.contains("assistant: hello"));
698 }
699
700 #[test]
701 fn insert_compaction_summary_placement() {
702 let mut msgs = vec![
703 UnifiedMessage {
704 role: "system".into(),
705 content: "sys".into(),
706 parts: None,
707 },
708 UnifiedMessage {
709 role: "user".into(),
710 content: "hi".into(),
711 parts: None,
712 },
713 ];
714 insert_compaction_summary(&mut msgs, "summary here".into());
715 assert_eq!(msgs.len(), 3);
716 assert_eq!(msgs[0].role, "system");
717 assert_eq!(msgs[1].role, "system");
718 assert!(msgs[1].content.contains("summary here"));
719 assert_eq!(msgs[2].role, "user");
720 }
721
722 #[test]
723 fn needs_hard_clear_under_threshold() {
724 let msgs = vec![UnifiedMessage {
725 role: "user".into(),
726 content: "short".into(),
727 parts: None,
728 }];
729 let cfg = PruningConfig::default();
730 assert!(!needs_hard_clear(&msgs, &cfg));
731 }
732
733 #[test]
734 fn needs_hard_clear_over_threshold() {
735 let big = "y".repeat(500_000);
737 let msgs = vec![UnifiedMessage {
738 role: "user".into(),
739 content: big,
740 parts: None,
741 }];
742 let cfg = PruningConfig::default();
743 assert!(needs_hard_clear(&msgs, &cfg));
744 }
745
746 #[test]
747 fn insert_compaction_summary_no_system_messages() {
748 let mut msgs = vec![
750 UnifiedMessage {
751 role: "user".into(),
752 content: "hello".into(),
753 parts: None,
754 },
755 UnifiedMessage {
756 role: "assistant".into(),
757 content: "hi".into(),
758 parts: None,
759 },
760 ];
761 insert_compaction_summary(&mut msgs, "compacted info".into());
762 assert_eq!(msgs.len(), 3);
763 assert_eq!(msgs[0].role, "system");
764 assert!(msgs[0].content.contains("compacted info"));
765 assert_eq!(msgs[1].role, "user");
766 }
767
768 #[test]
769 fn insert_compaction_summary_all_system_messages() {
770 let mut msgs = vec![
772 UnifiedMessage {
773 role: "system".into(),
774 content: "sys1".into(),
775 parts: None,
776 },
777 UnifiedMessage {
778 role: "system".into(),
779 content: "sys2".into(),
780 parts: None,
781 },
782 ];
783 insert_compaction_summary(&mut msgs, "final summary".into());
784 assert_eq!(msgs.len(), 3);
785 assert_eq!(msgs[2].role, "system");
787 assert!(msgs[2].content.contains("final summary"));
788 }
789
790 #[test]
791 fn build_context_sys_prompt_exceeds_budget() {
792 let big_sys = "z".repeat(20_000);
794 let mem = "";
795 let history = vec![UnifiedMessage {
796 role: "user".into(),
797 content: "hi".into(),
798 parts: None,
799 }];
800
801 let ctx = build_context(ComplexityLevel::L0, &big_sys, mem, &history);
802 assert!(!ctx.is_empty());
804 assert_eq!(ctx[0].role, "system");
805 assert!(ctx[0].content.len() < big_sys.len());
807 assert!(!ctx[0].content.is_empty());
809 }
810
811 #[test]
812 fn build_context_empty_history() {
813 let sys = "Agent prompt";
814 let mem = "Memory info";
815 let history: Vec<UnifiedMessage> = vec![];
816
817 let ctx = build_context(ComplexityLevel::L1, sys, mem, &history);
818 assert_eq!(ctx.len(), 2); assert_eq!(ctx[0].content, sys);
820 assert_eq!(ctx[1].content, mem);
821 }
822
823 #[test]
824 fn soft_trim_no_non_system_messages() {
825 let msgs = vec![UnifiedMessage {
826 role: "system".into(),
827 content: "sys".into(),
828 parts: None,
829 }];
830 let cfg = PruningConfig {
831 max_tokens: 200,
832 preserve_recent: 5,
833 ..Default::default()
834 };
835 let result = soft_trim(&msgs, &cfg);
836 assert_eq!(result.messages.len(), 1);
837 assert_eq!(result.trimmed_count, 0);
838 }
839
840 #[test]
841 fn extract_trimmable_fewer_than_preserve() {
842 let msgs = vec![UnifiedMessage {
843 role: "user".into(),
844 content: "only one".into(),
845 parts: None,
846 }];
847 let cfg = PruningConfig {
848 preserve_recent: 5,
849 ..Default::default()
850 };
851 let trimmed = extract_trimmable(&msgs, &cfg);
852 assert!(
853 trimmed.is_empty(),
854 "nothing to trim if fewer than preserve_recent"
855 );
856 }
857
858 #[test]
859 fn count_tokens_empty() {
860 assert_eq!(count_tokens(&[]), 0);
861 }
862
863 #[test]
866 fn compaction_stage_from_excess_boundaries() {
867 assert_eq!(CompactionStage::from_excess(0.5), CompactionStage::Verbatim);
868 assert_eq!(CompactionStage::from_excess(1.0), CompactionStage::Verbatim);
869 assert_eq!(
870 CompactionStage::from_excess(1.01),
871 CompactionStage::SelectiveTrim
872 );
873 assert_eq!(
874 CompactionStage::from_excess(1.5),
875 CompactionStage::SelectiveTrim
876 );
877 assert_eq!(
878 CompactionStage::from_excess(1.51),
879 CompactionStage::SemanticCompress
880 );
881 assert_eq!(
882 CompactionStage::from_excess(2.5),
883 CompactionStage::SemanticCompress
884 );
885 assert_eq!(
886 CompactionStage::from_excess(2.51),
887 CompactionStage::TopicExtract
888 );
889 assert_eq!(
890 CompactionStage::from_excess(4.0),
891 CompactionStage::TopicExtract
892 );
893 assert_eq!(
894 CompactionStage::from_excess(4.01),
895 CompactionStage::Skeleton
896 );
897 assert_eq!(
898 CompactionStage::from_excess(100.0),
899 CompactionStage::Skeleton
900 );
901 }
902
903 #[test]
904 fn compaction_stage_ordering() {
905 assert!(CompactionStage::Verbatim < CompactionStage::SelectiveTrim);
906 assert!(CompactionStage::SelectiveTrim < CompactionStage::SemanticCompress);
907 assert!(CompactionStage::SemanticCompress < CompactionStage::TopicExtract);
908 assert!(CompactionStage::TopicExtract < CompactionStage::Skeleton);
909 }
910
911 #[test]
912 fn selective_trim_removes_filler() {
913 let msgs = vec![
914 UnifiedMessage {
915 role: "system".into(),
916 content: "sys prompt".into(),
917 parts: None,
918 },
919 UnifiedMessage {
920 role: "user".into(),
921 content: "hello".into(),
922 parts: None,
923 },
924 UnifiedMessage {
925 role: "assistant".into(),
926 content: "ok".into(),
927 parts: None,
928 },
929 UnifiedMessage {
930 role: "user".into(),
931 content: "Please analyze the data and find anomalies in the revenue stream".into(),
932 parts: None,
933 },
934 UnifiedMessage {
935 role: "assistant".into(),
936 content: "thanks".into(),
937 parts: None,
938 },
939 ];
940 let result = selective_trim(&msgs);
941 assert_eq!(result.len(), 2);
944 assert_eq!(result[0].role, "system");
945 assert!(result[1].content.contains("analyze the data"));
946 }
947
948 #[test]
949 fn selective_trim_keeps_all_long_messages() {
950 let msgs = vec![
951 UnifiedMessage {
952 role: "user".into(),
953 content: "This is a long enough message that should never be trimmed away".into(),
954 parts: None,
955 },
956 UnifiedMessage {
957 role: "assistant".into(),
958 content: "I agree, this response is also long enough to stay around".into(),
959 parts: None,
960 },
961 ];
962 let result = selective_trim(&msgs);
963 assert_eq!(result.len(), 2);
964 }
965
966 #[test]
967 fn topic_extract_takes_first_sentence() {
968 let msgs = vec![
969 UnifiedMessage {
970 role: "system".into(),
971 content: "You are helpful.".into(),
972 parts: None,
973 },
974 UnifiedMessage {
975 role: "user".into(),
976 content:
977 "Deploy the model to production. Then run the test suite. Finally update docs."
978 .into(),
979 parts: None,
980 },
981 ];
982 let result = topic_extract(&msgs);
983 assert_eq!(result.len(), 2);
984 assert_eq!(result[0].content, "You are helpful."); assert_eq!(result[1].content, "Deploy the model to production."); }
987
988 #[test]
989 fn skeleton_compress_creates_outline() {
990 let msgs = vec![
991 UnifiedMessage {
992 role: "system".into(),
993 content: "System prompt".into(),
994 parts: None,
995 },
996 UnifiedMessage {
997 role: "user".into(),
998 content: "How does authentication work in this app?".into(),
999 parts: None,
1000 },
1001 UnifiedMessage {
1002 role: "assistant".into(),
1003 content: "Authentication uses JWT tokens with a 24-hour expiry. The flow starts at the login endpoint.".into(),
1004 parts: None,
1005 },
1006 ];
1007 let result = skeleton_compress(&msgs);
1008 assert_eq!(result.len(), 2);
1010 assert_eq!(result[0].content, "System prompt");
1011 assert_eq!(result[1].role, "assistant");
1012 assert!(result[1].content.contains("[Conversation Skeleton]"));
1013 assert!(result[1].content.contains("[user]"));
1014 assert!(result[1].content.contains("[assistant]"));
1015 }
1016
1017 #[test]
1018 fn skeleton_compress_empty_non_system() {
1019 let msgs = vec![UnifiedMessage {
1020 role: "system".into(),
1021 content: "sys".into(),
1022 parts: None,
1023 }];
1024 let result = skeleton_compress(&msgs);
1025 assert_eq!(result.len(), 1);
1026 assert_eq!(result[0].role, "system");
1027 }
1028
1029 #[test]
1030 fn compact_to_stage_verbatim_is_identity() {
1031 let msgs = vec![
1032 UnifiedMessage {
1033 role: "user".into(),
1034 content: "test".into(),
1035 parts: None,
1036 },
1037 UnifiedMessage {
1038 role: "assistant".into(),
1039 content: "resp".into(),
1040 parts: None,
1041 },
1042 ];
1043 let result = compact_to_stage(&msgs, CompactionStage::Verbatim);
1044 assert_eq!(result.len(), msgs.len());
1045 assert_eq!(result[0].content, "test");
1046 assert_eq!(result[1].content, "resp");
1047 }
1048
1049 #[test]
1050 fn compact_to_stage_dispatches_correctly() {
1051 let msgs = vec![
1052 UnifiedMessage {
1053 role: "user".into(),
1054 content: "hi".into(),
1055 parts: None,
1056 },
1057 UnifiedMessage {
1058 role: "user".into(),
1059 content: "Analyze the market data and identify trends in revenue growth over Q3"
1060 .into(),
1061 parts: None,
1062 },
1063 ];
1064 let trimmed = compact_to_stage(&msgs, CompactionStage::SelectiveTrim);
1066 assert_eq!(trimmed.len(), 1);
1067 assert!(trimmed[0].content.contains("Analyze"));
1068 }
1069
1070 #[test]
1071 fn extract_topic_sentence_with_period() {
1072 assert_eq!(
1073 extract_topic_sentence("First sentence. Second sentence. Third."),
1074 "First sentence."
1075 );
1076 }
1077
1078 #[test]
1079 fn extract_topic_sentence_with_question() {
1080 assert_eq!(
1081 extract_topic_sentence("What is this? More details here."),
1082 "What is this?"
1083 );
1084 }
1085
1086 #[test]
1087 fn extract_topic_sentence_no_punctuation() {
1088 let short = "Just some text without ending";
1089 assert_eq!(extract_topic_sentence(short), short);
1090 }
1091
1092 #[test]
1093 fn extract_topic_sentence_very_long() {
1094 let long = "x".repeat(200);
1095 let result = extract_topic_sentence(&long);
1096 assert!(result.len() <= 120);
1097 }
1098
1099 fn make_msg(role: &str, content: &str) -> UnifiedMessage {
1102 UnifiedMessage {
1103 role: role.into(),
1104 content: content.into(),
1105 parts: None,
1106 }
1107 }
1108
1109 #[test]
1110 fn inject_reminder_skips_short_conversations() {
1111 let mut msgs = vec![
1112 make_msg("system", "You are helpful."),
1113 make_msg("user", "Hello"),
1114 make_msg("assistant", "Hi!"),
1115 make_msg("user", "How are you?"),
1116 make_msg("assistant", "Good, thanks!"),
1117 ];
1118 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Be helpful.");
1120 assert!(!injected);
1121 assert_eq!(msgs.len(), 5);
1122 }
1123
1124 #[test]
1125 fn inject_reminder_fires_for_long_conversations() {
1126 let mut msgs = vec![make_msg("system", "You are helpful.")];
1127 for i in 0..10 {
1129 msgs.push(make_msg("user", &format!("question {i}")));
1130 msgs.push(make_msg("assistant", &format!("answer {i}")));
1131 }
1132 let len_before = msgs.len();
1133 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Always be thorough.");
1134 assert!(injected);
1135 assert_eq!(msgs.len(), len_before + 1);
1136
1137 let reminder_idx = msgs
1140 .iter()
1141 .rposition(|m| m.content.contains("[System Note]"))
1142 .unwrap();
1143 assert_eq!(msgs[reminder_idx].role, "user");
1144 assert!(
1145 msgs[reminder_idx]
1146 .content
1147 .contains("[Reminder] Always be thorough.")
1148 );
1149 }
1150
1151 #[test]
1152 fn inject_reminder_places_before_last_user_message() {
1153 let mut msgs = vec![make_msg("system", "System prompt.")];
1154 for i in 0..5 {
1155 msgs.push(make_msg("user", &format!("q{i}")));
1156 msgs.push(make_msg("assistant", &format!("a{i}")));
1157 }
1158 msgs.push(make_msg("user", "final question"));
1160
1161 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Key directive.");
1162 assert!(injected);
1163
1164 assert_eq!(msgs.last().unwrap().content, "final question");
1166 assert_eq!(msgs.last().unwrap().role, "user");
1167
1168 let second_last = &msgs[msgs.len() - 2];
1170 assert_eq!(second_last.role, "user");
1171 assert!(second_last.content.contains("[System Note]"));
1172 assert!(second_last.content.contains("[Reminder]"));
1173 }
1174
1175 #[test]
1176 fn inject_reminder_no_user_messages_appends_at_end() {
1177 let mut msgs = vec![make_msg("system", "System prompt.")];
1178 for i in 0..10 {
1180 msgs.push(make_msg("assistant", &format!("response {i}")));
1181 }
1182 let len_before = msgs.len();
1183 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Test.");
1184 assert!(injected);
1185 assert_eq!(msgs.len(), len_before + 1);
1187 assert_eq!(
1188 msgs.last().unwrap().content,
1189 "[System Note] [Reminder] Test."
1190 );
1191 assert_eq!(msgs.last().unwrap().role, "user");
1192 }
1193}