use super::{Capability, CapabilityStatus};
use crate::message_filter::MessageFilterProvider;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub const COMPACTION_CAPABILITY_ID: &str = "compaction";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CompactionStrategy {
#[default]
Auto,
Native,
ObservationMasking,
Summarization,
}
impl std::fmt::Display for CompactionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Auto => write!(f, "auto"),
Self::Native => write!(f, "native"),
Self::ObservationMasking => write!(f, "observation_masking"),
Self::Summarization => write!(f, "summarization"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MaskingSummaryFormat {
#[default]
OneLine,
HeadTail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObservationMaskingConfig {
#[serde(default = "default_keep_recent_tool_outputs")]
pub keep_recent_tool_outputs: usize,
#[serde(default)]
pub summary_format: MaskingSummaryFormat,
}
impl Default for ObservationMaskingConfig {
fn default() -> Self {
Self {
keep_recent_tool_outputs: default_keep_recent_tool_outputs(),
summary_format: MaskingSummaryFormat::default(),
}
}
}
fn default_keep_recent_tool_outputs() -> usize {
2
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SummarizationConfig {
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_preserve")]
pub preserve: Vec<String>,
#[serde(default)]
pub instructions: Option<String>,
}
impl Default for SummarizationConfig {
fn default() -> Self {
Self {
model: None,
preserve: default_preserve(),
instructions: None,
}
}
}
fn default_preserve() -> Vec<String> {
vec![
"decisions".to_string(),
"files_modified".to_string(),
"errors".to_string(),
"current_plan".to_string(),
"skill_instructions".to_string(),
]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionConfig {
#[serde(default)]
pub strategy: CompactionStrategy,
#[serde(default = "default_proactive")]
pub proactive: bool,
#[serde(default = "default_budget_percent")]
pub budget_percent: f32,
#[serde(default)]
pub observation_masking: ObservationMaskingConfig,
#[serde(default)]
pub summarization: SummarizationConfig,
#[serde(default)]
pub memory_tiers: HierarchicalMemoryConfig,
}
impl Default for CompactionConfig {
fn default() -> Self {
Self {
strategy: CompactionStrategy::default(),
proactive: default_proactive(),
budget_percent: default_budget_percent(),
observation_masking: ObservationMaskingConfig::default(),
summarization: SummarizationConfig::default(),
memory_tiers: HierarchicalMemoryConfig::default(),
}
}
}
fn default_proactive() -> bool {
true
}
fn default_budget_percent() -> f32 {
0.85
}
impl CompactionConfig {
pub fn from_json(value: &serde_json::Value) -> Self {
serde_json::from_value(value.clone()).unwrap_or_default()
}
}
pub struct CompactionCapability;
impl Capability for CompactionCapability {
fn id(&self) -> &str {
COMPACTION_CAPABILITY_ID
}
fn name(&self) -> &str {
"Compaction"
}
fn description(&self) -> &str {
r#"Configurable context compaction when conversations exceed LLM context windows.
Choose between native provider compaction (e.g., OpenAI /responses/compact), observation masking (strip old tool outputs), or LLM summarization. The `auto` strategy cascades through all available options."#
}
fn status(&self) -> CapabilityStatus {
CapabilityStatus::Available
}
fn icon(&self) -> Option<&str> {
Some("shrink")
}
fn category(&self) -> Option<&str> {
Some("Optimization")
}
fn message_filter_provider(&self) -> Option<Arc<dyn MessageFilterProvider>> {
Some(Arc::new(CompactionFilterProvider))
}
}
struct CompactionFilterProvider;
impl MessageFilterProvider for CompactionFilterProvider {
fn apply_filters(
&self,
_query: &mut crate::message_filter::MessageQuery,
_config: &serde_json::Value,
) {
}
fn priority(&self) -> i32 {
50 }
}
pub fn estimate_tokens(msg: &LlmMessage) -> usize {
let text_len = match &msg.content {
LlmMessageContent::Text(t) => t.len(),
LlmMessageContent::Parts(parts) => parts
.iter()
.map(|p| match p {
LlmContentPart::Text { text } => text.len(),
_ => 50, })
.sum(),
};
let tool_call_len = msg
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| tc.name.len() + tc.arguments.to_string().len() + 20)
.sum::<usize>()
})
.unwrap_or(0);
(text_len + tool_call_len) / 4
}
pub fn estimate_total_tokens(messages: &[LlmMessage]) -> usize {
messages.iter().map(estimate_tokens).sum()
}
pub fn should_compact_proactively(
messages: &[LlmMessage],
config: &CompactionConfig,
context_window_tokens: usize,
) -> bool {
if !config.proactive {
return false;
}
let budget = (context_window_tokens as f32 * config.budget_percent) as usize;
let estimated = estimate_total_tokens(messages);
estimated > budget
}
pub fn aggressive_trim(
messages: &[LlmMessage],
target_tokens: usize,
has_system_prompt: bool,
) -> Vec<LlmMessage> {
let mut result = Vec::new();
let mut token_budget = target_tokens;
let start_idx = if has_system_prompt && !messages.is_empty() {
let sys_tokens = estimate_tokens(&messages[0]);
if sys_tokens < token_budget {
result.push(messages[0].clone());
token_budget -= sys_tokens;
}
1
} else {
0
};
let conversation = &messages[start_idx..];
let protected_indices: std::collections::HashSet<usize> = conversation
.iter()
.enumerate()
.filter(|(_, m)| {
is_protected_tool_result(conversation, m) || is_protected_tool_call_message(m)
})
.map(|(i, _)| i)
.collect();
let mut protected_budget: usize = 0;
for &idx in &protected_indices {
protected_budget += estimate_tokens(&conversation[idx]);
}
if protected_budget > token_budget {
let mut protected_with_indices: Vec<(usize, LlmMessage)> = protected_indices
.iter()
.map(|&idx| (idx, conversation[idx].clone()))
.collect();
protected_with_indices.sort_by_key(|(i, _)| *i);
let mut remaining = token_budget;
let mut kept: Vec<(usize, LlmMessage)> = Vec::new();
for (idx, msg) in protected_with_indices.into_iter().rev() {
let t = estimate_tokens(&msg);
if t <= remaining {
kept.push((idx, msg));
remaining -= t;
}
}
kept.sort_by_key(|(i, _)| *i);
result.extend(kept.into_iter().map(|(_, m)| m));
return result;
}
token_budget -= protected_budget;
let mut keep_from_end = Vec::new();
for (i, msg) in conversation.iter().enumerate().rev() {
if protected_indices.contains(&i) {
continue; }
let msg_tokens = estimate_tokens(msg);
if msg_tokens <= token_budget {
keep_from_end.push((i, msg.clone()));
token_budget -= msg_tokens;
} else {
break;
}
}
let mut all_kept: Vec<(usize, LlmMessage)> = Vec::new();
for &idx in &protected_indices {
all_kept.push((idx, conversation[idx].clone()));
}
all_kept.extend(keep_from_end);
all_kept.sort_by_key(|(i, _)| *i);
result.extend(all_kept.into_iter().map(|(_, m)| m));
result
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionCompactionMetrics {
pub compaction_count: u32,
pub total_messages_saved: u64,
pub strategy_counts: HashMap<String, u32>,
pub total_duration_ms: u64,
}
impl SessionCompactionMetrics {
pub fn record(
&mut self,
strategy_used: &str,
messages_before: usize,
messages_after: usize,
duration_ms: u64,
) {
self.compaction_count += 1;
self.total_messages_saved += (messages_before.saturating_sub(messages_after)) as u64;
self.total_duration_ms += duration_ms;
for strategy in strategy_used.split('+') {
*self
.strategy_counts
.entry(strategy.to_string())
.or_insert(0) += 1;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryTier {
Hot,
Warm,
Cold,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalMemoryConfig {
#[serde(default = "default_hot_messages")]
pub hot_messages: usize,
#[serde(default = "default_warm_messages")]
pub warm_messages: usize,
}
impl Default for HierarchicalMemoryConfig {
fn default() -> Self {
Self {
hot_messages: default_hot_messages(),
warm_messages: default_warm_messages(),
}
}
}
fn default_hot_messages() -> usize {
20
}
fn default_warm_messages() -> usize {
100
}
pub fn classify_memory_tiers<'a>(
messages: &'a [LlmMessage],
config: &HierarchicalMemoryConfig,
) -> Vec<(MemoryTier, &'a LlmMessage)> {
let len = messages.len();
messages
.iter()
.enumerate()
.map(|(i, msg)| {
let from_end = len - 1 - i;
let tier = if from_end < config.hot_messages {
MemoryTier::Hot
} else if from_end < config.hot_messages + config.warm_messages {
MemoryTier::Warm
} else {
MemoryTier::Cold
};
(tier, msg)
})
.collect()
}
pub fn apply_hierarchical_memory(
messages: &[LlmMessage],
config: &HierarchicalMemoryConfig,
masking_config: &ObservationMaskingConfig,
cold_summary: Option<&str>,
) -> Vec<LlmMessage> {
let len = messages.len();
let hot_start = len.saturating_sub(config.hot_messages);
let warm_start = hot_start.saturating_sub(config.warm_messages);
let mut result = Vec::new();
if warm_start > 0 {
let cold_msgs = &messages[..warm_start];
let protected_cold: Vec<LlmMessage> = cold_msgs
.iter()
.filter(|m| is_protected_tool_result(cold_msgs, m) || is_protected_tool_call_message(m))
.cloned()
.collect();
if let Some(summary) = cold_summary {
result.push(build_summary_message(summary));
}
result.extend(protected_cold);
}
if warm_start < hot_start {
let warm_msgs = &messages[warm_start..hot_start];
let protected_call_ids: std::collections::HashSet<String> = warm_msgs
.iter()
.filter(|m| is_protected_tool_result(messages, m))
.filter_map(|m| m.tool_call_id.clone())
.collect();
let masked = apply_observation_masking_with_protected(
warm_msgs,
masking_config,
&protected_call_ids,
);
result.extend(masked.messages);
}
if hot_start < len {
result.extend_from_slice(&messages[hot_start..]);
}
result
}
use crate::llm_driver_registry::{LlmContentPart, LlmMessage, LlmMessageContent, LlmMessageRole};
const PROTECTED_TOOL_NAMES: &[&str] = &["activate_skill"];
fn is_protected_tool_result(messages: &[LlmMessage], tool_msg: &LlmMessage) -> bool {
if tool_msg.role != LlmMessageRole::Tool {
return false;
}
let tool_name = find_tool_call_name(messages, tool_msg);
PROTECTED_TOOL_NAMES.contains(&tool_name.as_str())
}
fn is_protected_tool_call_message(msg: &LlmMessage) -> bool {
if msg.role != LlmMessageRole::Assistant {
return false;
}
msg.tool_calls.as_ref().is_some_and(|calls| {
calls
.iter()
.any(|tc| PROTECTED_TOOL_NAMES.contains(&tc.name.as_str()))
})
}
#[derive(Debug)]
pub struct ObservationMaskingResult {
pub messages: Vec<LlmMessage>,
pub masked_count: usize,
}
pub fn apply_observation_masking(
messages: &[LlmMessage],
config: &ObservationMaskingConfig,
) -> ObservationMaskingResult {
apply_observation_masking_with_protected(messages, config, &std::collections::HashSet::new())
}
fn apply_observation_masking_with_protected(
messages: &[LlmMessage],
config: &ObservationMaskingConfig,
extra_protected_call_ids: &std::collections::HashSet<String>,
) -> ObservationMaskingResult {
let tool_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| {
m.role == LlmMessageRole::Tool
&& !is_protected_tool_result(messages, m)
&& !m
.tool_call_id
.as_ref()
.is_some_and(|id| extra_protected_call_ids.contains(id))
})
.map(|(i, _)| i)
.collect();
if tool_indices.len() <= config.keep_recent_tool_outputs {
return ObservationMaskingResult {
messages: messages.to_vec(),
masked_count: 0,
};
}
let to_mask_count = tool_indices.len() - config.keep_recent_tool_outputs;
let indices_to_mask: std::collections::HashSet<usize> =
tool_indices[..to_mask_count].iter().copied().collect();
let mut result = Vec::with_capacity(messages.len());
let mut masked_count = 0;
for (i, msg) in messages.iter().enumerate() {
if indices_to_mask.contains(&i) {
let tool_name = find_tool_call_name(messages, msg);
let summary = match config.summary_format {
MaskingSummaryFormat::OneLine => format_one_line_summary(&tool_name, &msg.content),
MaskingSummaryFormat::HeadTail => format_head_tail_summary(&msg.content),
};
result.push(LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text(summary),
tool_calls: msg.tool_calls.clone(),
tool_call_id: msg.tool_call_id.clone(),
phase: msg.phase,
thinking: None,
thinking_signature: None,
});
masked_count += 1;
} else {
result.push(msg.clone());
}
}
ObservationMaskingResult {
messages: result,
masked_count,
}
}
fn find_tool_call_name(messages: &[LlmMessage], tool_msg: &LlmMessage) -> String {
let Some(ref call_id) = tool_msg.tool_call_id else {
return "unknown_tool".to_string();
};
for msg in messages.iter().rev() {
if msg.role == LlmMessageRole::Assistant
&& let Some(ref tool_calls) = msg.tool_calls
{
for tc in tool_calls {
if tc.id == *call_id {
return tc.name.clone();
}
}
}
}
"unknown_tool".to_string()
}
fn extract_text(content: &LlmMessageContent) -> String {
match content {
LlmMessageContent::Text(t) => t.clone(),
LlmMessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| {
if let LlmContentPart::Text { text } = p {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
}
}
fn format_one_line_summary(tool_name: &str, content: &LlmMessageContent) -> String {
let text = extract_text(content);
let line_count = text.lines().count();
let byte_len = text.len();
if byte_len <= 100 {
format!("[{tool_name} → {text}]")
} else {
format!("[{tool_name} → {line_count} lines, {byte_len} bytes]")
}
}
fn format_head_tail_summary(content: &LlmMessageContent) -> String {
let text = extract_text(content);
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= 6 {
return text;
}
let head: Vec<&str> = lines[..3].to_vec();
let tail: Vec<&str> = lines[lines.len() - 3..].to_vec();
format!(
"{}\n... ({} lines omitted) ...\n{}",
head.join("\n"),
lines.len() - 6,
tail.join("\n")
)
}
pub fn build_summarization_prompt(config: &SummarizationConfig) -> String {
let preserve_items = if config.preserve.is_empty() {
default_preserve()
} else {
config.preserve.clone()
};
let preserve_list = preserve_items
.iter()
.map(|item| format!("- {item}"))
.collect::<Vec<_>>()
.join("\n");
let custom_instructions = config
.instructions
.as_deref()
.map(|instr| format!("\n- {instr}"))
.unwrap_or_default();
format!(
r#"<task>
Summarize the following conversation history. The summary replaces these
messages in the agent's context window — it must contain everything the
agent needs to continue working.
</task>
<preserve>
{preserve_list}{custom_instructions}
</preserve>
<format>
Produce a structured summary. Use sections. Be concise but complete.
Do not include tool output verbatim — reference files by path.
IMPORTANT: Any activate_skill tool results contain durable skill instructions.
Include them verbatim in a dedicated "Active Skills" section — do not summarize
or paraphrase skill instructions.
</format>"#
)
}
pub fn format_messages_for_summarization(messages: &[LlmMessage]) -> String {
let mut parts = Vec::new();
for msg in messages {
let role = match msg.role {
LlmMessageRole::System => "system",
LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
LlmMessageRole::Tool => "tool",
};
let content = extract_text(&msg.content);
let is_protected = is_protected_tool_result(messages, msg);
let truncated = if !is_protected && content.len() > 2000 {
let safe_prefix = truncate_at_char_boundary(&content, 2000);
format!(
"{}... [truncated, {} chars total]",
safe_prefix,
content.len()
)
} else {
content
};
parts.push(format!("[{role}]: {truncated}"));
}
parts.join("\n\n")
}
fn truncate_at_char_boundary(content: &str, max_bytes: usize) -> &str {
if content.len() <= max_bytes {
return content;
}
if content.is_char_boundary(max_bytes) {
return &content[..max_bytes];
}
let mut end = max_bytes;
while end > 0 && !content.is_char_boundary(end) {
end -= 1;
}
&content[..end]
}
pub fn build_summary_message(summary_text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::System,
content: LlmMessageContent::Text(format!(
"[CONVERSATION_SUMMARY]\n{summary_text}\n[/CONVERSATION_SUMMARY]"
)),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionStep {
pub strategy: String,
pub messages_after: usize,
pub duration_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool_types::ToolCall;
use serde_json::json;
fn make_user_msg(text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::User,
content: LlmMessageContent::Text(text.to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_assistant_msg(text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text(text.to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_assistant_with_tool_call(call_id: &str, tool_name: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text(String::new()),
tool_calls: Some(vec![ToolCall {
id: call_id.to_string(),
name: tool_name.to_string(),
arguments: json!({"path": "src/main.rs"}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_tool_result(call_id: &str, output: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text(output.to_string()),
tool_calls: None,
tool_call_id: Some(call_id.to_string()),
phase: None,
thinking: None,
thinking_signature: None,
}
}
#[test]
fn test_capability_metadata() {
let cap = CompactionCapability;
assert_eq!(cap.id(), COMPACTION_CAPABILITY_ID);
assert_eq!(cap.name(), "Compaction");
assert_eq!(cap.status(), CapabilityStatus::Available);
assert_eq!(cap.category(), Some("Optimization"));
assert!(cap.tools().is_empty());
assert!(cap.message_filter_provider().is_some());
}
#[test]
fn test_default_config() {
let config = CompactionConfig::default();
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
assert!((config.budget_percent - 0.85).abs() < f32::EPSILON);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 2);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::OneLine
);
assert!(config.summarization.model.is_none());
assert_eq!(config.summarization.preserve.len(), 5);
assert!(config.summarization.instructions.is_none());
}
#[test]
fn test_config_from_empty_json() {
let config = CompactionConfig::from_json(&json!({}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
}
#[test]
fn test_config_native_only() {
let config = CompactionConfig::from_json(&json!({"strategy": "native"}));
assert_eq!(config.strategy, CompactionStrategy::Native);
assert!(config.proactive);
}
#[test]
fn test_config_observation_masking_with_custom_settings() {
let config = CompactionConfig::from_json(&json!({
"strategy": "observation_masking",
"proactive": false,
"observation_masking": {
"keep_recent_tool_outputs": 10,
"summary_format": "head_tail"
}
}));
assert_eq!(config.strategy, CompactionStrategy::ObservationMasking);
assert!(!config.proactive);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 10);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::HeadTail
);
}
#[test]
fn test_config_summarization_with_custom_model() {
let config = CompactionConfig::from_json(&json!({
"strategy": "summarization",
"summarization": {
"model": "claude-haiku-4-5-20251001",
"instructions": "Focus on API decisions",
"preserve": ["decisions", "errors"]
}
}));
assert_eq!(config.strategy, CompactionStrategy::Summarization);
assert_eq!(
config.summarization.model.as_deref(),
Some("claude-haiku-4-5-20251001")
);
assert_eq!(
config.summarization.instructions.as_deref(),
Some("Focus on API decisions")
);
assert_eq!(config.summarization.preserve.len(), 2);
}
#[test]
fn test_config_falls_back_to_defaults_for_invalid_json() {
let config = CompactionConfig::from_json(&json!({
"strategy": "nonexistent_strategy",
"budget_percent": "not-a-number"
}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
}
#[test]
fn test_config_partial_override() {
let config = CompactionConfig::from_json(&json!({
"budget_percent": 0.7,
"observation_masking": {
"keep_recent_tool_outputs": 3
}
}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
assert!((config.budget_percent - 0.7).abs() < f32::EPSILON);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 3);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::OneLine
);
}
#[test]
fn test_strategy_serialization_roundtrip() {
for strategy in [
CompactionStrategy::Auto,
CompactionStrategy::Native,
CompactionStrategy::ObservationMasking,
CompactionStrategy::Summarization,
] {
let json = serde_json::to_value(strategy).unwrap();
let deserialized: CompactionStrategy = serde_json::from_value(json).unwrap();
assert_eq!(strategy, deserialized);
}
}
#[test]
fn test_strategy_display() {
assert_eq!(CompactionStrategy::Auto.to_string(), "auto");
assert_eq!(CompactionStrategy::Native.to_string(), "native");
assert_eq!(
CompactionStrategy::ObservationMasking.to_string(),
"observation_masking"
);
assert_eq!(
CompactionStrategy::Summarization.to_string(),
"summarization"
);
}
#[test]
fn test_masking_format_serialization_roundtrip() {
for format in [
MaskingSummaryFormat::OneLine,
MaskingSummaryFormat::HeadTail,
] {
let json = serde_json::to_value(format).unwrap();
let deserialized: MaskingSummaryFormat = serde_json::from_value(json).unwrap();
assert_eq!(format, deserialized);
}
}
#[test]
fn test_budget_percent_boundary_values() {
let config = CompactionConfig::from_json(&json!({"budget_percent": 0.1}));
assert!((config.budget_percent - 0.1).abs() < f32::EPSILON);
let config = CompactionConfig::from_json(&json!({"budget_percent": 0.99}));
assert!((config.budget_percent - 0.99).abs() < f32::EPSILON);
}
#[test]
fn test_keep_recent_tool_outputs_zero() {
let config = CompactionConfig::from_json(&json!({
"observation_masking": {"keep_recent_tool_outputs": 0}
}));
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 0);
}
#[test]
fn test_masking_no_tool_messages() {
let messages = vec![make_user_msg("hello"), make_assistant_msg("hi")];
let config = ObservationMaskingConfig::default();
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 0);
assert_eq!(result.messages.len(), 2);
}
#[test]
fn test_masking_fewer_than_keep_recent() {
let messages = vec![
make_user_msg("read file"),
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result("call_1", "file contents"),
make_assistant_msg("done"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 5,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 0);
}
#[test]
fn test_masking_masks_old_outputs() {
let messages = vec![
make_user_msg("start"),
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result(
"call_1",
"old file contents that are very long and should be masked by the observation masking strategy because it exceeds 100 chars",
),
make_assistant_msg("got it"),
make_user_msg("next"),
make_assistant_with_tool_call("call_2", "search"),
make_tool_result("call_2", "search results"),
make_assistant_msg("found it"),
make_user_msg("more"),
make_assistant_with_tool_call("call_3", "bash"),
make_tool_result("call_3", "command output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 2,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
let masked = &result.messages[2];
assert_eq!(masked.role, LlmMessageRole::Tool);
let text = extract_text(&masked.content);
assert!(
text.starts_with('['),
"Expected masked summary, got: {text}"
);
assert!(text.contains("read_file"), "Expected tool name: {text}");
assert_eq!(extract_text(&result.messages[6].content), "search results");
assert_eq!(extract_text(&result.messages[10].content), "command output");
}
#[test]
fn test_masking_preserves_tool_call_id() {
let messages = vec![
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result("call_1", "content"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.messages[1].tool_call_id, Some("call_1".to_string()));
}
#[test]
fn test_masking_head_tail_format() {
let long_output = (0..20)
.map(|i| format!("line {i}"))
.collect::<Vec<_>>()
.join("\n");
let messages = vec![
make_assistant_with_tool_call("call_1", "bash"),
make_tool_result("call_1", &long_output),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "recent output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::HeadTail,
};
let result = apply_observation_masking(&messages, &config);
let text = extract_text(&result.messages[1].content);
assert!(text.contains("line 0"), "Should contain first lines");
assert!(text.contains("line 19"), "Should contain last lines");
assert!(text.contains("lines omitted"), "Should indicate omissions");
}
#[test]
fn test_masking_short_output_inline() {
let messages = vec![
make_assistant_with_tool_call("call_1", "get_time"),
make_tool_result("call_1", "2024-01-01"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "ok"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
let text = extract_text(&result.messages[1].content);
assert!(text.contains("2024-01-01"), "Short output included: {text}");
}
#[test]
fn test_masking_all_when_keep_zero() {
let messages = vec![
make_assistant_with_tool_call("call_1", "a"),
make_tool_result("call_1", "output1"),
make_assistant_with_tool_call("call_2", "b"),
make_tool_result("call_2", "output2"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 2);
}
#[test]
fn test_masking_empty_messages() {
let result = apply_observation_masking(&[], &ObservationMaskingConfig::default());
assert_eq!(result.masked_count, 0);
assert!(result.messages.is_empty());
}
#[test]
fn test_masking_preserves_message_count() {
let messages = vec![
make_user_msg("start"),
make_assistant_with_tool_call("c1", "read_file"),
make_tool_result("c1", "content 1"),
make_assistant_msg("ok"),
make_user_msg("next"),
make_assistant_with_tool_call("c2", "bash"),
make_tool_result("c2", "content 2"),
make_assistant_msg("done"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.messages.len(), messages.len());
}
#[test]
fn test_masking_unknown_tool_call_id() {
let messages = vec![
make_tool_result("orphan", "some output"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "recent"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
let text = extract_text(&result.messages[0].content);
assert!(text.contains("unknown_tool"), "Fallback name: {text}");
}
#[test]
fn test_masking_many_tool_calls_keeps_exactly_n() {
let mut messages = Vec::new();
for i in 0..10 {
let id = format!("call_{i}");
messages.push(make_assistant_with_tool_call(&id, &format!("tool_{i}")));
messages.push(make_tool_result(&id, &format!("output {i}")));
}
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 3,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 7);
assert_eq!(extract_text(&result.messages[15].content), "output 7");
assert_eq!(extract_text(&result.messages[17].content), "output 8");
assert_eq!(extract_text(&result.messages[19].content), "output 9");
}
#[test]
fn test_summarization_prompt_default() {
let config = SummarizationConfig::default();
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("<task>"));
assert!(prompt.contains("decisions"));
assert!(prompt.contains("files_modified"));
assert!(prompt.contains("errors"));
assert!(prompt.contains("current_plan"));
}
#[test]
fn test_summarization_prompt_custom_instructions() {
let config = SummarizationConfig {
instructions: Some("Focus on API changes".to_string()),
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("Focus on API changes"));
}
#[test]
fn test_summarization_prompt_custom_preserve() {
let config = SummarizationConfig {
preserve: vec!["auth_tokens".to_string(), "database_schema".to_string()],
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("auth_tokens"));
assert!(prompt.contains("database_schema"));
assert!(!prompt.contains("decisions"));
}
#[test]
fn test_summarization_prompt_empty_preserve_uses_defaults() {
let config = SummarizationConfig {
preserve: vec![],
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("decisions"));
}
#[test]
fn test_format_messages_for_summarization() {
let messages = vec![
make_user_msg("What is 2+2?"),
make_assistant_msg("The answer is 4."),
];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("[user]: What is 2+2?"));
assert!(formatted.contains("[assistant]: The answer is 4."));
}
#[test]
fn test_format_messages_truncates_long_content() {
let long_content = "x".repeat(5000);
let messages = vec![make_user_msg(&long_content)];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("truncated"));
assert!(formatted.len() < long_content.len());
}
#[test]
fn test_format_messages_truncates_utf8_without_panic() {
let multibyte = "é".repeat(1001); let messages = vec![make_user_msg(&multibyte)];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("truncated"));
assert!(formatted.contains("[truncated, 2002 chars total]"));
}
#[test]
fn test_build_summary_message() {
let msg = build_summary_message("The user asked about APIs.");
assert_eq!(msg.role, LlmMessageRole::System);
let text = extract_text(&msg.content);
assert!(text.contains("[CONVERSATION_SUMMARY]"));
assert!(text.contains("The user asked about APIs."));
assert!(text.contains("[/CONVERSATION_SUMMARY]"));
}
#[test]
fn test_head_tail_short_content_unchanged() {
let content = LlmMessageContent::Text("line1\nline2\nline3".to_string());
assert_eq!(format_head_tail_summary(&content), "line1\nline2\nline3");
}
#[test]
fn test_head_tail_exactly_six_lines() {
let content = LlmMessageContent::Text("1\n2\n3\n4\n5\n6".to_string());
assert_eq!(format_head_tail_summary(&content), "1\n2\n3\n4\n5\n6");
}
#[test]
fn test_head_tail_seven_lines() {
let content = LlmMessageContent::Text("1\n2\n3\n4\n5\n6\n7".to_string());
let result = format_head_tail_summary(&content);
assert!(result.contains("1\n2\n3"));
assert!(result.contains("5\n6\n7"));
assert!(result.contains("1 lines omitted"));
}
#[test]
fn test_one_line_empty_output() {
let result = format_one_line_summary("bash", &LlmMessageContent::Text(String::new()));
assert_eq!(result, "[bash → ]");
}
#[test]
fn test_one_line_exactly_100_chars() {
let text = "x".repeat(100);
let result = format_one_line_summary("bash", &LlmMessageContent::Text(text.clone()));
assert!(result.contains(&text));
}
#[test]
fn test_one_line_101_chars_summarized() {
let text = "x".repeat(101);
let result = format_one_line_summary("bash", &LlmMessageContent::Text(text));
assert!(result.contains("lines"));
assert!(result.contains("bytes"));
}
#[test]
fn test_one_line_multipart_content() {
let content = LlmMessageContent::Parts(vec![
LlmContentPart::Text {
text: "part1".to_string(),
},
LlmContentPart::Text {
text: "part2".to_string(),
},
]);
let result = format_one_line_summary("tool", &content);
assert!(result.contains("part1"));
assert!(result.contains("part2"));
}
#[test]
fn test_compaction_step_serialization() {
let step = CompactionStep {
strategy: "observation_masking".to_string(),
messages_after: 42,
duration_ms: 12,
};
let json = serde_json::to_value(&step).unwrap();
assert_eq!(json["strategy"], "observation_masking");
assert_eq!(json["messages_after"], 42);
assert_eq!(json["duration_ms"], 12);
}
#[test]
fn test_estimate_tokens_text() {
let msg = make_user_msg("hello world"); let tokens = estimate_tokens(&msg);
assert_eq!(tokens, 11 / 4);
}
#[test]
fn test_estimate_tokens_empty() {
let msg = make_user_msg("");
assert_eq!(estimate_tokens(&msg), 0);
}
#[test]
fn test_estimate_total_tokens() {
let messages = vec![
make_user_msg("a".repeat(400).as_str()), make_assistant_msg("b".repeat(200).as_str()), ];
assert_eq!(estimate_total_tokens(&messages), 150);
}
#[test]
fn test_estimate_tokens_with_tool_calls() {
let msg = make_assistant_with_tool_call("call_1", "read_file");
let tokens = estimate_tokens(&msg);
assert!(tokens > 0, "Tool call should contribute tokens");
}
#[test]
fn test_should_compact_proactively_under_budget() {
let messages = vec![make_user_msg("short")];
let config = CompactionConfig::default(); assert!(!should_compact_proactively(&messages, &config, 128_000));
}
#[test]
fn test_should_compact_proactively_over_budget() {
let big_text = "x".repeat(4000); let messages = vec![make_user_msg(&big_text)];
let config = CompactionConfig::default();
assert!(should_compact_proactively(&messages, &config, 1000));
}
#[test]
fn test_should_compact_proactively_disabled() {
let big_text = "x".repeat(4000);
let messages = vec![make_user_msg(&big_text)];
let config = CompactionConfig {
proactive: false,
..Default::default()
};
assert!(!should_compact_proactively(&messages, &config, 1000));
}
#[test]
fn test_aggressive_trim_keeps_newest() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_user_msg(&"a".repeat(400)), make_assistant_msg(&"b".repeat(400)), make_user_msg(&"c".repeat(400)), make_assistant_msg(&"d".repeat(400)), ];
let target_tokens = 300;
let result = aggressive_trim(&messages, target_tokens, true);
assert!(
result.len() < messages.len(),
"Expected trim, got {} messages",
result.len()
);
assert_eq!(result[0].role, LlmMessageRole::User);
}
#[test]
fn test_aggressive_trim_empty() {
let result = aggressive_trim(&[], 100, false);
assert!(result.is_empty());
}
#[test]
fn test_aggressive_trim_everything_fits() {
let messages = vec![make_user_msg("hi"), make_assistant_msg("hello")];
let result = aggressive_trim(&messages, 100_000, false);
assert_eq!(result.len(), 2);
}
#[test]
fn test_session_metrics_record() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("observation_masking+native", 100, 50, 200);
assert_eq!(metrics.compaction_count, 1);
assert_eq!(metrics.total_messages_saved, 50);
assert_eq!(metrics.total_duration_ms, 200);
assert_eq!(metrics.strategy_counts["observation_masking"], 1);
assert_eq!(metrics.strategy_counts["native"], 1);
}
#[test]
fn test_session_metrics_accumulate() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("observation_masking", 100, 80, 10);
metrics.record("summarization", 80, 40, 500);
assert_eq!(metrics.compaction_count, 2);
assert_eq!(metrics.total_messages_saved, 60);
assert_eq!(metrics.total_duration_ms, 510);
assert_eq!(metrics.strategy_counts["observation_masking"], 1);
assert_eq!(metrics.strategy_counts["summarization"], 1);
}
#[test]
fn test_session_metrics_serialization() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("auto", 50, 30, 100);
let json = serde_json::to_value(&metrics).unwrap();
assert_eq!(json["compaction_count"], 1);
assert_eq!(json["total_messages_saved"], 20);
}
#[test]
fn test_classify_memory_tiers_basic() {
let messages: Vec<LlmMessage> = (0..30)
.map(|i| make_user_msg(&format!("msg {i}")))
.collect();
let config = HierarchicalMemoryConfig {
hot_messages: 5,
warm_messages: 10,
};
let classified = classify_memory_tiers(&messages, &config);
assert_eq!(classified.len(), 30);
assert_eq!(classified[29].0, MemoryTier::Hot);
assert_eq!(classified[25].0, MemoryTier::Hot);
assert_eq!(classified[24].0, MemoryTier::Warm);
assert_eq!(classified[15].0, MemoryTier::Warm);
assert_eq!(classified[14].0, MemoryTier::Cold);
assert_eq!(classified[0].0, MemoryTier::Cold);
}
#[test]
fn test_classify_memory_tiers_all_hot() {
let messages: Vec<LlmMessage> =
(0..3).map(|i| make_user_msg(&format!("msg {i}"))).collect();
let config = HierarchicalMemoryConfig::default();
let classified = classify_memory_tiers(&messages, &config);
assert!(classified.iter().all(|(tier, _)| *tier == MemoryTier::Hot));
}
#[test]
fn test_apply_hierarchical_memory_basic() {
let mut messages = Vec::new();
for i in 0..5 {
let id = format!("old_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("old content {i}")));
}
for i in 0..3 {
let id = format!("mid_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("mid output {i}")));
}
messages.push(make_user_msg("what now?"));
messages.push(make_assistant_msg("let me check"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 6,
};
let masking_config = ObservationMaskingConfig::default();
let result = apply_hierarchical_memory(
&messages,
&config,
&masking_config,
Some("Summary of old work"),
);
assert!(result.len() <= 9);
let first_text = extract_text(&result[0].content);
assert!(first_text.contains("CONVERSATION_SUMMARY"));
let last = extract_text(&result[result.len() - 1].content);
assert!(last.contains("let me check"));
}
#[test]
fn test_apply_hierarchical_memory_no_cold() {
let messages = vec![make_user_msg("hello"), make_assistant_msg("hi")];
let config = HierarchicalMemoryConfig {
hot_messages: 5,
warm_messages: 5,
};
let result = apply_hierarchical_memory(
&messages,
&config,
&ObservationMaskingConfig::default(),
None,
);
assert_eq!(result.len(), 2);
}
#[test]
fn test_memory_tier_config_from_json() {
let config: HierarchicalMemoryConfig = serde_json::from_value(json!({
"hot_messages": 10,
"warm_messages": 50
}))
.unwrap();
assert_eq!(config.hot_messages, 10);
assert_eq!(config.warm_messages, 50);
}
#[test]
fn test_memory_tier_config_defaults() {
let config = HierarchicalMemoryConfig::default();
assert_eq!(config.hot_messages, 20);
assert_eq!(config.warm_messages, 100);
}
#[test]
fn test_compaction_config_with_memory_tiers() {
let config = CompactionConfig::from_json(&json!({
"strategy": "auto",
"memory_tiers": {
"hot_messages": 15,
"warm_messages": 80
}
}));
assert_eq!(config.memory_tiers.hot_messages, 15);
assert_eq!(config.memory_tiers.warm_messages, 80);
}
#[test]
fn test_memory_tier_serialization() {
assert_eq!(serde_json::to_value(MemoryTier::Hot).unwrap(), json!("hot"));
assert_eq!(
serde_json::to_value(MemoryTier::Warm).unwrap(),
json!("warm")
);
assert_eq!(
serde_json::to_value(MemoryTier::Cold).unwrap(),
json!("cold")
);
}
#[test]
fn test_masking_skips_activate_skill_results() {
let messages = vec![
make_assistant_with_tool_call("call_skill", "activate_skill"),
make_tool_result(
"call_skill",
"You are a code review agent. Follow these instructions...",
),
make_assistant_msg("Skill activated"),
make_assistant_with_tool_call("call_read", "read_file"),
make_tool_result(
"call_read",
"file contents that are long enough to be masked by observation masking because they exceed one hundred characters easily",
),
make_assistant_msg("got it"),
make_assistant_with_tool_call("call_bash", "bash"),
make_tool_result("call_bash", "command output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(
extract_text(&result.messages[1].content),
"You are a code review agent. Follow these instructions..."
);
assert!(extract_text(&result.messages[4].content).starts_with('['));
assert_eq!(extract_text(&result.messages[7].content), "command output");
assert_eq!(result.masked_count, 1);
}
#[test]
fn test_masking_all_activate_skill_exempt_from_count() {
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", "Skill 1 instructions"),
make_assistant_with_tool_call("s2", "activate_skill"),
make_tool_result("s2", "Skill 2 instructions"),
make_assistant_with_tool_call("c1", "bash"),
make_tool_result("c1", "output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
assert_eq!(
extract_text(&result.messages[1].content),
"Skill 1 instructions"
);
assert_eq!(
extract_text(&result.messages[3].content),
"Skill 2 instructions"
);
}
#[test]
fn test_aggressive_trim_preserves_skill_messages() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_assistant_with_tool_call("skill1", "activate_skill"),
make_tool_result("skill1", "Important skill instructions"),
make_user_msg(&"a".repeat(400)), make_assistant_msg(&"b".repeat(400)), make_user_msg(&"c".repeat(400)), make_assistant_msg(&"d".repeat(400)), ];
let target_tokens = 400;
let result = aggressive_trim(&messages, target_tokens, true);
let has_skill_result = result.iter().any(|m| {
m.role == LlmMessageRole::Tool
&& extract_text(&m.content) == "Important skill instructions"
});
assert!(
has_skill_result,
"Skill tool result must survive aggressive trim"
);
let has_skill_call = result.iter().any(|m| {
m.tool_calls
.as_ref()
.is_some_and(|calls| calls.iter().any(|tc| tc.name == "activate_skill"))
});
assert!(
has_skill_call,
"Skill tool call must survive aggressive trim"
);
}
#[test]
fn test_hierarchical_memory_rescues_skill_from_cold_tier() {
let mut messages = Vec::new();
messages.push(make_assistant_with_tool_call("skill1", "activate_skill"));
messages.push(make_tool_result(
"skill1",
"You must always validate input.",
));
for i in 0..8 {
let id = format!("old_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("old content {i}")));
}
for i in 0..3 {
let id = format!("mid_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("mid output {i}")));
}
messages.push(make_user_msg("what now?"));
messages.push(make_assistant_msg("let me check"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 6,
};
let masking_config = ObservationMaskingConfig::default();
let result = apply_hierarchical_memory(
&messages,
&config,
&masking_config,
Some("Summary of old work"),
);
let has_skill_instructions = result
.iter()
.any(|m| extract_text(&m.content).contains("You must always validate input."));
assert!(
has_skill_instructions,
"Skill instructions from cold tier must be rescued into output"
);
assert!(extract_text(&result[0].content).contains("CONVERSATION_SUMMARY"));
}
#[test]
fn test_is_protected_tool_result_detection() {
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", "skill content"),
make_assistant_with_tool_call("r1", "read_file"),
make_tool_result("r1", "file content"),
];
assert!(is_protected_tool_result(&messages, &messages[1]));
assert!(!is_protected_tool_result(&messages, &messages[3]));
assert!(!is_protected_tool_result(&messages, &messages[0]));
}
#[test]
fn test_is_protected_tool_call_message_detection() {
let skill_call = make_assistant_with_tool_call("s1", "activate_skill");
let regular_call = make_assistant_with_tool_call("r1", "read_file");
let user_msg = make_user_msg("hello");
assert!(is_protected_tool_call_message(&skill_call));
assert!(!is_protected_tool_call_message(®ular_call));
assert!(!is_protected_tool_call_message(&user_msg));
}
#[test]
fn test_default_preserve_includes_skill_instructions() {
let config = SummarizationConfig::default();
assert!(
config.preserve.contains(&"skill_instructions".to_string()),
"Default preserve list must include skill_instructions"
);
}
#[test]
fn test_summarization_prompt_mentions_skill_protection() {
let config = SummarizationConfig::default();
let prompt = build_summarization_prompt(&config);
assert!(
prompt.contains("activate_skill"),
"Summarization prompt must instruct LLM to preserve skill content"
);
}
#[test]
fn test_aggressive_trim_protected_exceed_budget() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_assistant_with_tool_call("skill1", "activate_skill"), make_tool_result("skill1", &"x".repeat(800)), make_assistant_with_tool_call("skill2", "activate_skill"), make_tool_result("skill2", &"y".repeat(800)), make_user_msg(&"z".repeat(400)), ];
let result = aggressive_trim(&messages, 200, true);
let has_non_protected = result
.iter()
.any(|m| m.role == LlmMessageRole::User && extract_text(&m.content).contains('z'));
assert!(
!has_non_protected,
"Non-protected messages must be dropped when protected exceed budget"
);
}
#[test]
fn test_format_messages_no_truncate_protected_tool_result() {
let long_instructions = "a".repeat(5000);
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", &long_instructions),
make_assistant_with_tool_call("r1", "read_file"),
make_tool_result("r1", &"b".repeat(5000)),
];
let formatted = format_messages_for_summarization(&messages);
assert!(
formatted.contains(&long_instructions),
"Protected tool result must not be truncated"
);
assert!(
formatted.contains("[truncated, 5000 chars total]"),
"Non-protected tool result should be truncated"
);
}
#[test]
fn test_hierarchical_memory_cross_tier_boundary_protection() {
let mut messages = Vec::new();
messages.push(make_assistant_with_tool_call("skill1", "activate_skill"));
for i in 0..9 {
let id = format!("cold_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("cold content {i}")));
}
messages.push(make_tool_result(
"skill1",
"Cross-tier skill instructions that must survive",
));
for i in 0..2 {
let id = format!("warm_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("warm output {i}")));
}
messages.push(make_user_msg("continue"));
messages.push(make_assistant_msg("ok"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 5, };
let masking_config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_hierarchical_memory(&messages, &config, &masking_config, None);
let has_skill_instructions = result.iter().any(|m| {
extract_text(&m.content).contains("Cross-tier skill instructions that must survive")
});
assert!(
has_skill_instructions,
"Skill result in warm tier with call in cold tier must be protected"
);
}
}