use roboticus_llm::format::UnifiedMessage;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CompactionStage {
Verbatim,
SelectiveTrim,
SemanticCompress,
TopicExtract,
Skeleton,
}
impl CompactionStage {
pub fn from_excess(excess_ratio: f64) -> Self {
if excess_ratio <= 1.0 {
Self::Verbatim
} else if excess_ratio <= 1.5 {
Self::SelectiveTrim
} else if excess_ratio <= 2.5 {
Self::SemanticCompress
} else if excess_ratio <= 4.0 {
Self::TopicExtract
} else {
Self::Skeleton
}
}
}
pub fn compact_to_stage(
messages: &[UnifiedMessage],
stage: CompactionStage,
) -> Vec<UnifiedMessage> {
match stage {
CompactionStage::Verbatim => messages.to_vec(),
CompactionStage::SelectiveTrim => selective_trim(messages),
CompactionStage::SemanticCompress => semantic_compress(messages),
CompactionStage::TopicExtract => topic_extract(messages),
CompactionStage::Skeleton => skeleton_compress(messages),
}
}
fn selective_trim(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
const FILLER: &[&str] = &[
"hello",
"hi",
"hey",
"thanks",
"thank you",
"ok",
"okay",
"sure",
"got it",
"sounds good",
"no problem",
"np",
"ack",
"roger",
];
messages
.iter()
.filter(|m| {
if m.role == "system" {
return true;
}
if m.content.len() >= 40 {
return true;
}
let lower = m.content.trim().to_lowercase();
!FILLER.contains(&lower.as_str())
})
.cloned()
.collect()
}
fn semantic_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
use roboticus_llm::compression::PromptCompressor;
let compressor = PromptCompressor::new(0.6);
messages
.iter()
.map(|m| {
if m.role == "system" || m.content.len() < 100 {
m.clone()
} else {
UnifiedMessage {
role: m.role.clone(),
content: compressor.compress(&m.content),
parts: None,
}
}
})
.collect()
}
fn topic_extract(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
messages
.iter()
.map(|m| {
if m.role == "system" {
m.clone()
} else {
UnifiedMessage {
role: m.role.clone(),
content: extract_topic_sentence(&m.content),
parts: None,
}
}
})
.collect()
}
fn skeleton_compress(messages: &[UnifiedMessage]) -> Vec<UnifiedMessage> {
let topics: Vec<String> = messages
.iter()
.filter(|m| m.role != "system")
.map(|m| {
let topic = extract_topic_sentence(&m.content);
format!("[{}] {}", m.role, topic)
})
.filter(|line| line.len() > 10)
.collect();
if topics.is_empty() {
return messages
.iter()
.filter(|m| m.role == "system")
.cloned()
.collect();
}
let mut result: Vec<UnifiedMessage> = messages
.iter()
.filter(|m| m.role == "system")
.cloned()
.collect();
result.push(UnifiedMessage {
role: "assistant".into(),
content: format!("[Conversation Skeleton]\n{}", topics.join("\n")),
parts: None,
});
result
}
fn extract_topic_sentence(text: &str) -> String {
let end = text
.find(". ")
.or_else(|| text.find(".\n"))
.or_else(|| text.find('?'))
.or_else(|| text.find('!'))
.map(|i| i + 1)
.unwrap_or_else(|| text.len().min(120));
text[..end.min(text.len())].trim().to_string()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ComplexityLevel {
L0,
L1,
L2,
L3,
}
pub fn determine_level(complexity_score: f64) -> ComplexityLevel {
if complexity_score < 0.3 {
ComplexityLevel::L0
} else if complexity_score < 0.6 {
ComplexityLevel::L1
} else if complexity_score < 0.9 {
ComplexityLevel::L2
} else {
ComplexityLevel::L3
}
}
pub fn determine_level_with_minimum(
complexity_score: f64,
channel_minimum: Option<u8>,
) -> ComplexityLevel {
let base = determine_level(complexity_score);
let Some(min) = channel_minimum else {
return base;
};
let min_level = match min {
0 => ComplexityLevel::L0,
1 => ComplexityLevel::L1,
2 => ComplexityLevel::L2,
_ => ComplexityLevel::L3,
};
if level_ordinal(base) < level_ordinal(min_level) {
min_level
} else {
base
}
}
fn level_ordinal(level: ComplexityLevel) -> u8 {
match level {
ComplexityLevel::L0 => 0,
ComplexityLevel::L1 => 1,
ComplexityLevel::L2 => 2,
ComplexityLevel::L3 => 3,
}
}
pub fn token_budget(level: ComplexityLevel) -> usize {
token_budget_with_config(level, &Default::default())
}
pub fn token_budget_with_config(
level: ComplexityLevel,
cfg: &roboticus_core::config::ContextBudgetConfig,
) -> usize {
match level {
ComplexityLevel::L0 => cfg.l0,
ComplexityLevel::L1 => cfg.l1,
ComplexityLevel::L2 => cfg.l2,
ComplexityLevel::L3 => cfg.l3,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct ContextFootprint {
pub token_budget: usize,
pub system_prompt_tokens: usize,
pub memory_tokens: usize,
pub history_tokens: usize,
pub history_depth: usize,
}
pub fn estimate_tokens(text: &str) -> usize {
text.len().div_ceil(4)
}
pub fn build_context(
level: ComplexityLevel,
system_prompt: &str,
memories: &str,
history: &[UnifiedMessage],
) -> Vec<UnifiedMessage> {
build_context_with_budget(level, system_prompt, memories, history, &Default::default())
}
pub fn build_context_with_budget(
level: ComplexityLevel,
system_prompt: &str,
memories: &str,
history: &[UnifiedMessage],
budget_cfg: &roboticus_core::config::ContextBudgetConfig,
) -> Vec<UnifiedMessage> {
build_context_with_budget_footprint(level, system_prompt, memories, history, budget_cfg).0
}
pub fn build_context_with_budget_footprint(
level: ComplexityLevel,
system_prompt: &str,
memories: &str,
history: &[UnifiedMessage],
budget_cfg: &roboticus_core::config::ContextBudgetConfig,
) -> (Vec<UnifiedMessage>, ContextFootprint) {
let mut budget = token_budget_with_config(level, budget_cfg);
let sys_tokens = estimate_tokens(system_prompt);
let soul_cap = budget_cfg.soul_token_cap(budget);
if sys_tokens > soul_cap && budget_cfg.soul_max_context_pct > 0.0 {
let needed = (sys_tokens as f64 / budget_cfg.soul_max_context_pct) as usize;
let l3_budget = token_budget_with_config(ComplexityLevel::L3, budget_cfg);
budget = needed.min(l3_budget);
}
let mut used = 0usize;
let mut messages = Vec::new();
let mut footprint = ContextFootprint {
token_budget: budget,
..ContextFootprint::default()
};
if sys_tokens <= budget {
messages.push(UnifiedMessage {
role: "system".into(),
content: system_prompt.to_string(),
parts: None,
});
used += sys_tokens;
footprint.system_prompt_tokens += sys_tokens;
} else {
let max_chars = budget.saturating_mul(4);
let truncated: String = system_prompt.chars().take(max_chars).collect();
let truncated_tokens = estimate_tokens(&truncated);
messages.push(UnifiedMessage {
role: "system".into(),
content: truncated,
parts: None,
});
used += truncated_tokens;
footprint.system_prompt_tokens += truncated_tokens;
tracing::warn!(
sys_tokens,
budget,
"system prompt exceeds budget — truncated to fit"
);
}
if !memories.is_empty() {
let mem_tokens = estimate_tokens(memories);
if used + mem_tokens <= budget {
messages.push(UnifiedMessage {
role: "system".into(),
content: memories.to_string(),
parts: None,
});
used += mem_tokens;
footprint.memory_tokens += mem_tokens;
}
}
let mut history_buf: Vec<&UnifiedMessage> = Vec::new();
let mut history_tokens = 0usize;
for msg in history.iter().rev() {
let msg_tokens = estimate_tokens(&msg.content);
if used + history_tokens + msg_tokens > budget {
break;
}
history_tokens += msg_tokens;
history_buf.push(msg);
}
history_buf.reverse();
for msg in history_buf {
messages.push(msg.clone());
footprint.history_depth += 1;
}
footprint.history_tokens = history_tokens;
let prune_cfg = PruningConfig {
max_tokens: budget,
soft_trim_ratio: 1.0,
..PruningConfig::default()
};
if needs_pruning(&messages, &prune_cfg) {
let trimmed = soft_trim(&messages, &prune_cfg).messages;
let footprint = classify_context_snapshot(&trimmed, memories.is_empty());
return (trimmed, footprint);
}
(messages, footprint)
}
pub fn classify_context_snapshot(
messages: &[UnifiedMessage],
memories_empty: bool,
) -> ContextFootprint {
let mut footprint = ContextFootprint::default();
let mut system_seen = 0usize;
let memory_slot = if memories_empty { None } else { Some(1usize) };
for msg in messages {
let tokens = estimate_tokens(&msg.content);
if msg.role == "system" {
let idx = system_seen;
system_seen += 1;
if Some(idx) == memory_slot {
footprint.memory_tokens += tokens;
} else {
footprint.system_prompt_tokens += tokens;
}
} else {
footprint.history_tokens += tokens;
footprint.history_depth += 1;
}
}
footprint
}
pub fn inject_instruction_reminder(messages: &mut Vec<UnifiedMessage>, reminder: &str) -> bool {
let non_system_turns = messages.iter().filter(|m| m.role != "system").count();
if non_system_turns < crate::prompt::ANTI_FADE_TURN_THRESHOLD {
return false;
}
let insert_pos = messages
.iter()
.rposition(|m| m.role == "user")
.unwrap_or(messages.len());
messages.insert(
insert_pos,
UnifiedMessage {
role: "user".into(),
content: format!("[System Note] {reminder}"),
parts: None,
},
);
true
}
#[derive(Debug, Clone)]
pub struct PruningConfig {
pub max_tokens: usize,
pub soft_trim_ratio: f64,
pub hard_clear_ratio: f64,
pub preserve_recent: usize,
}
impl Default for PruningConfig {
fn default() -> Self {
Self {
max_tokens: 128_000,
soft_trim_ratio: 0.8,
hard_clear_ratio: 0.95,
preserve_recent: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct PruningResult {
pub messages: Vec<UnifiedMessage>,
pub trimmed_count: usize,
pub compaction_summary: Option<String>,
pub total_tokens: usize,
}
pub fn count_tokens(messages: &[UnifiedMessage]) -> usize {
messages.iter().map(|m| estimate_tokens(&m.content)).sum()
}
pub fn needs_pruning(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
let tokens = count_tokens(messages);
tokens > ((config.max_tokens as f64 * config.soft_trim_ratio) as usize)
}
pub fn needs_hard_clear(messages: &[UnifiedMessage], config: &PruningConfig) -> bool {
let tokens = count_tokens(messages);
tokens > ((config.max_tokens as f64 * config.hard_clear_ratio) as usize)
}
pub fn soft_trim(messages: &[UnifiedMessage], config: &PruningConfig) -> PruningResult {
let target_tokens = (config.max_tokens as f64 * config.soft_trim_ratio) as usize;
let system_msgs: Vec<_> = messages
.iter()
.filter(|m| m.role == "system")
.cloned()
.collect();
let non_system: Vec<_> = messages
.iter()
.filter(|m| m.role != "system")
.cloned()
.collect();
let preserve_count = config.preserve_recent.min(non_system.len());
let preserved = &non_system[non_system.len().saturating_sub(preserve_count)..];
let mut result: Vec<UnifiedMessage> = system_msgs;
let system_tokens = count_tokens(&result);
let mut available = target_tokens.saturating_sub(system_tokens);
let mut kept = Vec::new();
for msg in preserved.iter().rev() {
let msg_tokens = estimate_tokens(&msg.content);
if msg_tokens <= available {
kept.push(msg.clone());
available = available.saturating_sub(msg_tokens);
}
}
kept.reverse();
let trimmed_count = non_system.len() - kept.len();
result.extend(kept);
let total_tokens = count_tokens(&result);
PruningResult {
messages: result,
trimmed_count,
compaction_summary: None,
total_tokens,
}
}
pub fn extract_trimmable(
messages: &[UnifiedMessage],
config: &PruningConfig,
) -> Vec<UnifiedMessage> {
let non_system: Vec<_> = messages
.iter()
.filter(|m| m.role != "system")
.cloned()
.collect();
let preserve_count = config.preserve_recent.min(non_system.len());
let trim_end = non_system.len().saturating_sub(preserve_count);
non_system[..trim_end].to_vec()
}
pub fn build_compaction_prompt(trimmed: &[UnifiedMessage]) -> String {
let mut prompt = String::from(
"Summarize the following conversation history into a concise paragraph. \
Capture key facts, decisions, and context. Do not include greetings or filler.\n\n",
);
for msg in trimmed {
prompt.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
prompt
}
pub fn compress_context(messages: &mut [UnifiedMessage], target_ratio: f64) {
use roboticus_llm::compression::PromptCompressor;
let compressor = PromptCompressor::new(target_ratio);
let last_user_idx = messages.iter().rposition(|m| m.role == "user");
for (i, msg) in messages.iter_mut().enumerate() {
if Some(i) == last_user_idx {
continue; }
if msg.content.len() < 200 {
continue;
}
msg.content = compressor.compress(&msg.content);
}
}
pub fn insert_compaction_summary(messages: &mut Vec<UnifiedMessage>, summary: String) {
let insert_pos = messages
.iter()
.position(|m| m.role != "system")
.unwrap_or(messages.len());
messages.insert(
insert_pos,
UnifiedMessage {
role: "system".into(),
content: format!("[Conversation Summary] {summary}"),
parts: None,
},
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn level_determination() {
assert_eq!(determine_level(0.0), ComplexityLevel::L0);
assert_eq!(determine_level(0.29), ComplexityLevel::L0);
assert_eq!(determine_level(0.3), ComplexityLevel::L1);
assert_eq!(determine_level(0.59), ComplexityLevel::L1);
assert_eq!(determine_level(0.6), ComplexityLevel::L2);
assert_eq!(determine_level(0.89), ComplexityLevel::L2);
assert_eq!(determine_level(0.9), ComplexityLevel::L3);
assert_eq!(determine_level(1.0), ComplexityLevel::L3);
}
#[test]
fn budget_values() {
assert_eq!(token_budget(ComplexityLevel::L0), 8_000);
assert_eq!(token_budget(ComplexityLevel::L1), 8_000);
assert_eq!(token_budget(ComplexityLevel::L2), 16_000);
assert_eq!(token_budget(ComplexityLevel::L3), 32_000);
}
#[test]
fn context_assembly_respects_budget() {
let sys = "You are a helpful agent.";
let mem = "User prefers concise answers.";
let history = vec![
UnifiedMessage {
role: "user".into(),
content: "Hello".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "Hi there!".into(),
parts: None,
},
];
let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
assert!(!ctx.is_empty());
assert_eq!(ctx[0].role, "system");
assert_eq!(ctx[0].content, sys);
let total_chars: usize = ctx.iter().map(|m| m.content.len()).sum();
let total_tokens = total_chars.div_ceil(4);
assert!(total_tokens <= token_budget(ComplexityLevel::L0));
}
#[test]
fn context_truncates_old_history() {
let sys = "System prompt";
let mem = "";
let big_msg = "x".repeat(8000);
let history = vec![
UnifiedMessage {
role: "user".into(),
content: big_msg,
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "recent message".into(),
parts: None,
},
];
let ctx = build_context(ComplexityLevel::L0, sys, mem, &history);
assert!(ctx.len() >= 2);
assert_eq!(ctx.last().unwrap().content, "recent message");
}
#[test]
fn pruning_config_defaults() {
let cfg = PruningConfig::default();
assert_eq!(cfg.max_tokens, 128_000);
assert_eq!(cfg.soft_trim_ratio, 0.8);
assert_eq!(cfg.hard_clear_ratio, 0.95);
assert_eq!(cfg.preserve_recent, 10);
}
#[test]
fn count_tokens_basic() {
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: "hello world".into(),
parts: None,
}];
let tokens = count_tokens(&msgs);
assert!(tokens > 0);
assert_eq!(tokens, estimate_tokens("hello world"));
}
#[test]
fn needs_pruning_under_threshold() {
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: "short".into(),
parts: None,
}];
let cfg = PruningConfig::default();
assert!(!needs_pruning(&msgs, &cfg));
}
#[test]
fn needs_pruning_over_threshold() {
let big = "x".repeat(500_000);
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: big,
parts: None,
}];
let cfg = PruningConfig::default();
assert!(needs_pruning(&msgs, &cfg));
}
#[test]
fn soft_trim_preserves_recent() {
let mut msgs = Vec::new();
msgs.push(UnifiedMessage {
role: "system".into(),
content: "sys".into(),
parts: None,
});
for i in 0..20 {
msgs.push(UnifiedMessage {
role: if i % 2 == 0 { "user" } else { "assistant" }.into(),
content: format!("message {i}"),
parts: None,
});
}
let cfg = PruningConfig {
max_tokens: 200,
soft_trim_ratio: 0.8,
preserve_recent: 5,
..Default::default()
};
let result = soft_trim(&msgs, &cfg);
assert!(result.messages[0].role == "system");
assert!(result.trimmed_count > 0);
let last = result.messages.last().unwrap();
assert_eq!(last.content, "message 19");
}
#[test]
fn extract_trimmable_gets_old_messages() {
let mut msgs = Vec::new();
msgs.push(UnifiedMessage {
role: "system".into(),
content: "sys".into(),
parts: None,
});
for i in 0..10 {
msgs.push(UnifiedMessage {
role: "user".into(),
content: format!("msg {i}"),
parts: None,
});
}
let cfg = PruningConfig {
preserve_recent: 3,
..Default::default()
};
let trimmed = extract_trimmable(&msgs, &cfg);
assert_eq!(trimmed.len(), 7);
assert_eq!(trimmed[0].content, "msg 0");
}
#[test]
fn build_compaction_prompt_format() {
let msgs = vec![
UnifiedMessage {
role: "user".into(),
content: "hi".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "hello".into(),
parts: None,
},
];
let prompt = build_compaction_prompt(&msgs);
assert!(prompt.contains("Summarize"));
assert!(prompt.contains("user: hi"));
assert!(prompt.contains("assistant: hello"));
}
#[test]
fn insert_compaction_summary_placement() {
let mut msgs = vec![
UnifiedMessage {
role: "system".into(),
content: "sys".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "hi".into(),
parts: None,
},
];
insert_compaction_summary(&mut msgs, "summary here".into());
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].role, "system");
assert_eq!(msgs[1].role, "system");
assert!(msgs[1].content.contains("summary here"));
assert_eq!(msgs[2].role, "user");
}
#[test]
fn needs_hard_clear_under_threshold() {
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: "short".into(),
parts: None,
}];
let cfg = PruningConfig::default();
assert!(!needs_hard_clear(&msgs, &cfg));
}
#[test]
fn needs_hard_clear_over_threshold() {
let big = "y".repeat(500_000);
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: big,
parts: None,
}];
let cfg = PruningConfig::default();
assert!(needs_hard_clear(&msgs, &cfg));
}
#[test]
fn insert_compaction_summary_no_system_messages() {
let mut msgs = vec![
UnifiedMessage {
role: "user".into(),
content: "hello".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "hi".into(),
parts: None,
},
];
insert_compaction_summary(&mut msgs, "compacted info".into());
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].role, "system");
assert!(msgs[0].content.contains("compacted info"));
assert_eq!(msgs[1].role, "user");
}
#[test]
fn insert_compaction_summary_all_system_messages() {
let mut msgs = vec![
UnifiedMessage {
role: "system".into(),
content: "sys1".into(),
parts: None,
},
UnifiedMessage {
role: "system".into(),
content: "sys2".into(),
parts: None,
},
];
insert_compaction_summary(&mut msgs, "final summary".into());
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[2].role, "system");
assert!(msgs[2].content.contains("final summary"));
}
#[test]
fn build_context_sys_prompt_exceeds_budget() {
let big_sys = "z".repeat(200_000);
let mem = "";
let history = vec![UnifiedMessage {
role: "user".into(),
content: "hi".into(),
parts: None,
}];
let ctx = build_context(ComplexityLevel::L0, &big_sys, mem, &history);
assert!(!ctx.is_empty());
assert_eq!(ctx[0].role, "system");
assert!(ctx[0].content.len() < big_sys.len());
assert!(!ctx[0].content.is_empty());
}
#[test]
fn build_context_empty_history() {
let sys = "Agent prompt";
let mem = "Memory info";
let history: Vec<UnifiedMessage> = vec![];
let ctx = build_context(ComplexityLevel::L1, sys, mem, &history);
assert_eq!(ctx.len(), 2); assert_eq!(ctx[0].content, sys);
assert_eq!(ctx[1].content, mem);
}
#[test]
fn build_context_returns_footprint_with_expected_split() {
let sys = "system prompt";
let mem = "memory block";
let history = vec![
UnifiedMessage {
role: "user".into(),
content: "hello".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "world".into(),
parts: None,
},
];
let (ctx, fp) = build_context_with_budget_footprint(
ComplexityLevel::L1,
sys,
mem,
&history,
&Default::default(),
);
assert_eq!(ctx.len(), 4);
assert_eq!(fp.token_budget, token_budget(ComplexityLevel::L1));
assert_eq!(fp.system_prompt_tokens, estimate_tokens(sys));
assert_eq!(fp.memory_tokens, estimate_tokens(mem));
assert_eq!(
fp.history_tokens,
estimate_tokens("hello") + estimate_tokens("world")
);
assert_eq!(fp.history_depth, 2);
let classified = classify_context_snapshot(&ctx, false);
assert_eq!(classified.system_prompt_tokens, fp.system_prompt_tokens);
assert_eq!(classified.memory_tokens, fp.memory_tokens);
assert_eq!(classified.history_tokens, fp.history_tokens);
assert_eq!(classified.history_depth, fp.history_depth);
}
#[test]
fn soft_trim_no_non_system_messages() {
let msgs = vec![UnifiedMessage {
role: "system".into(),
content: "sys".into(),
parts: None,
}];
let cfg = PruningConfig {
max_tokens: 200,
preserve_recent: 5,
..Default::default()
};
let result = soft_trim(&msgs, &cfg);
assert_eq!(result.messages.len(), 1);
assert_eq!(result.trimmed_count, 0);
}
#[test]
fn extract_trimmable_fewer_than_preserve() {
let msgs = vec![UnifiedMessage {
role: "user".into(),
content: "only one".into(),
parts: None,
}];
let cfg = PruningConfig {
preserve_recent: 5,
..Default::default()
};
let trimmed = extract_trimmable(&msgs, &cfg);
assert!(
trimmed.is_empty(),
"nothing to trim if fewer than preserve_recent"
);
}
#[test]
fn count_tokens_empty() {
assert_eq!(count_tokens(&[]), 0);
}
#[test]
fn compaction_stage_from_excess_boundaries() {
assert_eq!(CompactionStage::from_excess(0.5), CompactionStage::Verbatim);
assert_eq!(CompactionStage::from_excess(1.0), CompactionStage::Verbatim);
assert_eq!(
CompactionStage::from_excess(1.01),
CompactionStage::SelectiveTrim
);
assert_eq!(
CompactionStage::from_excess(1.5),
CompactionStage::SelectiveTrim
);
assert_eq!(
CompactionStage::from_excess(1.51),
CompactionStage::SemanticCompress
);
assert_eq!(
CompactionStage::from_excess(2.5),
CompactionStage::SemanticCompress
);
assert_eq!(
CompactionStage::from_excess(2.51),
CompactionStage::TopicExtract
);
assert_eq!(
CompactionStage::from_excess(4.0),
CompactionStage::TopicExtract
);
assert_eq!(
CompactionStage::from_excess(4.01),
CompactionStage::Skeleton
);
assert_eq!(
CompactionStage::from_excess(100.0),
CompactionStage::Skeleton
);
}
#[test]
fn compaction_stage_ordering() {
assert!(CompactionStage::Verbatim < CompactionStage::SelectiveTrim);
assert!(CompactionStage::SelectiveTrim < CompactionStage::SemanticCompress);
assert!(CompactionStage::SemanticCompress < CompactionStage::TopicExtract);
assert!(CompactionStage::TopicExtract < CompactionStage::Skeleton);
}
#[test]
fn selective_trim_removes_filler() {
let msgs = vec![
UnifiedMessage {
role: "system".into(),
content: "sys prompt".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "hello".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "ok".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "Please analyze the data and find anomalies in the revenue stream".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "thanks".into(),
parts: None,
},
];
let result = selective_trim(&msgs);
assert_eq!(result.len(), 2);
assert_eq!(result[0].role, "system");
assert!(result[1].content.contains("analyze the data"));
}
#[test]
fn selective_trim_keeps_all_long_messages() {
let msgs = vec![
UnifiedMessage {
role: "user".into(),
content: "This is a long enough message that should never be trimmed away".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "I agree, this response is also long enough to stay around".into(),
parts: None,
},
];
let result = selective_trim(&msgs);
assert_eq!(result.len(), 2);
}
#[test]
fn topic_extract_takes_first_sentence() {
let msgs = vec![
UnifiedMessage {
role: "system".into(),
content: "You are helpful.".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content:
"Deploy the model to production. Then run the test suite. Finally update docs."
.into(),
parts: None,
},
];
let result = topic_extract(&msgs);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "You are helpful."); assert_eq!(result[1].content, "Deploy the model to production."); }
#[test]
fn skeleton_compress_creates_outline() {
let msgs = vec![
UnifiedMessage {
role: "system".into(),
content: "System prompt".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "How does authentication work in this app?".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "Authentication uses JWT tokens with a 24-hour expiry. The flow starts at the login endpoint.".into(),
parts: None,
},
];
let result = skeleton_compress(&msgs);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "System prompt");
assert_eq!(result[1].role, "assistant");
assert!(result[1].content.contains("[Conversation Skeleton]"));
assert!(result[1].content.contains("[user]"));
assert!(result[1].content.contains("[assistant]"));
}
#[test]
fn skeleton_compress_empty_non_system() {
let msgs = vec![UnifiedMessage {
role: "system".into(),
content: "sys".into(),
parts: None,
}];
let result = skeleton_compress(&msgs);
assert_eq!(result.len(), 1);
assert_eq!(result[0].role, "system");
}
#[test]
fn compact_to_stage_verbatim_is_identity() {
let msgs = vec![
UnifiedMessage {
role: "user".into(),
content: "test".into(),
parts: None,
},
UnifiedMessage {
role: "assistant".into(),
content: "resp".into(),
parts: None,
},
];
let result = compact_to_stage(&msgs, CompactionStage::Verbatim);
assert_eq!(result.len(), msgs.len());
assert_eq!(result[0].content, "test");
assert_eq!(result[1].content, "resp");
}
#[test]
fn compact_to_stage_dispatches_correctly() {
let msgs = vec![
UnifiedMessage {
role: "user".into(),
content: "hi".into(),
parts: None,
},
UnifiedMessage {
role: "user".into(),
content: "Analyze the market data and identify trends in revenue growth over Q3"
.into(),
parts: None,
},
];
let trimmed = compact_to_stage(&msgs, CompactionStage::SelectiveTrim);
assert_eq!(trimmed.len(), 1);
assert!(trimmed[0].content.contains("Analyze"));
}
#[test]
fn extract_topic_sentence_with_period() {
assert_eq!(
extract_topic_sentence("First sentence. Second sentence. Third."),
"First sentence."
);
}
#[test]
fn extract_topic_sentence_with_question() {
assert_eq!(
extract_topic_sentence("What is this? More details here."),
"What is this?"
);
}
#[test]
fn extract_topic_sentence_no_punctuation() {
let short = "Just some text without ending";
assert_eq!(extract_topic_sentence(short), short);
}
#[test]
fn extract_topic_sentence_very_long() {
let long = "x".repeat(200);
let result = extract_topic_sentence(&long);
assert!(result.len() <= 120);
}
fn make_msg(role: &str, content: &str) -> UnifiedMessage {
UnifiedMessage {
role: role.into(),
content: content.into(),
parts: None,
}
}
#[test]
fn inject_reminder_skips_short_conversations() {
let mut msgs = vec![
make_msg("system", "You are helpful."),
make_msg("user", "Hello"),
make_msg("assistant", "Hi!"),
make_msg("user", "How are you?"),
make_msg("assistant", "Good, thanks!"),
];
let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Be helpful.");
assert!(!injected);
assert_eq!(msgs.len(), 5);
}
#[test]
fn inject_reminder_fires_for_long_conversations() {
let mut msgs = vec![make_msg("system", "You are helpful.")];
for i in 0..10 {
msgs.push(make_msg("user", &format!("question {i}")));
msgs.push(make_msg("assistant", &format!("answer {i}")));
}
let len_before = msgs.len();
let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Always be thorough.");
assert!(injected);
assert_eq!(msgs.len(), len_before + 1);
let reminder_idx = msgs
.iter()
.rposition(|m| m.content.contains("[System Note]"))
.unwrap();
assert_eq!(msgs[reminder_idx].role, "user");
assert!(
msgs[reminder_idx]
.content
.contains("[Reminder] Always be thorough.")
);
}
#[test]
fn inject_reminder_places_before_last_user_message() {
let mut msgs = vec![make_msg("system", "System prompt.")];
for i in 0..5 {
msgs.push(make_msg("user", &format!("q{i}")));
msgs.push(make_msg("assistant", &format!("a{i}")));
}
msgs.push(make_msg("user", "final question"));
let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Key directive.");
assert!(injected);
assert_eq!(msgs.last().unwrap().content, "final question");
assert_eq!(msgs.last().unwrap().role, "user");
let second_last = &msgs[msgs.len() - 2];
assert_eq!(second_last.role, "user");
assert!(second_last.content.contains("[System Note]"));
assert!(second_last.content.contains("[Reminder]"));
}
#[test]
fn inject_reminder_no_user_messages_appends_at_end() {
let mut msgs = vec![make_msg("system", "System prompt.")];
for i in 0..10 {
msgs.push(make_msg("assistant", &format!("response {i}")));
}
let len_before = msgs.len();
let injected = inject_instruction_reminder(&mut msgs, "[Reminder] Test.");
assert!(injected);
assert_eq!(msgs.len(), len_before + 1);
assert_eq!(
msgs.last().unwrap().content,
"[System Note] [Reminder] Test."
);
assert_eq!(msgs.last().unwrap().role, "user");
}
#[test]
fn determine_level_with_minimum_enforces_floor() {
assert_eq!(
determine_level_with_minimum(0.1, Some(1)),
ComplexityLevel::L1,
);
}
#[test]
fn determine_level_with_minimum_does_not_lower() {
assert_eq!(
determine_level_with_minimum(0.8, Some(1)),
ComplexityLevel::L2,
);
}
#[test]
fn determine_level_with_minimum_none_passthrough() {
assert_eq!(determine_level_with_minimum(0.1, None), ComplexityLevel::L0,);
}
}