1use roboticus_llm::format::UnifiedMessage;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
12pub enum CompactionStage {
13 Verbatim,
15 SelectiveTrim,
17 SemanticCompress,
19 TopicExtract,
21 Skeleton,
23}
24
25impl CompactionStage {
26 pub fn from_excess(excess_ratio: f64) -> Self {
31 if excess_ratio <= 1.0 {
32 Self::Verbatim
33 } else if excess_ratio <= 1.5 {
34 Self::SelectiveTrim
35 } else if excess_ratio <= 2.5 {
36 Self::SemanticCompress
37 } else if excess_ratio <= 4.0 {
38 Self::TopicExtract
39 } else {
40 Self::Skeleton
41 }
42 }
43}
44
45pub fn compact_to_stage(
49 messages: &[UnifiedMessage],
50 stage: CompactionStage,
51) -> Vec<UnifiedMessage> {
52 match stage {
53 CompactionStage::Verbatim => messages.to_vec(),
54 CompactionStage::SelectiveTrim => selective_trim(messages),
55 CompactionStage::SemanticCompress => semantic_compress(messages),
56 CompactionStage::TopicExtract => topic_extract(messages),
57 CompactionStage::Skeleton => skeleton_compress(messages),
58 }
59}
60
61fn selective_trim(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
63 const FILLER: &[&str] = &[
64 "hello",
65 "hi",
66 "hey",
67 "thanks",
68 "thank you",
69 "ok",
70 "okay",
71 "sure",
72 "got it",
73 "sounds good",
74 "no problem",
75 "np",
76 "ack",
77 "roger",
78 ];
79 messages
80 .iter()
81 .filter(|m| {
82 if m.role == "system" {
83 return true;
84 }
85 if m.content.len() >= 40 {
87 return true;
88 }
89 let lower = m.content.trim().to_lowercase();
90 !FILLER.contains(&lower.as_str())
93 })
94 .cloned()
95 .collect()
96}
97
98fn semantic_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
100 use roboticus_llm::compression::PromptCompressor;
101 let compressor = PromptCompressor::new(0.6);
102 messages
103 .iter()
104 .map(|m| {
105 if m.role == "system" || m.content.len() < 100 {
106 m.clone()
107 } else {
108 UnifiedMessage {
109 role: m.role.clone(),
110 content: compressor.compress(&m.content),
111 parts: None,
112 }
113 }
114 })
115 .collect()
116}
117
118fn topic_extract(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
120 messages
121 .iter()
122 .map(|m| {
123 if m.role == "system" {
124 m.clone()
125 } else {
126 UnifiedMessage {
127 role: m.role.clone(),
128 content: extract_topic_sentence(&m.content),
129 parts: None,
130 }
131 }
132 })
133 .collect()
134}
135
136fn skeleton_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
138 let topics: Vec<String> = messages
139 .iter()
140 .filter(|m| m.role != "system")
141 .map(|m| {
142 let topic = extract_topic_sentence(&m.content);
143 format!("[{}] {}", m.role, topic)
144 })
145 .filter(|line| line.len() > 10)
146 .collect();
147
148 if topics.is_empty() {
149 return messages
150 .iter()
151 .filter(|m| m.role == "system")
152 .cloned()
153 .collect();
154 }
155
156 let mut result: Vec<UnifiedMessage> = messages
157 .iter()
158 .filter(|m| m.role == "system")
159 .cloned()
160 .collect();
161 result.push(UnifiedMessage {
162 role: "assistant".into(),
163 content: format!("[Conversation Skeleton]\n{}", topics.join("\n")),
164 parts: None,
165 });
166 result
167}
168
169fn extract_topic_sentence(text: &str) -> String {
171 let end = text
172 .find(". ")
173 .or_else(|| text.find(".\n"))
174 .or_else(|| text.find('?'))
175 .or_else(|| text.find('!'))
176 .map(|i| i + 1)
177 .unwrap_or_else(|| text.len().min(120));
178 text[..end.min(text.len())].trim().to_string()
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub enum ComplexityLevel {
185 L0,
186 L1,
187 L2,
188 L3,
189}
190
191pub fn determine_level(complexity_score: f64) -> ComplexityLevel {
192 if complexity_score < 0.3 {
193 ComplexityLevel::L0
194 } else if complexity_score < 0.6 {
195 ComplexityLevel::L1
196 } else if complexity_score < 0.9 {
197 ComplexityLevel::L2
198 } else {
199 ComplexityLevel::L3
200 }
201}
202
203pub fn determine_level_with_minimum(
207 complexity_score: f64,
208 channel_minimum: Option<u8>,
209) -> ComplexityLevel {
210 let base = determine_level(complexity_score);
211 let Some(min) = channel_minimum else {
212 return base;
213 };
214 let min_level = match min {
215 0 => ComplexityLevel::L0,
216 1 => ComplexityLevel::L1,
217 2 => ComplexityLevel::L2,
218 _ => ComplexityLevel::L3,
219 };
220 if level_ordinal(base) < level_ordinal(min_level) {
221 min_level
222 } else {
223 base
224 }
225}
226
227fn level_ordinal(level: ComplexityLevel) -> u8 {
228 match level {
229 ComplexityLevel::L0 => 0,
230 ComplexityLevel::L1 => 1,
231 ComplexityLevel::L2 => 2,
232 ComplexityLevel::L3 => 3,
233 }
234}
235
236pub fn token_budget(level: ComplexityLevel) -> usize {
237 token_budget_with_config(level, &Default::default())
238}
239
240pub fn token_budget_with_config(
241 level: ComplexityLevel,
242 cfg: &roboticus_core::config::ContextBudgetConfig,
243) -> usize {
244 match level {
245 ComplexityLevel::L0 => cfg.l0,
246 ComplexityLevel::L1 => cfg.l1,
247 ComplexityLevel::L2 => cfg.l2,
248 ComplexityLevel::L3 => cfg.l3,
249 }
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
253pub struct ContextFootprint {
254 pub token_budget: usize,
255 pub system_prompt_tokens: usize,
256 pub memory_tokens: usize,
257 pub history_tokens: usize,
258 pub history_depth: usize,
259}
260
261pub fn estimate_tokens(text: &str) -> usize {
263 text.len().div_ceil(4)
264}
265
266pub fn build_context(
268 level: ComplexityLevel,
269 system_prompt: &str,
270 memories: &str,
271 history: &[UnifiedMessage],
272) -> Vec<UnifiedMessage> {
273 build_context_with_budget(level, system_prompt, memories, history, &Default::default())
274}
275
276pub fn build_context_with_budget(
278 level: ComplexityLevel,
279 system_prompt: &str,
280 memories: &str,
281 history: &[UnifiedMessage],
282 budget_cfg: &roboticus_core::config::ContextBudgetConfig,
283) -> Vec<UnifiedMessage> {
284 build_context_with_budget_footprint(level, system_prompt, memories, history, budget_cfg).0
285}
286
287pub fn build_context_with_budget_footprint(
289 level: ComplexityLevel,
290 system_prompt: &str,
291 memories: &str,
292 history: &[UnifiedMessage],
293 budget_cfg: &roboticus_core::config::ContextBudgetConfig,
294) -> (Vec<UnifiedMessage>, ContextFootprint) {
295 let mut budget = token_budget_with_config(level, budget_cfg);
296
297 let sys_tokens = estimate_tokens(system_prompt);
302 let soul_cap = budget_cfg.soul_token_cap(budget);
303 if sys_tokens > soul_cap && budget_cfg.soul_max_context_pct > 0.0 {
304 let needed = (sys_tokens as f64 / budget_cfg.soul_max_context_pct) as usize;
305 let l3_budget = token_budget_with_config(ComplexityLevel::L3, budget_cfg);
306 budget = needed.min(l3_budget);
307 }
308
309 let mut used = 0usize;
310 let mut messages = Vec::new();
311 let mut footprint = ContextFootprint {
312 token_budget: budget,
313 ..ContextFootprint::default()
314 };
315
316 if sys_tokens <= budget {
321 messages.push(UnifiedMessage {
322 role: "system".into(),
323 content: system_prompt.to_string(),
324 parts: None,
325 });
326 used += sys_tokens;
327 footprint.system_prompt_tokens += sys_tokens;
328 } else {
329 let max_chars = budget.saturating_mul(4);
333 let truncated: String = system_prompt.chars().take(max_chars).collect();
334 let truncated_tokens = estimate_tokens(&truncated);
335 messages.push(UnifiedMessage {
336 role: "system".into(),
337 content: truncated,
338 parts: None,
339 });
340 used += truncated_tokens;
341 footprint.system_prompt_tokens += truncated_tokens;
342 tracing::warn!(
343 sys_tokens,
344 budget,
345 "system prompt exceeds budget — truncated to fit"
346 );
347 }
348
349 if !memories.is_empty() {
350 let mem_tokens = estimate_tokens(memories);
351 if used + mem_tokens <= budget {
352 messages.push(UnifiedMessage {
353 role: "system".into(),
354 content: memories.to_string(),
355 parts: None,
356 });
357 used += mem_tokens;
358 footprint.memory_tokens += mem_tokens;
359 }
360 }
361
362 let mut history_buf: Vec<&UnifiedMessage> = Vec::new();
363 let mut history_tokens = 0usize;
364
365 for msg in history.iter().rev() {
366 let msg_tokens = estimate_tokens(&msg.content);
367 if used + history_tokens + msg_tokens > budget {
368 break;
369 }
370 history_tokens += msg_tokens;
371 history_buf.push(msg);
372 }
373
374 history_buf.reverse();
375 for msg in history_buf {
376 messages.push(msg.clone());
377 footprint.history_depth += 1;
378 }
379 footprint.history_tokens = history_tokens;
380
381 let prune_cfg = PruningConfig {
384 max_tokens: budget,
385 soft_trim_ratio: 1.0,
386 ..PruningConfig::default()
387 };
388 if needs_pruning(&messages, &prune_cfg) {
389 let trimmed = soft_trim(&messages, &prune_cfg).messages;
390 let footprint = classify_context_snapshot(&trimmed, memories.is_empty());
391 return (trimmed, footprint);
392 }
393
394 (messages, footprint)
395}
396
397pub fn classify_context_snapshot(
404 messages: &[UnifiedMessage],
405 memories_empty: bool,
406) -> ContextFootprint {
407 let mut footprint = ContextFootprint::default();
408 let mut system_seen = 0usize;
409 let memory_slot = if memories_empty { None } else { Some(1usize) };
410
411 for msg in messages {
412 let tokens = estimate_tokens(&msg.content);
413 if msg.role == "system" {
414 let idx = system_seen;
415 system_seen += 1;
416 if Some(idx) == memory_slot {
417 footprint.memory_tokens += tokens;
418 } else {
419 footprint.system_prompt_tokens += tokens;
420 }
421 } else {
422 footprint.history_tokens += tokens;
423 footprint.history_depth += 1;
424 }
425 }
426
427 footprint
428}
429
430pub fn inject_instruction_reminder(messages: &mut Vec<UnifiedMessage>, reminder: &str) -> bool {
445 let non_system_turns = messages.iter().filter(|m| m.role != "system").count();
446 if non_system_turns < crate::prompt::ANTI_FADE_TURN_THRESHOLD {
447 return false;
448 }
449
450 let insert_pos = messages
455 .iter()
456 .rposition(|m| m.role == "user")
457 .unwrap_or(messages.len());
458
459 messages.insert(
460 insert_pos,
461 UnifiedMessage {
462 role: "user".into(),
463 content: format!("[System Note] {reminder}"),
464 parts: None,
465 },
466 );
467 true
468}
469
470#[derive(Debug, Clone)]
471pub struct PruningConfig {
472 pub max_tokens: usize,
473 pub soft_trim_ratio: f64,
474 pub hard_clear_ratio: f64,
475 pub preserve_recent: usize,
476}
477
478impl Default for PruningConfig {
479 fn default() -> Self {
480 Self {
481 max_tokens: 128_000,
482 soft_trim_ratio: 0.8,
483 hard_clear_ratio: 0.95,
484 preserve_recent: 10,
485 }
486 }
487}
488
489#[derive(Debug, Clone)]
490pub struct PruningResult {
491 pub messages: Vec<UnifiedMessage>,
492 pub trimmed_count: usize,
493 pub compaction_summary: Option<String>,
494 pub total_tokens: usize,
495}
496
497pub fn count_tokens(messages: &[UnifiedMessage]) -> usize {
498 messages.iter().map(|m| estimate_tokens(&m.content)).sum()
499}
500
501pub fn needs_pruning(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
502 let tokens = count_tokens(messages);
503 tokens > ((config.max_tokens as f64 * config.soft_trim_ratio) as usize)
504}
505
506pub fn needs_hard_clear(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
507 let tokens = count_tokens(messages);
508 tokens > ((config.max_tokens as f64 * config.hard_clear_ratio) as usize)
509}
510
511pub fn soft_trim(messages: &[UnifiedMessage], config: &PruningConfig) -> PruningResult {
513 let target_tokens = (config.max_tokens as f64 * config.soft_trim_ratio) as usize;
514
515 let system_msgs: Vec<_> = messages
516 .iter()
517 .filter(|m| m.role == "system")
518 .cloned()
519 .collect();
520
521 let non_system: Vec<_> = messages
522 .iter()
523 .filter(|m| m.role != "system")
524 .cloned()
525 .collect();
526
527 let preserve_count = config.preserve_recent.min(non_system.len());
528 let preserved = &non_system[non_system.len().saturating_sub(preserve_count)..];
529
530 let mut result: Vec<UnifiedMessage> = system_msgs;
531 let system_tokens = count_tokens(&result);
532
533 let mut available = target_tokens.saturating_sub(system_tokens);
534 let mut kept = Vec::new();
535
536 for msg in preserved.iter().rev() {
537 let msg_tokens = estimate_tokens(&msg.content);
538 if msg_tokens <= available {
539 kept.push(msg.clone());
540 available = available.saturating_sub(msg_tokens);
541 }
542 }
545 kept.reverse();
546
547 let trimmed_count = non_system.len() - kept.len();
548 result.extend(kept);
549
550 let total_tokens = count_tokens(&result);
551
552 PruningResult {
553 messages: result,
554 trimmed_count,
555 compaction_summary: None,
556 total_tokens,
557 }
558}
559
560pub fn extract_trimmable(
562 messages: &[UnifiedMessage],
563 config: &PruningConfig,
564) -> Vec<UnifiedMessage> {
565 let non_system: Vec<_> = messages
566 .iter()
567 .filter(|m| m.role != "system")
568 .cloned()
569 .collect();
570
571 let preserve_count = config.preserve_recent.min(non_system.len());
572 let trim_end = non_system.len().saturating_sub(preserve_count);
573
574 non_system[..trim_end].to_vec()
575}
576
577pub fn build_compaction_prompt(trimmed: &[UnifiedMessage]) -> String {
579 let mut prompt = String::from(
580 "Summarize the following conversation history into a concise paragraph. \
581 Capture key facts, decisions, and context. Do not include greetings or filler.\n\n",
582 );
583
584 for msg in trimmed {
585 prompt.push_str(&format!("{}: {}\n", msg.role, msg.content));
586 }
587
588 prompt
589}
590
591pub fn compress_context(messages: &mut [UnifiedMessage], target_ratio: f64) {
597 use roboticus_llm::compression::PromptCompressor;
598
599 let compressor = PromptCompressor::new(target_ratio);
600
601 let last_user_idx = messages.iter().rposition(|m| m.role == "user");
603
604 for (i, msg) in messages.iter_mut().enumerate() {
605 if Some(i) == last_user_idx {
606 continue; }
608 if msg.content.len() < 200 {
610 continue;
611 }
612 msg.content = compressor.compress(&msg.content);
613 }
614}
615
616pub fn insert_compaction_summary(messages: &mut Vec<UnifiedMessage>, summary: String) {
618 let insert_pos = messages
619 .iter()
620 .position(|m| m.role != "system")
621 .unwrap_or(messages.len());
622
623 messages.insert(
624 insert_pos,
625 UnifiedMessage {
626 role: "system".into(),
627 content: format!("[Conversation Summary] {summary}"),
628 parts: None,
629 },
630 );
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn level_determination() {
639 assert_eq!(determine_level(0.0), ComplexityLevel::L0);
640 assert_eq!(determine_level(0.29), ComplexityLevel::L0);
641 assert_eq!(determine_level(0.3), ComplexityLevel::L1);
642 assert_eq!(determine_level(0.59), ComplexityLevel::L1);
643 assert_eq!(determine_level(0.6), ComplexityLevel::L2);
644 assert_eq!(determine_level(0.89), ComplexityLevel::L2);
645 assert_eq!(determine_level(0.9), ComplexityLevel::L3);
646 assert_eq!(determine_level(1.0), ComplexityLevel::L3);
647 }
648
649 #[test]
650 fn budget_values() {
651 assert_eq!(token_budget(ComplexityLevel::L0), 8_000);
652 assert_eq!(token_budget(ComplexityLevel::L1), 8_000);
653 assert_eq!(token_budget(ComplexityLevel::L2), 16_000);
654 assert_eq!(token_budget(ComplexityLevel::L3), 32_000);
655 }
656
657 #[test]
658 fn context_assembly_respects_budget() {
659 let sys = "You are a helpful agent.";
660 let mem = "User prefers concise answers.";
661 let history = vec![
662 UnifiedMessage {
663 role: "user".into(),
664 content: "Hello".into(),
665 parts: None,
666 },
667 UnifiedMessage {
668 role: "assistant".into(),
669 content: "Hi there!".into(),
670 parts: None,
671 },
672 ];
673
674 let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
675
676 assert!(!ctx.is_empty());
677 assert_eq!(ctx[0].role, "system");
678 assert_eq!(ctx[0].content, sys);
679
680 let total_chars: usize = ctx.iter().map(|m| m.content.len()).sum();
681 let total_tokens = total_chars.div_ceil(4);
682 assert!(total_tokens <= token_budget(ComplexityLevel::L0));
683 }
684
685 #[test]
686 fn context_truncates_old_history() {
687 let sys = "System prompt";
688 let mem = "";
689 let big_msg = "x".repeat(8000);
690 let history = vec![
691 UnifiedMessage {
692 role: "user".into(),
693 content: big_msg,
694 parts: None,
695 },
696 UnifiedMessage {
697 role: "user".into(),
698 content: "recent message".into(),
699 parts: None,
700 },
701 ];
702
703 let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
704 assert!(ctx.len() >= 2);
705 assert_eq!(ctx.last().unwrap().content, "recent message");
706 }
707
708 #[test]
709 fn pruning_config_defaults() {
710 let cfg = PruningConfig::default();
711 assert_eq!(cfg.max_tokens, 128_000);
712 assert_eq!(cfg.soft_trim_ratio, 0.8);
713 assert_eq!(cfg.hard_clear_ratio, 0.95);
714 assert_eq!(cfg.preserve_recent, 10);
715 }
716
717 #[test]
718 fn count_tokens_basic() {
719 let msgs = vec![UnifiedMessage {
720 role: "user".into(),
721 content: "hello world".into(),
722 parts: None,
723 }];
724 let tokens = count_tokens(&msgs);
725 assert!(tokens > 0);
726 assert_eq!(tokens, estimate_tokens("hello world"));
727 }
728
729 #[test]
730 fn needs_pruning_under_threshold() {
731 let msgs = vec![UnifiedMessage {
732 role: "user".into(),
733 content: "short".into(),
734 parts: None,
735 }];
736 let cfg = PruningConfig::default();
737 assert!(!needs_pruning(&msgs, &cfg));
738 }
739
740 #[test]
741 fn needs_pruning_over_threshold() {
742 let big = "x".repeat(500_000);
743 let msgs = vec![UnifiedMessage {
744 role: "user".into(),
745 content: big,
746 parts: None,
747 }];
748 let cfg = PruningConfig::default();
749 assert!(needs_pruning(&msgs, &cfg));
750 }
751
752 #[test]
753 fn soft_trim_preserves_recent() {
754 let mut msgs = Vec::new();
755 msgs.push(UnifiedMessage {
756 role: "system".into(),
757 content: "sys".into(),
758 parts: None,
759 });
760 for i in 0..20 {
761 msgs.push(UnifiedMessage {
762 role: if i % 2 == 0 { "user" } else { "assistant" }.into(),
763 content: format!("message {i}"),
764 parts: None,
765 });
766 }
767
768 let cfg = PruningConfig {
769 max_tokens: 200,
770 soft_trim_ratio: 0.8,
771 preserve_recent: 5,
772 ..Default::default()
773 };
774
775 let result = soft_trim(&msgs, &cfg);
776 assert!(result.messages[0].role == "system");
777 assert!(result.trimmed_count > 0);
778 let last = result.messages.last().unwrap();
779 assert_eq!(last.content, "message 19");
780 }
781
782 #[test]
783 fn extract_trimmable_gets_old_messages() {
784 let mut msgs = Vec::new();
785 msgs.push(UnifiedMessage {
786 role: "system".into(),
787 content: "sys".into(),
788 parts: None,
789 });
790 for i in 0..10 {
791 msgs.push(UnifiedMessage {
792 role: "user".into(),
793 content: format!("msg {i}"),
794 parts: None,
795 });
796 }
797
798 let cfg = PruningConfig {
799 preserve_recent: 3,
800 ..Default::default()
801 };
802 let trimmed = extract_trimmable(&msgs, &cfg);
803 assert_eq!(trimmed.len(), 7);
804 assert_eq!(trimmed[0].content, "msg 0");
805 }
806
807 #[test]
808 fn build_compaction_prompt_format() {
809 let msgs = vec![
810 UnifiedMessage {
811 role: "user".into(),
812 content: "hi".into(),
813 parts: None,
814 },
815 UnifiedMessage {
816 role: "assistant".into(),
817 content: "hello".into(),
818 parts: None,
819 },
820 ];
821 let prompt = build_compaction_prompt(&msgs);
822 assert!(prompt.contains("Summarize"));
823 assert!(prompt.contains("user: hi"));
824 assert!(prompt.contains("assistant: hello"));
825 }
826
827 #[test]
828 fn insert_compaction_summary_placement() {
829 let mut msgs = vec![
830 UnifiedMessage {
831 role: "system".into(),
832 content: "sys".into(),
833 parts: None,
834 },
835 UnifiedMessage {
836 role: "user".into(),
837 content: "hi".into(),
838 parts: None,
839 },
840 ];
841 insert_compaction_summary(&mut msgs, "summary here".into());
842 assert_eq!(msgs.len(), 3);
843 assert_eq!(msgs[0].role, "system");
844 assert_eq!(msgs[1].role, "system");
845 assert!(msgs[1].content.contains("summary here"));
846 assert_eq!(msgs[2].role, "user");
847 }
848
849 #[test]
850 fn needs_hard_clear_under_threshold() {
851 let msgs = vec![UnifiedMessage {
852 role: "user".into(),
853 content: "short".into(),
854 parts: None,
855 }];
856 let cfg = PruningConfig::default();
857 assert!(!needs_hard_clear(&msgs, &cfg));
858 }
859
860 #[test]
861 fn needs_hard_clear_over_threshold() {
862 let big = "y".repeat(500_000);
864 let msgs = vec![UnifiedMessage {
865 role: "user".into(),
866 content: big,
867 parts: None,
868 }];
869 let cfg = PruningConfig::default();
870 assert!(needs_hard_clear(&msgs, &cfg));
871 }
872
873 #[test]
874 fn insert_compaction_summary_no_system_messages() {
875 let mut msgs = vec![
877 UnifiedMessage {
878 role: "user".into(),
879 content: "hello".into(),
880 parts: None,
881 },
882 UnifiedMessage {
883 role: "assistant".into(),
884 content: "hi".into(),
885 parts: None,
886 },
887 ];
888 insert_compaction_summary(&mut msgs, "compacted info".into());
889 assert_eq!(msgs.len(), 3);
890 assert_eq!(msgs[0].role, "system");
891 assert!(msgs[0].content.contains("compacted info"));
892 assert_eq!(msgs[1].role, "user");
893 }
894
895 #[test]
896 fn insert_compaction_summary_all_system_messages() {
897 let mut msgs = vec![
899 UnifiedMessage {
900 role: "system".into(),
901 content: "sys1".into(),
902 parts: None,
903 },
904 UnifiedMessage {
905 role: "system".into(),
906 content: "sys2".into(),
907 parts: None,
908 },
909 ];
910 insert_compaction_summary(&mut msgs, "final summary".into());
911 assert_eq!(msgs.len(), 3);
912 assert_eq!(msgs[2].role, "system");
914 assert!(msgs[2].content.contains("final summary"));
915 }
916
917 #[test]
918 fn build_context_sys_prompt_exceeds_budget() {
919 let big_sys = "z".repeat(200_000);
923 let mem = "";
924 let history = vec![UnifiedMessage {
925 role: "user".into(),
926 content: "hi".into(),
927 parts: None,
928 }];
929
930 let ctx = build_context(ComplexityLevel::L0, &big_sys, mem, &history);
931 assert!(!ctx.is_empty());
933 assert_eq!(ctx[0].role, "system");
934 assert!(ctx[0].content.len() < big_sys.len());
936 assert!(!ctx[0].content.is_empty());
938 }
939
940 #[test]
941 fn build_context_empty_history() {
942 let sys = "Agent prompt";
943 let mem = "Memory info";
944 let history: Vec<UnifiedMessage> = vec![];
945
946 let ctx = build_context(ComplexityLevel::L1, sys, mem, &history);
947 assert_eq!(ctx.len(), 2); assert_eq!(ctx[0].content, sys);
949 assert_eq!(ctx[1].content, mem);
950 }
951
952 #[test]
953 fn build_context_returns_footprint_with_expected_split() {
954 let sys = "system prompt";
955 let mem = "memory block";
956 let history = vec![
957 UnifiedMessage {
958 role: "user".into(),
959 content: "hello".into(),
960 parts: None,
961 },
962 UnifiedMessage {
963 role: "assistant".into(),
964 content: "world".into(),
965 parts: None,
966 },
967 ];
968
969 let (ctx, fp) = build_context_with_budget_footprint(
970 ComplexityLevel::L1,
971 sys,
972 mem,
973 &history,
974 &Default::default(),
975 );
976
977 assert_eq!(ctx.len(), 4);
978 assert_eq!(fp.token_budget, token_budget(ComplexityLevel::L1));
979 assert_eq!(fp.system_prompt_tokens, estimate_tokens(sys));
980 assert_eq!(fp.memory_tokens, estimate_tokens(mem));
981 assert_eq!(
982 fp.history_tokens,
983 estimate_tokens("hello") + estimate_tokens("world")
984 );
985 assert_eq!(fp.history_depth, 2);
986
987 let classified = classify_context_snapshot(&ctx, false);
988 assert_eq!(classified.system_prompt_tokens, fp.system_prompt_tokens);
989 assert_eq!(classified.memory_tokens, fp.memory_tokens);
990 assert_eq!(classified.history_tokens, fp.history_tokens);
991 assert_eq!(classified.history_depth, fp.history_depth);
992 }
993
994 #[test]
995 fn soft_trim_no_non_system_messages() {
996 let msgs = vec![UnifiedMessage {
997 role: "system".into(),
998 content: "sys".into(),
999 parts: None,
1000 }];
1001 let cfg = PruningConfig {
1002 max_tokens: 200,
1003 preserve_recent: 5,
1004 ..Default::default()
1005 };
1006 let result = soft_trim(&msgs, &cfg);
1007 assert_eq!(result.messages.len(), 1);
1008 assert_eq!(result.trimmed_count, 0);
1009 }
1010
1011 #[test]
1012 fn extract_trimmable_fewer_than_preserve() {
1013 let msgs = vec![UnifiedMessage {
1014 role: "user".into(),
1015 content: "only one".into(),
1016 parts: None,
1017 }];
1018 let cfg = PruningConfig {
1019 preserve_recent: 5,
1020 ..Default::default()
1021 };
1022 let trimmed = extract_trimmable(&msgs, &cfg);
1023 assert!(
1024 trimmed.is_empty(),
1025 "nothing to trim if fewer than preserve_recent"
1026 );
1027 }
1028
1029 #[test]
1030 fn count_tokens_empty() {
1031 assert_eq!(count_tokens(&[]), 0);
1032 }
1033
1034 #[test]
1037 fn compaction_stage_from_excess_boundaries() {
1038 assert_eq!(CompactionStage::from_excess(0.5), CompactionStage::Verbatim);
1039 assert_eq!(CompactionStage::from_excess(1.0), CompactionStage::Verbatim);
1040 assert_eq!(
1041 CompactionStage::from_excess(1.01),
1042 CompactionStage::SelectiveTrim
1043 );
1044 assert_eq!(
1045 CompactionStage::from_excess(1.5),
1046 CompactionStage::SelectiveTrim
1047 );
1048 assert_eq!(
1049 CompactionStage::from_excess(1.51),
1050 CompactionStage::SemanticCompress
1051 );
1052 assert_eq!(
1053 CompactionStage::from_excess(2.5),
1054 CompactionStage::SemanticCompress
1055 );
1056 assert_eq!(
1057 CompactionStage::from_excess(2.51),
1058 CompactionStage::TopicExtract
1059 );
1060 assert_eq!(
1061 CompactionStage::from_excess(4.0),
1062 CompactionStage::TopicExtract
1063 );
1064 assert_eq!(
1065 CompactionStage::from_excess(4.01),
1066 CompactionStage::Skeleton
1067 );
1068 assert_eq!(
1069 CompactionStage::from_excess(100.0),
1070 CompactionStage::Skeleton
1071 );
1072 }
1073
1074 #[test]
1075 fn compaction_stage_ordering() {
1076 assert!(CompactionStage::Verbatim < CompactionStage::SelectiveTrim);
1077 assert!(CompactionStage::SelectiveTrim < CompactionStage::SemanticCompress);
1078 assert!(CompactionStage::SemanticCompress < CompactionStage::TopicExtract);
1079 assert!(CompactionStage::TopicExtract < CompactionStage::Skeleton);
1080 }
1081
1082 #[test]
1083 fn selective_trim_removes_filler() {
1084 let msgs = vec![
1085 UnifiedMessage {
1086 role: "system".into(),
1087 content: "sys prompt".into(),
1088 parts: None,
1089 },
1090 UnifiedMessage {
1091 role: "user".into(),
1092 content: "hello".into(),
1093 parts: None,
1094 },
1095 UnifiedMessage {
1096 role: "assistant".into(),
1097 content: "ok".into(),
1098 parts: None,
1099 },
1100 UnifiedMessage {
1101 role: "user".into(),
1102 content: "Please analyze the data and find anomalies in the revenue stream".into(),
1103 parts: None,
1104 },
1105 UnifiedMessage {
1106 role: "assistant".into(),
1107 content: "thanks".into(),
1108 parts: None,
1109 },
1110 ];
1111 let result = selective_trim(&msgs);
1112 assert_eq!(result.len(), 2);
1115 assert_eq!(result[0].role, "system");
1116 assert!(result[1].content.contains("analyze the data"));
1117 }
1118
1119 #[test]
1120 fn selective_trim_keeps_all_long_messages() {
1121 let msgs = vec![
1122 UnifiedMessage {
1123 role: "user".into(),
1124 content: "This is a long enough message that should never be trimmed away".into(),
1125 parts: None,
1126 },
1127 UnifiedMessage {
1128 role: "assistant".into(),
1129 content: "I agree, this response is also long enough to stay around".into(),
1130 parts: None,
1131 },
1132 ];
1133 let result = selective_trim(&msgs);
1134 assert_eq!(result.len(), 2);
1135 }
1136
1137 #[test]
1138 fn topic_extract_takes_first_sentence() {
1139 let msgs = vec![
1140 UnifiedMessage {
1141 role: "system".into(),
1142 content: "You are helpful.".into(),
1143 parts: None,
1144 },
1145 UnifiedMessage {
1146 role: "user".into(),
1147 content:
1148 "Deploy the model to production. Then run the test suite. Finally update docs."
1149 .into(),
1150 parts: None,
1151 },
1152 ];
1153 let result = topic_extract(&msgs);
1154 assert_eq!(result.len(), 2);
1155 assert_eq!(result[0].content, "You are helpful."); assert_eq!(result[1].content, "Deploy the model to production."); }
1158
1159 #[test]
1160 fn skeleton_compress_creates_outline() {
1161 let msgs = vec![
1162 UnifiedMessage {
1163 role: "system".into(),
1164 content: "System prompt".into(),
1165 parts: None,
1166 },
1167 UnifiedMessage {
1168 role: "user".into(),
1169 content: "How does authentication work in this app?".into(),
1170 parts: None,
1171 },
1172 UnifiedMessage {
1173 role: "assistant".into(),
1174 content: "Authentication uses JWT tokens with a 24-hour expiry. The flow starts at the login endpoint.".into(),
1175 parts: None,
1176 },
1177 ];
1178 let result = skeleton_compress(&msgs);
1179 assert_eq!(result.len(), 2);
1181 assert_eq!(result[0].content, "System prompt");
1182 assert_eq!(result[1].role, "assistant");
1183 assert!(result[1].content.contains("[Conversation Skeleton]"));
1184 assert!(result[1].content.contains("[user]"));
1185 assert!(result[1].content.contains("[assistant]"));
1186 }
1187
1188 #[test]
1189 fn skeleton_compress_empty_non_system() {
1190 let msgs = vec![UnifiedMessage {
1191 role: "system".into(),
1192 content: "sys".into(),
1193 parts: None,
1194 }];
1195 let result = skeleton_compress(&msgs);
1196 assert_eq!(result.len(), 1);
1197 assert_eq!(result[0].role, "system");
1198 }
1199
1200 #[test]
1201 fn compact_to_stage_verbatim_is_identity() {
1202 let msgs = vec![
1203 UnifiedMessage {
1204 role: "user".into(),
1205 content: "test".into(),
1206 parts: None,
1207 },
1208 UnifiedMessage {
1209 role: "assistant".into(),
1210 content: "resp".into(),
1211 parts: None,
1212 },
1213 ];
1214 let result = compact_to_stage(&msgs, CompactionStage::Verbatim);
1215 assert_eq!(result.len(), msgs.len());
1216 assert_eq!(result[0].content, "test");
1217 assert_eq!(result[1].content, "resp");
1218 }
1219
1220 #[test]
1221 fn compact_to_stage_dispatches_correctly() {
1222 let msgs = vec![
1223 UnifiedMessage {
1224 role: "user".into(),
1225 content: "hi".into(),
1226 parts: None,
1227 },
1228 UnifiedMessage {
1229 role: "user".into(),
1230 content: "Analyze the market data and identify trends in revenue growth over Q3"
1231 .into(),
1232 parts: None,
1233 },
1234 ];
1235 let trimmed = compact_to_stage(&msgs, CompactionStage::SelectiveTrim);
1237 assert_eq!(trimmed.len(), 1);
1238 assert!(trimmed[0].content.contains("Analyze"));
1239 }
1240
1241 #[test]
1242 fn extract_topic_sentence_with_period() {
1243 assert_eq!(
1244 extract_topic_sentence("First sentence. Second sentence. Third."),
1245 "First sentence."
1246 );
1247 }
1248
1249 #[test]
1250 fn extract_topic_sentence_with_question() {
1251 assert_eq!(
1252 extract_topic_sentence("What is this? More details here."),
1253 "What is this?"
1254 );
1255 }
1256
1257 #[test]
1258 fn extract_topic_sentence_no_punctuation() {
1259 let short = "Just some text without ending";
1260 assert_eq!(extract_topic_sentence(short), short);
1261 }
1262
1263 #[test]
1264 fn extract_topic_sentence_very_long() {
1265 let long = "x".repeat(200);
1266 let result = extract_topic_sentence(&long);
1267 assert!(result.len() <= 120);
1268 }
1269
1270 fn make_msg(role: &str, content: &str) -> UnifiedMessage {
1273 UnifiedMessage {
1274 role: role.into(),
1275 content: content.into(),
1276 parts: None,
1277 }
1278 }
1279
1280 #[test]
1281 fn inject_reminder_skips_short_conversations() {
1282 let mut msgs = vec![
1283 make_msg("system", "You are helpful."),
1284 make_msg("user", "Hello"),
1285 make_msg("assistant", "Hi!"),
1286 make_msg("user", "How are you?"),
1287 make_msg("assistant", "Good, thanks!"),
1288 ];
1289 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Be helpful.");
1291 assert!(!injected);
1292 assert_eq!(msgs.len(), 5);
1293 }
1294
1295 #[test]
1296 fn inject_reminder_fires_for_long_conversations() {
1297 let mut msgs = vec![make_msg("system", "You are helpful.")];
1298 for i in 0..10 {
1300 msgs.push(make_msg("user", &format!("question {i}")));
1301 msgs.push(make_msg("assistant", &format!("answer {i}")));
1302 }
1303 let len_before = msgs.len();
1304 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Always be thorough.");
1305 assert!(injected);
1306 assert_eq!(msgs.len(), len_before + 1);
1307
1308 let reminder_idx = msgs
1311 .iter()
1312 .rposition(|m| m.content.contains("[System Note]"))
1313 .unwrap();
1314 assert_eq!(msgs[reminder_idx].role, "user");
1315 assert!(
1316 msgs[reminder_idx]
1317 .content
1318 .contains("[Reminder] Always be thorough.")
1319 );
1320 }
1321
1322 #[test]
1323 fn inject_reminder_places_before_last_user_message() {
1324 let mut msgs = vec![make_msg("system", "System prompt.")];
1325 for i in 0..5 {
1326 msgs.push(make_msg("user", &format!("q{i}")));
1327 msgs.push(make_msg("assistant", &format!("a{i}")));
1328 }
1329 msgs.push(make_msg("user", "final question"));
1331
1332 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Key directive.");
1333 assert!(injected);
1334
1335 assert_eq!(msgs.last().unwrap().content, "final question");
1337 assert_eq!(msgs.last().unwrap().role, "user");
1338
1339 let second_last = &msgs[msgs.len() - 2];
1341 assert_eq!(second_last.role, "user");
1342 assert!(second_last.content.contains("[System Note]"));
1343 assert!(second_last.content.contains("[Reminder]"));
1344 }
1345
1346 #[test]
1347 fn inject_reminder_no_user_messages_appends_at_end() {
1348 let mut msgs = vec![make_msg("system", "System prompt.")];
1349 for i in 0..10 {
1351 msgs.push(make_msg("assistant", &format!("response {i}")));
1352 }
1353 let len_before = msgs.len();
1354 let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Test.");
1355 assert!(injected);
1356 assert_eq!(msgs.len(), len_before + 1);
1358 assert_eq!(
1359 msgs.last().unwrap().content,
1360 "[System Note] [Reminder] Test."
1361 );
1362 assert_eq!(msgs.last().unwrap().role, "user");
1363 }
1364
1365 #[test]
1366 fn determine_level_with_minimum_enforces_floor() {
1367 assert_eq!(
1369 determine_level_with_minimum(0.1, Some(1)),
1370 ComplexityLevel::L1,
1371 );
1372 }
1373
1374 #[test]
1375 fn determine_level_with_minimum_does_not_lower() {
1376 assert_eq!(
1378 determine_level_with_minimum(0.8, Some(1)),
1379 ComplexityLevel::L2,
1380 );
1381 }
1382
1383 #[test]
1384 fn determine_level_with_minimum_none_passthrough() {
1385 assert_eq!(determine_level_with_minimum(0.1, None), ComplexityLevel::L0,);
1387 }
1388}