use super::{Capability, CapabilityStatus, ModelViewContext, ModelViewProvider};
use crate::events::TokenUsage;
use crate::message::{ContentPart, Message, MessageRole};
use crate::message_filter::MessageFilterProvider;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub const COMPACTION_CAPABILITY_ID: &str = "compaction";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CompactionStrategy {
#[default]
Auto,
Native,
ObservationMasking,
Summarization,
}
impl std::fmt::Display for CompactionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Auto => write!(f, "auto"),
Self::Native => write!(f, "native"),
Self::ObservationMasking => write!(f, "observation_masking"),
Self::Summarization => write!(f, "summarization"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MaskingSummaryFormat {
#[default]
OneLine,
HeadTail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObservationMaskingConfig {
#[serde(default = "default_keep_recent_tool_outputs")]
pub keep_recent_tool_outputs: usize,
#[serde(default)]
pub summary_format: MaskingSummaryFormat,
}
impl Default for ObservationMaskingConfig {
fn default() -> Self {
Self {
keep_recent_tool_outputs: default_keep_recent_tool_outputs(),
summary_format: MaskingSummaryFormat::default(),
}
}
}
fn default_keep_recent_tool_outputs() -> usize {
2
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostControlConfig {
#[serde(default = "default_cost_control_enabled")]
pub enabled: bool,
#[serde(default = "default_cost_control_keep_recent_tool_results")]
pub keep_recent_tool_results: usize,
#[serde(default = "default_cost_control_mask_after_tool_results")]
pub mask_after_tool_results: usize,
#[serde(default = "default_cost_control_max_live_tool_result_bytes")]
pub max_live_tool_result_bytes: usize,
#[serde(default = "default_cost_control_max_uncached_input_tokens")]
pub max_uncached_input_tokens: u32,
#[serde(default = "default_cost_control_min_cache_read_ratio")]
pub min_cache_read_ratio: f32,
}
impl Default for CostControlConfig {
fn default() -> Self {
Self {
enabled: default_cost_control_enabled(),
keep_recent_tool_results: default_cost_control_keep_recent_tool_results(),
mask_after_tool_results: default_cost_control_mask_after_tool_results(),
max_live_tool_result_bytes: default_cost_control_max_live_tool_result_bytes(),
max_uncached_input_tokens: default_cost_control_max_uncached_input_tokens(),
min_cache_read_ratio: default_cost_control_min_cache_read_ratio(),
}
}
}
fn default_cost_control_enabled() -> bool {
true
}
fn default_cost_control_keep_recent_tool_results() -> usize {
2
}
fn default_cost_control_mask_after_tool_results() -> usize {
4
}
fn default_cost_control_max_live_tool_result_bytes() -> usize {
24 * 1024
}
fn default_cost_control_max_uncached_input_tokens() -> u32 {
100_000
}
fn default_cost_control_min_cache_read_ratio() -> f32 {
0.35
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SummarizationConfig {
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_preserve")]
pub preserve: Vec<String>,
#[serde(default)]
pub instructions: Option<String>,
}
impl Default for SummarizationConfig {
fn default() -> Self {
Self {
model: None,
preserve: default_preserve(),
instructions: None,
}
}
}
fn default_preserve() -> Vec<String> {
vec![
"decisions".to_string(),
"files_modified".to_string(),
"errors".to_string(),
"current_plan".to_string(),
"skill_instructions".to_string(),
]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionConfig {
#[serde(default)]
pub strategy: CompactionStrategy,
#[serde(default = "default_proactive")]
pub proactive: bool,
#[serde(default = "default_budget_percent")]
pub budget_percent: f32,
#[serde(default)]
pub observation_masking: ObservationMaskingConfig,
#[serde(default)]
pub summarization: SummarizationConfig,
#[serde(default)]
pub memory_tiers: HierarchicalMemoryConfig,
#[serde(default)]
pub cost_control: CostControlConfig,
}
impl Default for CompactionConfig {
fn default() -> Self {
Self {
strategy: CompactionStrategy::default(),
proactive: default_proactive(),
budget_percent: default_budget_percent(),
observation_masking: ObservationMaskingConfig::default(),
summarization: SummarizationConfig::default(),
memory_tiers: HierarchicalMemoryConfig::default(),
cost_control: CostControlConfig::default(),
}
}
}
fn default_proactive() -> bool {
true
}
fn default_budget_percent() -> f32 {
0.85
}
impl CompactionConfig {
pub fn from_json(value: &serde_json::Value) -> Self {
serde_json::from_value(value.clone()).unwrap_or_default()
}
}
pub struct CompactionCapability;
impl Capability for CompactionCapability {
fn id(&self) -> &str {
COMPACTION_CAPABILITY_ID
}
fn name(&self) -> &str {
"Compaction"
}
fn description(&self) -> &str {
r#"Configurable context compaction when conversations exceed LLM context windows.
Choose between native provider compaction (e.g., OpenAI /responses/compact), observation masking (strip old tool outputs), or LLM summarization. The `auto` strategy cascades through all available options."#
}
fn status(&self) -> CapabilityStatus {
CapabilityStatus::Available
}
fn icon(&self) -> Option<&str> {
Some("shrink")
}
fn category(&self) -> Option<&str> {
Some("Optimization")
}
fn message_filter_provider(&self) -> Option<Arc<dyn MessageFilterProvider>> {
Some(Arc::new(CompactionFilterProvider))
}
fn model_view_provider(&self) -> Option<Arc<dyn ModelViewProvider>> {
Some(Arc::new(CompactionModelViewProvider))
}
}
struct CompactionModelViewProvider;
impl ModelViewProvider for CompactionModelViewProvider {
fn apply_model_view(
&self,
messages: Vec<Message>,
config: &serde_json::Value,
context: &ModelViewContext<'_>,
) -> Vec<Message> {
let config = CompactionConfig::from_json(config);
let masking = build_model_view_messages_owned(messages, &config, context.prior_usage);
if masking.masked_count > 0 {
tracing::info!(
session_id = %context.session_id,
masked_count = masking.masked_count,
tool_result_bytes_before = masking.tool_result_bytes_before,
tool_result_bytes_after = masking.tool_result_bytes_after,
"CompactionCapability: masked stale tool results for model view"
);
}
masking.messages
}
fn priority(&self) -> i32 {
50
}
}
struct CompactionFilterProvider;
impl MessageFilterProvider for CompactionFilterProvider {
fn apply_filters(
&self,
_query: &mut crate::message_filter::MessageQuery,
_config: &serde_json::Value,
) {
}
fn priority(&self) -> i32 {
50 }
}
pub fn estimate_tokens(msg: &LlmMessage) -> usize {
let text_len = match &msg.content {
LlmMessageContent::Text(t) => t.len(),
LlmMessageContent::Parts(parts) => parts
.iter()
.map(|p| match p {
LlmContentPart::Text { text } => text.len(),
_ => 50, })
.sum(),
};
let tool_call_len = msg
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| tc.name.len() + tc.arguments.to_string().len() + 20)
.sum::<usize>()
})
.unwrap_or(0);
(text_len + tool_call_len) / 4
}
pub fn estimate_total_tokens(messages: &[LlmMessage]) -> usize {
messages.iter().map(estimate_tokens).sum()
}
pub fn should_compact_proactively(
messages: &[LlmMessage],
config: &CompactionConfig,
context_window_tokens: usize,
) -> bool {
if !config.proactive {
return false;
}
let budget = (context_window_tokens as f32 * config.budget_percent) as usize;
let estimated = estimate_total_tokens(messages);
estimated > budget
}
pub fn aggressive_trim(
messages: &[LlmMessage],
target_tokens: usize,
has_system_prompt: bool,
) -> Vec<LlmMessage> {
let mut result = Vec::new();
let mut token_budget = target_tokens;
let start_idx = if has_system_prompt && !messages.is_empty() {
let sys_tokens = estimate_tokens(&messages[0]);
if sys_tokens < token_budget {
result.push(messages[0].clone());
token_budget -= sys_tokens;
}
1
} else {
0
};
let conversation = &messages[start_idx..];
let protected_indices: std::collections::HashSet<usize> = conversation
.iter()
.enumerate()
.filter(|(_, m)| {
is_protected_tool_result(conversation, m) || is_protected_tool_call_message(m)
})
.map(|(i, _)| i)
.collect();
let mut protected_budget: usize = 0;
for &idx in &protected_indices {
protected_budget += estimate_tokens(&conversation[idx]);
}
if protected_budget > token_budget {
let mut protected_with_indices: Vec<(usize, LlmMessage)> = protected_indices
.iter()
.map(|&idx| (idx, conversation[idx].clone()))
.collect();
protected_with_indices.sort_by_key(|(i, _)| *i);
let mut remaining = token_budget;
let mut kept: Vec<(usize, LlmMessage)> = Vec::new();
for (idx, msg) in protected_with_indices.into_iter().rev() {
let t = estimate_tokens(&msg);
if t <= remaining {
kept.push((idx, msg));
remaining -= t;
}
}
kept.sort_by_key(|(i, _)| *i);
result.extend(kept.into_iter().map(|(_, m)| m));
return result;
}
token_budget -= protected_budget;
let mut keep_from_end = Vec::new();
for (i, msg) in conversation.iter().enumerate().rev() {
if protected_indices.contains(&i) {
continue; }
let msg_tokens = estimate_tokens(msg);
if msg_tokens <= token_budget {
keep_from_end.push((i, msg.clone()));
token_budget -= msg_tokens;
} else {
break;
}
}
let mut all_kept: Vec<(usize, LlmMessage)> = Vec::new();
for &idx in &protected_indices {
all_kept.push((idx, conversation[idx].clone()));
}
all_kept.extend(keep_from_end);
all_kept.sort_by_key(|(i, _)| *i);
result.extend(all_kept.into_iter().map(|(_, m)| m));
result
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionCompactionMetrics {
pub compaction_count: u32,
pub total_messages_saved: u64,
pub strategy_counts: HashMap<String, u32>,
pub total_duration_ms: u64,
}
impl SessionCompactionMetrics {
pub fn record(
&mut self,
strategy_used: &str,
messages_before: usize,
messages_after: usize,
duration_ms: u64,
) {
self.compaction_count += 1;
self.total_messages_saved += (messages_before.saturating_sub(messages_after)) as u64;
self.total_duration_ms += duration_ms;
for strategy in strategy_used.split('+') {
*self
.strategy_counts
.entry(strategy.to_string())
.or_insert(0) += 1;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryTier {
Hot,
Warm,
Cold,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalMemoryConfig {
#[serde(default = "default_hot_messages")]
pub hot_messages: usize,
#[serde(default = "default_warm_messages")]
pub warm_messages: usize,
}
impl Default for HierarchicalMemoryConfig {
fn default() -> Self {
Self {
hot_messages: default_hot_messages(),
warm_messages: default_warm_messages(),
}
}
}
fn default_hot_messages() -> usize {
20
}
fn default_warm_messages() -> usize {
100
}
pub fn classify_memory_tiers<'a>(
messages: &'a [LlmMessage],
config: &HierarchicalMemoryConfig,
) -> Vec<(MemoryTier, &'a LlmMessage)> {
let len = messages.len();
messages
.iter()
.enumerate()
.map(|(i, msg)| {
let from_end = len - 1 - i;
let tier = if from_end < config.hot_messages {
MemoryTier::Hot
} else if from_end < config.hot_messages + config.warm_messages {
MemoryTier::Warm
} else {
MemoryTier::Cold
};
(tier, msg)
})
.collect()
}
pub fn apply_hierarchical_memory(
messages: &[LlmMessage],
config: &HierarchicalMemoryConfig,
masking_config: &ObservationMaskingConfig,
cold_summary: Option<&str>,
) -> Vec<LlmMessage> {
let len = messages.len();
let hot_start = len.saturating_sub(config.hot_messages);
let warm_start = hot_start.saturating_sub(config.warm_messages);
let mut result = Vec::new();
if warm_start > 0 {
let cold_msgs = &messages[..warm_start];
let protected_cold: Vec<LlmMessage> = cold_msgs
.iter()
.filter(|m| is_protected_tool_result(cold_msgs, m) || is_protected_tool_call_message(m))
.cloned()
.collect();
if let Some(summary) = cold_summary {
result.push(build_summary_message(summary));
}
result.extend(protected_cold);
}
if warm_start < hot_start {
let warm_msgs = &messages[warm_start..hot_start];
let protected_call_ids: std::collections::HashSet<String> = warm_msgs
.iter()
.filter(|m| is_protected_tool_result(messages, m))
.filter_map(|m| m.tool_call_id.clone())
.collect();
let masked = apply_observation_masking_with_protected(
warm_msgs,
masking_config,
&protected_call_ids,
);
result.extend(masked.messages);
}
if hot_start < len {
result.extend_from_slice(&messages[hot_start..]);
}
result
}
use crate::llm_driver_registry::{LlmContentPart, LlmMessage, LlmMessageContent, LlmMessageRole};
const PROTECTED_TOOL_NAMES: &[&str] = &["activate_skill"];
fn is_protected_tool_result(messages: &[LlmMessage], tool_msg: &LlmMessage) -> bool {
if tool_msg.role != LlmMessageRole::Tool {
return false;
}
let tool_name = find_tool_call_name(messages, tool_msg);
PROTECTED_TOOL_NAMES.contains(&tool_name.as_str())
}
fn is_protected_tool_call_message(msg: &LlmMessage) -> bool {
if msg.role != LlmMessageRole::Assistant {
return false;
}
msg.tool_calls.as_ref().is_some_and(|calls| {
calls
.iter()
.any(|tc| PROTECTED_TOOL_NAMES.contains(&tc.name.as_str()))
})
}
#[derive(Debug)]
pub struct ObservationMaskingResult {
pub messages: Vec<LlmMessage>,
pub masked_count: usize,
}
pub fn apply_observation_masking(
messages: &[LlmMessage],
config: &ObservationMaskingConfig,
) -> ObservationMaskingResult {
apply_observation_masking_with_protected(messages, config, &std::collections::HashSet::new())
}
#[derive(Debug)]
pub struct CostControlMaskingResult {
pub messages: Vec<Message>,
pub masked_count: usize,
pub tool_result_bytes_before: usize,
pub tool_result_bytes_after: usize,
}
pub fn build_model_view_messages(
stored_messages: &[Message],
compaction_config: &CompactionConfig,
prior_usage: Option<&TokenUsage>,
) -> CostControlMaskingResult {
apply_cost_control_masking(stored_messages, compaction_config, prior_usage)
}
pub fn build_model_view_messages_owned(
stored_messages: Vec<Message>,
compaction_config: &CompactionConfig,
prior_usage: Option<&TokenUsage>,
) -> CostControlMaskingResult {
apply_cost_control_masking_owned(stored_messages, compaction_config, prior_usage)
}
pub fn apply_cost_control_masking(
messages: &[Message],
config: &CompactionConfig,
prior_usage: Option<&TokenUsage>,
) -> CostControlMaskingResult {
apply_cost_control_masking_owned(messages.to_vec(), config, prior_usage)
}
fn apply_cost_control_masking_owned(
messages: Vec<Message>,
config: &CompactionConfig,
prior_usage: Option<&TokenUsage>,
) -> CostControlMaskingResult {
let cost_config = &config.cost_control;
let tool_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, message)| {
message.role == MessageRole::ToolResult
&& !is_protected_message_tool_result(&messages, message)
})
.map(|(index, _)| index)
.collect();
let tool_result_bytes_before = tool_indices
.iter()
.map(|index| message_tool_result_len(&messages[*index]))
.sum();
if !cost_config.enabled
|| tool_indices.len() <= cost_config.keep_recent_tool_results
|| !should_apply_cost_control_masking(
tool_indices.len(),
tool_result_bytes_before,
cost_config,
prior_usage,
)
{
return CostControlMaskingResult {
messages,
masked_count: 0,
tool_result_bytes_before,
tool_result_bytes_after: tool_result_bytes_before,
};
}
let keep_recent = cost_config.keep_recent_tool_results;
let to_mask_count = tool_indices.len().saturating_sub(keep_recent);
let indices_to_mask: std::collections::HashSet<usize> =
tool_indices[..to_mask_count].iter().copied().collect();
let tool_names: std::collections::HashMap<usize, String> = indices_to_mask
.iter()
.map(|index| {
(
*index,
find_message_tool_call_name(&messages, &messages[*index]),
)
})
.collect();
let mut masked_count = 0;
let mut masked_messages = Vec::with_capacity(messages.len());
for (index, message) in messages.into_iter().enumerate() {
if let Some(tool_name) = tool_names.get(&index) {
masked_messages.push(mask_tool_result_message(&message, tool_name));
masked_count += 1;
} else {
masked_messages.push(message);
}
}
let tool_result_bytes_after = masked_messages
.iter()
.filter(|message| message.role == MessageRole::ToolResult)
.map(message_tool_result_len)
.sum();
CostControlMaskingResult {
messages: masked_messages,
masked_count,
tool_result_bytes_before,
tool_result_bytes_after,
}
}
fn should_apply_cost_control_masking(
tool_result_count: usize,
tool_result_bytes: usize,
config: &CostControlConfig,
prior_usage: Option<&TokenUsage>,
) -> bool {
if tool_result_count >= config.mask_after_tool_results {
return true;
}
if tool_result_bytes >= config.max_live_tool_result_bytes {
return true;
}
let Some(usage) = prior_usage else {
return false;
};
let cache_read = usage.cache_read_tokens.unwrap_or(0);
let uncached = usage.input_tokens.saturating_sub(cache_read);
if uncached >= config.max_uncached_input_tokens {
return true;
}
usage.input_tokens > 0
&& (cache_read as f32 / usage.input_tokens as f32) < config.min_cache_read_ratio
}
fn is_protected_message_tool_result(messages: &[Message], tool_msg: &Message) -> bool {
if tool_msg.role != MessageRole::ToolResult {
return false;
}
let tool_name = find_message_tool_call_name(messages, tool_msg);
PROTECTED_TOOL_NAMES.contains(&tool_name.as_str())
}
fn find_message_tool_call_name(messages: &[Message], tool_msg: &Message) -> String {
let Some(call_id) = tool_msg.tool_call_id() else {
return "unknown_tool".to_string();
};
for msg in messages.iter().rev() {
if msg.role != MessageRole::Agent {
continue;
}
for tool_call in msg.tool_calls() {
if tool_call.id == call_id {
return tool_call.name.clone();
}
}
}
"unknown_tool".to_string()
}
fn message_tool_result_len(message: &Message) -> usize {
let Some(result) = message.tool_result_content() else {
return 0;
};
result
.result
.as_ref()
.map(estimate_json_value_len)
.unwrap_or(0)
+ result.error.as_ref().map_or(0, String::len)
}
fn mask_tool_result_message(message: &Message, tool_name: &str) -> Message {
let Some(result) = message.tool_result_content() else {
return message.clone();
};
let summary = summarize_tool_result(tool_name, result.result.as_ref(), result.error.as_ref());
let was_error = result.error.is_some();
let mut masked = message.clone();
for part in &mut masked.content {
if let ContentPart::ToolResult(tool_result) = part {
if was_error {
tool_result.result = None;
tool_result.error = Some(summary);
} else {
tool_result.result = Some(serde_json::json!({
"masked": true,
"summary": summary,
}));
tool_result.error = None;
}
break;
}
}
masked
}
fn summarize_tool_result(
tool_name: &str,
result: Option<&serde_json::Value>,
error: Option<&String>,
) -> String {
if let Some(error) = error {
return format!("[{tool_name} error: {}]", truncate_inline(error, 160));
}
let Some(value) = result else {
return format!("[{tool_name} returned no result]");
};
let Some(object) = value.as_object() else {
return format!(
"[{tool_name} -> {}, {} bytes]",
value_kind(value),
estimate_json_value_len(value)
);
};
match tool_name {
"read_file" | "daytona_read_file" | "sandbox_read_file" | "e2b_read_file"
| "docker_read_file" | "deno_read_file" | "sprites_read_file" | "read_github_file" => {
summarize_read_file_result(tool_name, object, value)
}
"bash" | "daytona_exec" | "sandbox_exec" | "e2b_exec" | "docker_exec" | "deno_exec" => {
summarize_exec_result(tool_name, object, value)
}
"list_directory" => summarize_list_directory_result(tool_name, object, value),
"grep_files" => summarize_grep_files_result(tool_name, object, value),
_ => summarize_generic_tool_result(tool_name, object, value),
}
}
fn summarize_read_file_result(
tool_name: &str,
object: &serde_json::Map<String, serde_json::Value>,
value: &serde_json::Value,
) -> String {
let path = object
.get("path")
.and_then(|v| v.as_str())
.unwrap_or("(unknown path)");
let lines = object.get("lines_shown").and_then(|v| v.as_object());
let line_range = lines
.and_then(|lines| {
let start = lines.get("start")?.as_u64()?;
let end = lines.get("end")?.as_u64()?;
Some(format!(" lines {start}-{end}"))
})
.unwrap_or_default();
let total_lines = object
.get("total_lines")
.and_then(|v| v.as_u64())
.map(|lines| format!(", total_lines={lines}"))
.unwrap_or_default();
let next_offset = object
.get("truncation")
.and_then(|v| v.as_object())
.and_then(|truncation| truncation.get("next_offset"))
.and_then(|v| v.as_u64())
.map(|offset| format!(", next_offset={offset}"))
.unwrap_or_default();
let hash = object
.get("content_hash")
.and_then(|v| v.as_str())
.map(|hash| format!(", hash={hash}"))
.unwrap_or_default();
let truncated = object
.get("truncated")
.and_then(|v| v.as_bool())
.unwrap_or(false);
format!(
"[{tool_name} {path}{line_range}, {} bytes, truncated={truncated}{total_lines}{next_offset}{hash}]",
estimate_json_value_len(value)
)
}
fn summarize_exec_result(
tool_name: &str,
object: &serde_json::Map<String, serde_json::Value>,
value: &serde_json::Value,
) -> String {
let exit = object
.get("exit_code")
.and_then(|v| v.as_i64())
.map(|code| format!(" exit={code}"))
.unwrap_or_default();
let stdout_len = object
.get("stdout")
.and_then(|v| v.as_str())
.map(|stdout| stdout.len())
.unwrap_or(0);
let stderr_len = object
.get("stderr")
.and_then(|v| v.as_str())
.map(|stderr| stderr.len())
.unwrap_or(0);
let full_output = object
.get("full_output")
.and_then(|v| v.as_str())
.map(|path| format!(", full_output={path}"))
.unwrap_or_default();
let total_lines = object
.get("total_lines")
.and_then(|v| v.as_u64())
.map(|lines| format!(", total_lines={lines}"))
.unwrap_or_default();
format!(
"[{tool_name}{exit}, stdout={} bytes, stderr={} bytes, result={} bytes{full_output}{total_lines}]",
stdout_len,
stderr_len,
estimate_json_value_len(value)
)
}
fn summarize_list_directory_result(
tool_name: &str,
object: &serde_json::Map<String, serde_json::Value>,
value: &serde_json::Value,
) -> String {
let path = object
.get("path")
.and_then(|v| v.as_str())
.unwrap_or("(unknown path)");
let count = object
.get("count")
.and_then(|v| v.as_u64())
.or_else(|| {
object
.get("entries")
.and_then(|v| v.as_array())
.map(|v| v.len() as u64)
})
.unwrap_or(0);
format!(
"[{tool_name} {path}, {count} entries, {} bytes]",
estimate_json_value_len(value)
)
}
fn summarize_grep_files_result(
tool_name: &str,
object: &serde_json::Map<String, serde_json::Value>,
value: &serde_json::Value,
) -> String {
let pattern = object
.get("pattern")
.and_then(|v| v.as_str())
.map(|pattern| format!(" pattern={:?}", truncate_inline(pattern, 80)))
.unwrap_or_default();
let match_count = object
.get("match_count")
.and_then(|v| v.as_u64())
.unwrap_or(0);
format!(
"[{tool_name}{pattern}, matches={match_count}, {} bytes]",
estimate_json_value_len(value)
)
}
fn summarize_generic_tool_result(
tool_name: &str,
object: &serde_json::Map<String, serde_json::Value>,
value: &serde_json::Value,
) -> String {
let keys = object.keys().take(5).cloned().collect::<Vec<_>>().join(",");
format!(
"[{tool_name} result, {} bytes, keys={keys}]",
estimate_json_value_len(value)
)
}
fn value_kind(value: &serde_json::Value) -> &'static str {
match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "bool",
serde_json::Value::Number(_) => "number",
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
}
}
fn estimate_json_value_len(value: &serde_json::Value) -> usize {
let mut writer = CountingWriter::default();
serde_json::to_writer(&mut writer, value)
.map(|_| writer.bytes)
.unwrap_or(0)
}
#[derive(Default)]
struct CountingWriter {
bytes: usize,
}
impl std::io::Write for CountingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.bytes += buf.len();
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn truncate_inline(text: &str, max_chars: usize) -> String {
if text.chars().count() <= max_chars {
return text.to_string();
}
let mut truncated = text.chars().take(max_chars).collect::<String>();
truncated.push_str("...");
truncated
}
fn apply_observation_masking_with_protected(
messages: &[LlmMessage],
config: &ObservationMaskingConfig,
extra_protected_call_ids: &std::collections::HashSet<String>,
) -> ObservationMaskingResult {
let tool_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| {
m.role == LlmMessageRole::Tool
&& !is_protected_tool_result(messages, m)
&& !m
.tool_call_id
.as_ref()
.is_some_and(|id| extra_protected_call_ids.contains(id))
})
.map(|(i, _)| i)
.collect();
if tool_indices.len() <= config.keep_recent_tool_outputs {
return ObservationMaskingResult {
messages: messages.to_vec(),
masked_count: 0,
};
}
let to_mask_count = tool_indices.len() - config.keep_recent_tool_outputs;
let indices_to_mask: std::collections::HashSet<usize> =
tool_indices[..to_mask_count].iter().copied().collect();
let mut result = Vec::with_capacity(messages.len());
let mut masked_count = 0;
for (i, msg) in messages.iter().enumerate() {
if indices_to_mask.contains(&i) {
let tool_name = find_tool_call_name(messages, msg);
let summary = match config.summary_format {
MaskingSummaryFormat::OneLine => format_one_line_summary(&tool_name, &msg.content),
MaskingSummaryFormat::HeadTail => format_head_tail_summary(&msg.content),
};
result.push(LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text(summary),
tool_calls: msg.tool_calls.clone(),
tool_call_id: msg.tool_call_id.clone(),
phase: msg.phase,
thinking: None,
thinking_signature: None,
});
masked_count += 1;
} else {
result.push(msg.clone());
}
}
ObservationMaskingResult {
messages: result,
masked_count,
}
}
fn find_tool_call_name(messages: &[LlmMessage], tool_msg: &LlmMessage) -> String {
let Some(ref call_id) = tool_msg.tool_call_id else {
return "unknown_tool".to_string();
};
for msg in messages.iter().rev() {
if msg.role == LlmMessageRole::Assistant
&& let Some(ref tool_calls) = msg.tool_calls
{
for tc in tool_calls {
if tc.id == *call_id {
return tc.name.clone();
}
}
}
}
"unknown_tool".to_string()
}
fn extract_text(content: &LlmMessageContent) -> String {
match content {
LlmMessageContent::Text(t) => t.clone(),
LlmMessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| {
if let LlmContentPart::Text { text } = p {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
}
}
fn format_one_line_summary(tool_name: &str, content: &LlmMessageContent) -> String {
let text = extract_text(content);
let line_count = text.lines().count();
let byte_len = text.len();
if byte_len <= 100 {
format!("[{tool_name} → {text}]")
} else {
format!("[{tool_name} → {line_count} lines, {byte_len} bytes]")
}
}
fn format_head_tail_summary(content: &LlmMessageContent) -> String {
let text = extract_text(content);
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= 6 {
return text;
}
let head: Vec<&str> = lines[..3].to_vec();
let tail: Vec<&str> = lines[lines.len() - 3..].to_vec();
format!(
"{}\n... ({} lines omitted) ...\n{}",
head.join("\n"),
lines.len() - 6,
tail.join("\n")
)
}
pub fn build_summarization_prompt(config: &SummarizationConfig) -> String {
let preserve_items = if config.preserve.is_empty() {
default_preserve()
} else {
config.preserve.clone()
};
let preserve_list = preserve_items
.iter()
.map(|item| format!("- {item}"))
.collect::<Vec<_>>()
.join("\n");
let custom_instructions = config
.instructions
.as_deref()
.map(|instr| format!("\n- {instr}"))
.unwrap_or_default();
format!(
r#"<task>
Summarize the following conversation history. The summary replaces these
messages in the agent's context window — it must contain everything the
agent needs to continue working.
</task>
<preserve>
{preserve_list}{custom_instructions}
</preserve>
<format>
Produce a structured summary. Use sections. Be concise but complete.
Do not include tool output verbatim — reference files by path.
IMPORTANT: Any activate_skill tool results contain durable skill instructions.
Include them verbatim in a dedicated "Active Skills" section — do not summarize
or paraphrase skill instructions.
</format>"#
)
}
pub fn format_messages_for_summarization(messages: &[LlmMessage]) -> String {
let mut parts = Vec::new();
for msg in messages {
let role = match msg.role {
LlmMessageRole::System => "system",
LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
LlmMessageRole::Tool => "tool",
};
let content = extract_text(&msg.content);
let is_protected = is_protected_tool_result(messages, msg);
let truncated = if !is_protected && content.len() > 2000 {
let safe_prefix = truncate_at_char_boundary(&content, 2000);
format!(
"{}... [truncated, {} chars total]",
safe_prefix,
content.len()
)
} else {
content
};
parts.push(format!("[{role}]: {truncated}"));
}
parts.join("\n\n")
}
fn truncate_at_char_boundary(content: &str, max_bytes: usize) -> &str {
if content.len() <= max_bytes {
return content;
}
if content.is_char_boundary(max_bytes) {
return &content[..max_bytes];
}
let mut end = max_bytes;
while end > 0 && !content.is_char_boundary(end) {
end -= 1;
}
&content[..end]
}
pub fn build_summary_message(summary_text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::System,
content: LlmMessageContent::Text(format!(
"[CONVERSATION_SUMMARY]\n{summary_text}\n[/CONVERSATION_SUMMARY]"
)),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionStep {
pub strategy: String,
pub messages_after: usize,
pub duration_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool_types::ToolCall;
use serde_json::json;
fn make_user_msg(text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::User,
content: LlmMessageContent::Text(text.to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_assistant_msg(text: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text(text.to_string()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_assistant_with_tool_call(call_id: &str, tool_name: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Assistant,
content: LlmMessageContent::Text(String::new()),
tool_calls: Some(vec![ToolCall {
id: call_id.to_string(),
name: tool_name.to_string(),
arguments: json!({"path": "src/main.rs"}),
}]),
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
fn make_tool_result(call_id: &str, output: &str) -> LlmMessage {
LlmMessage {
role: LlmMessageRole::Tool,
content: LlmMessageContent::Text(output.to_string()),
tool_calls: None,
tool_call_id: Some(call_id.to_string()),
phase: None,
thinking: None,
thinking_signature: None,
}
}
#[test]
fn test_capability_metadata() {
let cap = CompactionCapability;
assert_eq!(cap.id(), COMPACTION_CAPABILITY_ID);
assert_eq!(cap.name(), "Compaction");
assert_eq!(cap.status(), CapabilityStatus::Available);
assert_eq!(cap.category(), Some("Optimization"));
assert!(cap.tools().is_empty());
assert!(cap.message_filter_provider().is_some());
}
#[test]
fn test_default_config() {
let config = CompactionConfig::default();
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
assert!((config.budget_percent - 0.85).abs() < f32::EPSILON);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 2);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::OneLine
);
assert!(config.summarization.model.is_none());
assert_eq!(config.summarization.preserve.len(), 5);
assert!(config.summarization.instructions.is_none());
assert!(config.cost_control.enabled);
assert_eq!(config.cost_control.keep_recent_tool_results, 2);
}
#[test]
fn test_config_from_empty_json() {
let config = CompactionConfig::from_json(&json!({}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
}
#[test]
fn test_config_native_only() {
let config = CompactionConfig::from_json(&json!({"strategy": "native"}));
assert_eq!(config.strategy, CompactionStrategy::Native);
assert!(config.proactive);
}
#[test]
fn test_config_observation_masking_with_custom_settings() {
let config = CompactionConfig::from_json(&json!({
"strategy": "observation_masking",
"proactive": false,
"observation_masking": {
"keep_recent_tool_outputs": 10,
"summary_format": "head_tail"
}
}));
assert_eq!(config.strategy, CompactionStrategy::ObservationMasking);
assert!(!config.proactive);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 10);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::HeadTail
);
}
#[test]
fn test_config_cost_control_with_custom_settings() {
let config = CompactionConfig::from_json(&json!({
"cost_control": {
"enabled": true,
"keep_recent_tool_results": 1,
"mask_after_tool_results": 2,
"max_live_tool_result_bytes": 4096,
"max_uncached_input_tokens": 50000,
"min_cache_read_ratio": 0.5
}
}));
assert!(config.cost_control.enabled);
assert_eq!(config.cost_control.keep_recent_tool_results, 1);
assert_eq!(config.cost_control.mask_after_tool_results, 2);
assert_eq!(config.cost_control.max_live_tool_result_bytes, 4096);
assert_eq!(config.cost_control.max_uncached_input_tokens, 50000);
assert!((config.cost_control.min_cache_read_ratio - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_config_summarization_with_custom_model() {
let config = CompactionConfig::from_json(&json!({
"strategy": "summarization",
"summarization": {
"model": "claude-haiku-4-5-20251001",
"instructions": "Focus on API decisions",
"preserve": ["decisions", "errors"]
}
}));
assert_eq!(config.strategy, CompactionStrategy::Summarization);
assert_eq!(
config.summarization.model.as_deref(),
Some("claude-haiku-4-5-20251001")
);
assert_eq!(
config.summarization.instructions.as_deref(),
Some("Focus on API decisions")
);
assert_eq!(config.summarization.preserve.len(), 2);
}
fn make_message_tool_turn(
call_id: &str,
tool_name: &str,
result: serde_json::Value,
) -> Vec<Message> {
vec![
Message::assistant_with_tools(
"",
vec![ToolCall {
id: call_id.to_string(),
name: tool_name.to_string(),
arguments: json!({"path": "/workspace/src/lib.rs"}),
}],
),
Message::tool_result(call_id, Some(result), None),
]
}
#[test]
fn test_cost_control_masks_old_read_file_results() {
let mut messages = vec![Message::user("inspect files")];
for index in 0..5 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"read_file",
json!({
"path": "/workspace/src/lib.rs",
"content": format!("{}{}", "line\n".repeat(400), index),
"total_lines": 900,
"lines_shown": {"start": 1, "end": 400},
"truncated": true,
"content_hash": format!("sha256:{index}"),
"truncation": {"truncated": true, "next_offset": 400, "reason": "line_cap"}
}),
));
}
let config = CompactionConfig::from_json(&json!({
"cost_control": {
"keep_recent_tool_results": 2,
"mask_after_tool_results": 4
}
}));
let result = apply_cost_control_masking(&messages, &config, None);
assert_eq!(result.masked_count, 3);
assert!(result.tool_result_bytes_after < result.tool_result_bytes_before);
let first_tool = result.messages[2].tool_result_content().unwrap();
let masked = first_tool.result.as_ref().unwrap();
assert_eq!(masked["masked"], true);
let summary = masked["summary"].as_str().unwrap();
assert!(summary.contains("read_file"));
assert!(summary.contains("/workspace/src/lib.rs"));
assert!(summary.contains("lines 1-400"));
assert!(summary.contains("next_offset=400"));
assert!(!summary.contains("line\nline"));
let last_tool = result
.messages
.last()
.unwrap()
.tool_result_content()
.unwrap();
assert!(last_tool.result.as_ref().unwrap().get("content").is_some());
}
#[test]
fn test_model_view_masks_with_compaction_config() {
let mut messages = vec![Message::user("inspect files repeatedly")];
for index in 0..9 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"read_file",
json!({
"path": "/workspace/session_019e4c9dd1b17021af70ad3227361b16.jsonl",
"content": format!("{}{}", "large transcript line\n".repeat(1000), index),
"total_lines": 1000,
"lines_shown": {"start": 1, "end": 1000},
"truncated": false,
"content_hash": format!("sha256:{index}")
}),
));
}
let config = CompactionConfig::default();
let result = build_model_view_messages(&messages, &config, None);
assert_eq!(result.masked_count, 7);
assert!(result.tool_result_bytes_after < result.tool_result_bytes_before / 4);
let first_tool = result.messages[2].tool_result_content().unwrap();
let masked = first_tool.result.as_ref().unwrap();
assert_eq!(masked["masked"], true);
assert!(masked["summary"].as_str().unwrap().contains("read_file"));
let last_tool = result
.messages
.last()
.unwrap()
.tool_result_content()
.unwrap();
assert!(last_tool.result.as_ref().unwrap().get("content").is_some());
}
#[test]
fn test_compaction_capability_contributes_model_view_provider() {
let mut messages = vec![Message::user("inspect files repeatedly")];
for index in 0..9 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"read_file",
json!({
"path": "/workspace/src/lib.rs",
"content": format!("{}{}", "large file line\n".repeat(1000), index),
"total_lines": 1000,
"lines_shown": {"start": 1, "end": 1000},
"truncated": false
}),
));
}
let capability = CompactionCapability;
let provider = capability.model_view_provider().unwrap();
let context = ModelViewContext {
session_id: crate::typed_id::SessionId::new(),
prior_usage: None,
};
let result = provider.apply_model_view(messages, &json!({}), &context);
let first_tool = result[2].tool_result_content().unwrap();
assert_eq!(first_tool.result.as_ref().unwrap()["masked"], true);
let last_tool = result.last().unwrap().tool_result_content().unwrap();
assert!(last_tool.result.as_ref().unwrap().get("content").is_some());
}
#[test]
fn test_model_view_respects_disabled_cost_control_config() {
let mut messages = vec![Message::user("inspect files repeatedly")];
for index in 0..5 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"read_file",
json!({
"path": "/workspace/src/lib.rs",
"content": "line\n".repeat(400),
"total_lines": 400,
"lines_shown": {"start": 1, "end": 400},
"truncated": false
}),
));
}
let config = CompactionConfig::from_json(&json!({
"cost_control": {
"enabled": false,
"keep_recent_tool_results": 1,
"mask_after_tool_results": 2
}
}));
let result = build_model_view_messages(&messages, &config, None);
assert_eq!(result.masked_count, 0);
assert_eq!(
result.tool_result_bytes_after,
result.tool_result_bytes_before
);
}
#[test]
fn test_cost_control_uses_prior_usage_signal() {
let mut messages = vec![Message::user("run commands")];
for index in 0..3 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"bash",
json!({
"stdout": "small output",
"stderr": "",
"exit_code": 0,
"success": true
}),
));
}
let config = CompactionConfig::from_json(&json!({
"cost_control": {
"keep_recent_tool_results": 1,
"mask_after_tool_results": 99,
"max_live_tool_result_bytes": 999999,
"max_uncached_input_tokens": 1000
}
}));
let usage = TokenUsage::with_cache(10_000, 100, Some(0), None);
let result = apply_cost_control_masking(&messages, &config, Some(&usage));
assert_eq!(result.masked_count, 2);
let first_tool = result.messages[2].tool_result_content().unwrap();
let summary = first_tool.result.as_ref().unwrap()["summary"]
.as_str()
.unwrap();
assert!(summary.contains("bash exit=0"));
}
#[test]
fn test_model_view_uses_provider_cache_signal_from_compaction_config() {
let mut messages = vec![Message::user("run commands")];
for index in 0..3 {
messages.extend(make_message_tool_turn(
&format!("call_{index}"),
"bash",
json!({
"stdout": "small output",
"stderr": "",
"exit_code": 0,
"success": true
}),
));
}
let usage = TokenUsage::with_cache(150_000, 100, Some(0), None);
let config = CompactionConfig::default();
let result = build_model_view_messages(&messages, &config, Some(&usage));
assert_eq!(result.masked_count, 1);
let first_tool = result.messages[2].tool_result_content().unwrap();
assert_eq!(first_tool.result.as_ref().unwrap()["masked"], true);
}
#[test]
fn test_config_falls_back_to_defaults_for_invalid_json() {
let config = CompactionConfig::from_json(&json!({
"strategy": "nonexistent_strategy",
"budget_percent": "not-a-number"
}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
}
#[test]
fn test_config_partial_override() {
let config = CompactionConfig::from_json(&json!({
"budget_percent": 0.7,
"observation_masking": {
"keep_recent_tool_outputs": 3
}
}));
assert_eq!(config.strategy, CompactionStrategy::Auto);
assert!(config.proactive);
assert!((config.budget_percent - 0.7).abs() < f32::EPSILON);
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 3);
assert_eq!(
config.observation_masking.summary_format,
MaskingSummaryFormat::OneLine
);
}
#[test]
fn test_strategy_serialization_roundtrip() {
for strategy in [
CompactionStrategy::Auto,
CompactionStrategy::Native,
CompactionStrategy::ObservationMasking,
CompactionStrategy::Summarization,
] {
let json = serde_json::to_value(strategy).unwrap();
let deserialized: CompactionStrategy = serde_json::from_value(json).unwrap();
assert_eq!(strategy, deserialized);
}
}
#[test]
fn test_strategy_display() {
assert_eq!(CompactionStrategy::Auto.to_string(), "auto");
assert_eq!(CompactionStrategy::Native.to_string(), "native");
assert_eq!(
CompactionStrategy::ObservationMasking.to_string(),
"observation_masking"
);
assert_eq!(
CompactionStrategy::Summarization.to_string(),
"summarization"
);
}
#[test]
fn test_masking_format_serialization_roundtrip() {
for format in [
MaskingSummaryFormat::OneLine,
MaskingSummaryFormat::HeadTail,
] {
let json = serde_json::to_value(format).unwrap();
let deserialized: MaskingSummaryFormat = serde_json::from_value(json).unwrap();
assert_eq!(format, deserialized);
}
}
#[test]
fn test_budget_percent_boundary_values() {
let config = CompactionConfig::from_json(&json!({"budget_percent": 0.1}));
assert!((config.budget_percent - 0.1).abs() < f32::EPSILON);
let config = CompactionConfig::from_json(&json!({"budget_percent": 0.99}));
assert!((config.budget_percent - 0.99).abs() < f32::EPSILON);
}
#[test]
fn test_keep_recent_tool_outputs_zero() {
let config = CompactionConfig::from_json(&json!({
"observation_masking": {"keep_recent_tool_outputs": 0}
}));
assert_eq!(config.observation_masking.keep_recent_tool_outputs, 0);
}
#[test]
fn test_masking_no_tool_messages() {
let messages = vec![make_user_msg("hello"), make_assistant_msg("hi")];
let config = ObservationMaskingConfig::default();
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 0);
assert_eq!(result.messages.len(), 2);
}
#[test]
fn test_masking_fewer_than_keep_recent() {
let messages = vec![
make_user_msg("read file"),
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result("call_1", "file contents"),
make_assistant_msg("done"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 5,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 0);
}
#[test]
fn test_masking_masks_old_outputs() {
let messages = vec![
make_user_msg("start"),
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result(
"call_1",
"old file contents that are very long and should be masked by the observation masking strategy because it exceeds 100 chars",
),
make_assistant_msg("got it"),
make_user_msg("next"),
make_assistant_with_tool_call("call_2", "search"),
make_tool_result("call_2", "search results"),
make_assistant_msg("found it"),
make_user_msg("more"),
make_assistant_with_tool_call("call_3", "bash"),
make_tool_result("call_3", "command output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 2,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
let masked = &result.messages[2];
assert_eq!(masked.role, LlmMessageRole::Tool);
let text = extract_text(&masked.content);
assert!(
text.starts_with('['),
"Expected masked summary, got: {text}"
);
assert!(text.contains("read_file"), "Expected tool name: {text}");
assert_eq!(extract_text(&result.messages[6].content), "search results");
assert_eq!(extract_text(&result.messages[10].content), "command output");
}
#[test]
fn test_masking_preserves_tool_call_id() {
let messages = vec![
make_assistant_with_tool_call("call_1", "read_file"),
make_tool_result("call_1", "content"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.messages[1].tool_call_id, Some("call_1".to_string()));
}
#[test]
fn test_masking_head_tail_format() {
let long_output = (0..20)
.map(|i| format!("line {i}"))
.collect::<Vec<_>>()
.join("\n");
let messages = vec![
make_assistant_with_tool_call("call_1", "bash"),
make_tool_result("call_1", &long_output),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "recent output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::HeadTail,
};
let result = apply_observation_masking(&messages, &config);
let text = extract_text(&result.messages[1].content);
assert!(text.contains("line 0"), "Should contain first lines");
assert!(text.contains("line 19"), "Should contain last lines");
assert!(text.contains("lines omitted"), "Should indicate omissions");
}
#[test]
fn test_masking_short_output_inline() {
let messages = vec![
make_assistant_with_tool_call("call_1", "get_time"),
make_tool_result("call_1", "2024-01-01"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "ok"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
let text = extract_text(&result.messages[1].content);
assert!(text.contains("2024-01-01"), "Short output included: {text}");
}
#[test]
fn test_masking_all_when_keep_zero() {
let messages = vec![
make_assistant_with_tool_call("call_1", "a"),
make_tool_result("call_1", "output1"),
make_assistant_with_tool_call("call_2", "b"),
make_tool_result("call_2", "output2"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 2);
}
#[test]
fn test_masking_empty_messages() {
let result = apply_observation_masking(&[], &ObservationMaskingConfig::default());
assert_eq!(result.masked_count, 0);
assert!(result.messages.is_empty());
}
#[test]
fn test_masking_preserves_message_count() {
let messages = vec![
make_user_msg("start"),
make_assistant_with_tool_call("c1", "read_file"),
make_tool_result("c1", "content 1"),
make_assistant_msg("ok"),
make_user_msg("next"),
make_assistant_with_tool_call("c2", "bash"),
make_tool_result("c2", "content 2"),
make_assistant_msg("done"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.messages.len(), messages.len());
}
#[test]
fn test_masking_unknown_tool_call_id() {
let messages = vec![
make_tool_result("orphan", "some output"),
make_assistant_with_tool_call("call_2", "bash"),
make_tool_result("call_2", "recent"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
let text = extract_text(&result.messages[0].content);
assert!(text.contains("unknown_tool"), "Fallback name: {text}");
}
#[test]
fn test_masking_many_tool_calls_keeps_exactly_n() {
let mut messages = Vec::new();
for i in 0..10 {
let id = format!("call_{i}");
messages.push(make_assistant_with_tool_call(&id, &format!("tool_{i}")));
messages.push(make_tool_result(&id, &format!("output {i}")));
}
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 3,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 7);
assert_eq!(extract_text(&result.messages[15].content), "output 7");
assert_eq!(extract_text(&result.messages[17].content), "output 8");
assert_eq!(extract_text(&result.messages[19].content), "output 9");
}
#[test]
fn test_summarization_prompt_default() {
let config = SummarizationConfig::default();
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("<task>"));
assert!(prompt.contains("decisions"));
assert!(prompt.contains("files_modified"));
assert!(prompt.contains("errors"));
assert!(prompt.contains("current_plan"));
}
#[test]
fn test_summarization_prompt_custom_instructions() {
let config = SummarizationConfig {
instructions: Some("Focus on API changes".to_string()),
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("Focus on API changes"));
}
#[test]
fn test_summarization_prompt_custom_preserve() {
let config = SummarizationConfig {
preserve: vec!["auth_tokens".to_string(), "database_schema".to_string()],
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("auth_tokens"));
assert!(prompt.contains("database_schema"));
assert!(!prompt.contains("decisions"));
}
#[test]
fn test_summarization_prompt_empty_preserve_uses_defaults() {
let config = SummarizationConfig {
preserve: vec![],
..Default::default()
};
let prompt = build_summarization_prompt(&config);
assert!(prompt.contains("decisions"));
}
#[test]
fn test_format_messages_for_summarization() {
let messages = vec![
make_user_msg("What is 2+2?"),
make_assistant_msg("The answer is 4."),
];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("[user]: What is 2+2?"));
assert!(formatted.contains("[assistant]: The answer is 4."));
}
#[test]
fn test_format_messages_truncates_long_content() {
let long_content = "x".repeat(5000);
let messages = vec![make_user_msg(&long_content)];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("truncated"));
assert!(formatted.len() < long_content.len());
}
#[test]
fn test_format_messages_truncates_utf8_without_panic() {
let multibyte = "é".repeat(1001); let messages = vec![make_user_msg(&multibyte)];
let formatted = format_messages_for_summarization(&messages);
assert!(formatted.contains("truncated"));
assert!(formatted.contains("[truncated, 2002 chars total]"));
}
#[test]
fn test_build_summary_message() {
let msg = build_summary_message("The user asked about APIs.");
assert_eq!(msg.role, LlmMessageRole::System);
let text = extract_text(&msg.content);
assert!(text.contains("[CONVERSATION_SUMMARY]"));
assert!(text.contains("The user asked about APIs."));
assert!(text.contains("[/CONVERSATION_SUMMARY]"));
}
#[test]
fn test_head_tail_short_content_unchanged() {
let content = LlmMessageContent::Text("line1\nline2\nline3".to_string());
assert_eq!(format_head_tail_summary(&content), "line1\nline2\nline3");
}
#[test]
fn test_head_tail_exactly_six_lines() {
let content = LlmMessageContent::Text("1\n2\n3\n4\n5\n6".to_string());
assert_eq!(format_head_tail_summary(&content), "1\n2\n3\n4\n5\n6");
}
#[test]
fn test_head_tail_seven_lines() {
let content = LlmMessageContent::Text("1\n2\n3\n4\n5\n6\n7".to_string());
let result = format_head_tail_summary(&content);
assert!(result.contains("1\n2\n3"));
assert!(result.contains("5\n6\n7"));
assert!(result.contains("1 lines omitted"));
}
#[test]
fn test_one_line_empty_output() {
let result = format_one_line_summary("bash", &LlmMessageContent::Text(String::new()));
assert_eq!(result, "[bash → ]");
}
#[test]
fn test_one_line_exactly_100_chars() {
let text = "x".repeat(100);
let result = format_one_line_summary("bash", &LlmMessageContent::Text(text.clone()));
assert!(result.contains(&text));
}
#[test]
fn test_one_line_101_chars_summarized() {
let text = "x".repeat(101);
let result = format_one_line_summary("bash", &LlmMessageContent::Text(text));
assert!(result.contains("lines"));
assert!(result.contains("bytes"));
}
#[test]
fn test_one_line_multipart_content() {
let content = LlmMessageContent::Parts(vec![
LlmContentPart::Text {
text: "part1".to_string(),
},
LlmContentPart::Text {
text: "part2".to_string(),
},
]);
let result = format_one_line_summary("tool", &content);
assert!(result.contains("part1"));
assert!(result.contains("part2"));
}
#[test]
fn test_compaction_step_serialization() {
let step = CompactionStep {
strategy: "observation_masking".to_string(),
messages_after: 42,
duration_ms: 12,
};
let json = serde_json::to_value(&step).unwrap();
assert_eq!(json["strategy"], "observation_masking");
assert_eq!(json["messages_after"], 42);
assert_eq!(json["duration_ms"], 12);
}
#[test]
fn test_estimate_tokens_text() {
let msg = make_user_msg("hello world"); let tokens = estimate_tokens(&msg);
assert_eq!(tokens, 11 / 4);
}
#[test]
fn test_estimate_tokens_empty() {
let msg = make_user_msg("");
assert_eq!(estimate_tokens(&msg), 0);
}
#[test]
fn test_estimate_total_tokens() {
let messages = vec![
make_user_msg("a".repeat(400).as_str()), make_assistant_msg("b".repeat(200).as_str()), ];
assert_eq!(estimate_total_tokens(&messages), 150);
}
#[test]
fn test_estimate_tokens_with_tool_calls() {
let msg = make_assistant_with_tool_call("call_1", "read_file");
let tokens = estimate_tokens(&msg);
assert!(tokens > 0, "Tool call should contribute tokens");
}
#[test]
fn test_should_compact_proactively_under_budget() {
let messages = vec![make_user_msg("short")];
let config = CompactionConfig::default(); assert!(!should_compact_proactively(&messages, &config, 128_000));
}
#[test]
fn test_should_compact_proactively_over_budget() {
let big_text = "x".repeat(4000); let messages = vec![make_user_msg(&big_text)];
let config = CompactionConfig::default();
assert!(should_compact_proactively(&messages, &config, 1000));
}
#[test]
fn test_should_compact_proactively_disabled() {
let big_text = "x".repeat(4000);
let messages = vec![make_user_msg(&big_text)];
let config = CompactionConfig {
proactive: false,
..Default::default()
};
assert!(!should_compact_proactively(&messages, &config, 1000));
}
#[test]
fn test_aggressive_trim_keeps_newest() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_user_msg(&"a".repeat(400)), make_assistant_msg(&"b".repeat(400)), make_user_msg(&"c".repeat(400)), make_assistant_msg(&"d".repeat(400)), ];
let target_tokens = 300;
let result = aggressive_trim(&messages, target_tokens, true);
assert!(
result.len() < messages.len(),
"Expected trim, got {} messages",
result.len()
);
assert_eq!(result[0].role, LlmMessageRole::User);
}
#[test]
fn test_aggressive_trim_empty() {
let result = aggressive_trim(&[], 100, false);
assert!(result.is_empty());
}
#[test]
fn test_aggressive_trim_everything_fits() {
let messages = vec![make_user_msg("hi"), make_assistant_msg("hello")];
let result = aggressive_trim(&messages, 100_000, false);
assert_eq!(result.len(), 2);
}
#[test]
fn test_session_metrics_record() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("observation_masking+native", 100, 50, 200);
assert_eq!(metrics.compaction_count, 1);
assert_eq!(metrics.total_messages_saved, 50);
assert_eq!(metrics.total_duration_ms, 200);
assert_eq!(metrics.strategy_counts["observation_masking"], 1);
assert_eq!(metrics.strategy_counts["native"], 1);
}
#[test]
fn test_session_metrics_accumulate() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("observation_masking", 100, 80, 10);
metrics.record("summarization", 80, 40, 500);
assert_eq!(metrics.compaction_count, 2);
assert_eq!(metrics.total_messages_saved, 60);
assert_eq!(metrics.total_duration_ms, 510);
assert_eq!(metrics.strategy_counts["observation_masking"], 1);
assert_eq!(metrics.strategy_counts["summarization"], 1);
}
#[test]
fn test_session_metrics_serialization() {
let mut metrics = SessionCompactionMetrics::default();
metrics.record("auto", 50, 30, 100);
let json = serde_json::to_value(&metrics).unwrap();
assert_eq!(json["compaction_count"], 1);
assert_eq!(json["total_messages_saved"], 20);
}
#[test]
fn test_classify_memory_tiers_basic() {
let messages: Vec<LlmMessage> = (0..30)
.map(|i| make_user_msg(&format!("msg {i}")))
.collect();
let config = HierarchicalMemoryConfig {
hot_messages: 5,
warm_messages: 10,
};
let classified = classify_memory_tiers(&messages, &config);
assert_eq!(classified.len(), 30);
assert_eq!(classified[29].0, MemoryTier::Hot);
assert_eq!(classified[25].0, MemoryTier::Hot);
assert_eq!(classified[24].0, MemoryTier::Warm);
assert_eq!(classified[15].0, MemoryTier::Warm);
assert_eq!(classified[14].0, MemoryTier::Cold);
assert_eq!(classified[0].0, MemoryTier::Cold);
}
#[test]
fn test_classify_memory_tiers_all_hot() {
let messages: Vec<LlmMessage> =
(0..3).map(|i| make_user_msg(&format!("msg {i}"))).collect();
let config = HierarchicalMemoryConfig::default();
let classified = classify_memory_tiers(&messages, &config);
assert!(classified.iter().all(|(tier, _)| *tier == MemoryTier::Hot));
}
#[test]
fn test_apply_hierarchical_memory_basic() {
let mut messages = Vec::new();
for i in 0..5 {
let id = format!("old_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("old content {i}")));
}
for i in 0..3 {
let id = format!("mid_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("mid output {i}")));
}
messages.push(make_user_msg("what now?"));
messages.push(make_assistant_msg("let me check"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 6,
};
let masking_config = ObservationMaskingConfig::default();
let result = apply_hierarchical_memory(
&messages,
&config,
&masking_config,
Some("Summary of old work"),
);
assert!(result.len() <= 9);
let first_text = extract_text(&result[0].content);
assert!(first_text.contains("CONVERSATION_SUMMARY"));
let last = extract_text(&result[result.len() - 1].content);
assert!(last.contains("let me check"));
}
#[test]
fn test_apply_hierarchical_memory_no_cold() {
let messages = vec![make_user_msg("hello"), make_assistant_msg("hi")];
let config = HierarchicalMemoryConfig {
hot_messages: 5,
warm_messages: 5,
};
let result = apply_hierarchical_memory(
&messages,
&config,
&ObservationMaskingConfig::default(),
None,
);
assert_eq!(result.len(), 2);
}
#[test]
fn test_memory_tier_config_from_json() {
let config: HierarchicalMemoryConfig = serde_json::from_value(json!({
"hot_messages": 10,
"warm_messages": 50
}))
.unwrap();
assert_eq!(config.hot_messages, 10);
assert_eq!(config.warm_messages, 50);
}
#[test]
fn test_memory_tier_config_defaults() {
let config = HierarchicalMemoryConfig::default();
assert_eq!(config.hot_messages, 20);
assert_eq!(config.warm_messages, 100);
}
#[test]
fn test_compaction_config_with_memory_tiers() {
let config = CompactionConfig::from_json(&json!({
"strategy": "auto",
"memory_tiers": {
"hot_messages": 15,
"warm_messages": 80
}
}));
assert_eq!(config.memory_tiers.hot_messages, 15);
assert_eq!(config.memory_tiers.warm_messages, 80);
}
#[test]
fn test_memory_tier_serialization() {
assert_eq!(serde_json::to_value(MemoryTier::Hot).unwrap(), json!("hot"));
assert_eq!(
serde_json::to_value(MemoryTier::Warm).unwrap(),
json!("warm")
);
assert_eq!(
serde_json::to_value(MemoryTier::Cold).unwrap(),
json!("cold")
);
}
#[test]
fn test_masking_skips_activate_skill_results() {
let messages = vec![
make_assistant_with_tool_call("call_skill", "activate_skill"),
make_tool_result(
"call_skill",
"You are a code review agent. Follow these instructions...",
),
make_assistant_msg("Skill activated"),
make_assistant_with_tool_call("call_read", "read_file"),
make_tool_result(
"call_read",
"file contents that are long enough to be masked by observation masking because they exceed one hundred characters easily",
),
make_assistant_msg("got it"),
make_assistant_with_tool_call("call_bash", "bash"),
make_tool_result("call_bash", "command output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 1,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(
extract_text(&result.messages[1].content),
"You are a code review agent. Follow these instructions..."
);
assert!(extract_text(&result.messages[4].content).starts_with('['));
assert_eq!(extract_text(&result.messages[7].content), "command output");
assert_eq!(result.masked_count, 1);
}
#[test]
fn test_masking_all_activate_skill_exempt_from_count() {
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", "Skill 1 instructions"),
make_assistant_with_tool_call("s2", "activate_skill"),
make_tool_result("s2", "Skill 2 instructions"),
make_assistant_with_tool_call("c1", "bash"),
make_tool_result("c1", "output"),
];
let config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_observation_masking(&messages, &config);
assert_eq!(result.masked_count, 1);
assert_eq!(
extract_text(&result.messages[1].content),
"Skill 1 instructions"
);
assert_eq!(
extract_text(&result.messages[3].content),
"Skill 2 instructions"
);
}
#[test]
fn test_aggressive_trim_preserves_skill_messages() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_assistant_with_tool_call("skill1", "activate_skill"),
make_tool_result("skill1", "Important skill instructions"),
make_user_msg(&"a".repeat(400)), make_assistant_msg(&"b".repeat(400)), make_user_msg(&"c".repeat(400)), make_assistant_msg(&"d".repeat(400)), ];
let target_tokens = 400;
let result = aggressive_trim(&messages, target_tokens, true);
let has_skill_result = result.iter().any(|m| {
m.role == LlmMessageRole::Tool
&& extract_text(&m.content) == "Important skill instructions"
});
assert!(
has_skill_result,
"Skill tool result must survive aggressive trim"
);
let has_skill_call = result.iter().any(|m| {
m.tool_calls
.as_ref()
.is_some_and(|calls| calls.iter().any(|tc| tc.name == "activate_skill"))
});
assert!(
has_skill_call,
"Skill tool call must survive aggressive trim"
);
}
#[test]
fn test_hierarchical_memory_rescues_skill_from_cold_tier() {
let mut messages = Vec::new();
messages.push(make_assistant_with_tool_call("skill1", "activate_skill"));
messages.push(make_tool_result(
"skill1",
"You must always validate input.",
));
for i in 0..8 {
let id = format!("old_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("old content {i}")));
}
for i in 0..3 {
let id = format!("mid_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("mid output {i}")));
}
messages.push(make_user_msg("what now?"));
messages.push(make_assistant_msg("let me check"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 6,
};
let masking_config = ObservationMaskingConfig::default();
let result = apply_hierarchical_memory(
&messages,
&config,
&masking_config,
Some("Summary of old work"),
);
let has_skill_instructions = result
.iter()
.any(|m| extract_text(&m.content).contains("You must always validate input."));
assert!(
has_skill_instructions,
"Skill instructions from cold tier must be rescued into output"
);
assert!(extract_text(&result[0].content).contains("CONVERSATION_SUMMARY"));
}
#[test]
fn test_is_protected_tool_result_detection() {
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", "skill content"),
make_assistant_with_tool_call("r1", "read_file"),
make_tool_result("r1", "file content"),
];
assert!(is_protected_tool_result(&messages, &messages[1]));
assert!(!is_protected_tool_result(&messages, &messages[3]));
assert!(!is_protected_tool_result(&messages, &messages[0]));
}
#[test]
fn test_is_protected_tool_call_message_detection() {
let skill_call = make_assistant_with_tool_call("s1", "activate_skill");
let regular_call = make_assistant_with_tool_call("r1", "read_file");
let user_msg = make_user_msg("hello");
assert!(is_protected_tool_call_message(&skill_call));
assert!(!is_protected_tool_call_message(®ular_call));
assert!(!is_protected_tool_call_message(&user_msg));
}
#[test]
fn test_default_preserve_includes_skill_instructions() {
let config = SummarizationConfig::default();
assert!(
config.preserve.contains(&"skill_instructions".to_string()),
"Default preserve list must include skill_instructions"
);
}
#[test]
fn test_summarization_prompt_mentions_skill_protection() {
let config = SummarizationConfig::default();
let prompt = build_summarization_prompt(&config);
assert!(
prompt.contains("activate_skill"),
"Summarization prompt must instruct LLM to preserve skill content"
);
}
#[test]
fn test_aggressive_trim_protected_exceed_budget() {
let messages = vec![
make_user_msg(&"s".repeat(400)), make_assistant_with_tool_call("skill1", "activate_skill"), make_tool_result("skill1", &"x".repeat(800)), make_assistant_with_tool_call("skill2", "activate_skill"), make_tool_result("skill2", &"y".repeat(800)), make_user_msg(&"z".repeat(400)), ];
let result = aggressive_trim(&messages, 200, true);
let has_non_protected = result
.iter()
.any(|m| m.role == LlmMessageRole::User && extract_text(&m.content).contains('z'));
assert!(
!has_non_protected,
"Non-protected messages must be dropped when protected exceed budget"
);
}
#[test]
fn test_format_messages_no_truncate_protected_tool_result() {
let long_instructions = "a".repeat(5000);
let messages = vec![
make_assistant_with_tool_call("s1", "activate_skill"),
make_tool_result("s1", &long_instructions),
make_assistant_with_tool_call("r1", "read_file"),
make_tool_result("r1", &"b".repeat(5000)),
];
let formatted = format_messages_for_summarization(&messages);
assert!(
formatted.contains(&long_instructions),
"Protected tool result must not be truncated"
);
assert!(
formatted.contains("[truncated, 5000 chars total]"),
"Non-protected tool result should be truncated"
);
}
#[test]
fn test_hierarchical_memory_cross_tier_boundary_protection() {
let mut messages = Vec::new();
messages.push(make_assistant_with_tool_call("skill1", "activate_skill"));
for i in 0..9 {
let id = format!("cold_{i}");
messages.push(make_assistant_with_tool_call(&id, "read_file"));
messages.push(make_tool_result(&id, &format!("cold content {i}")));
}
messages.push(make_tool_result(
"skill1",
"Cross-tier skill instructions that must survive",
));
for i in 0..2 {
let id = format!("warm_{i}");
messages.push(make_assistant_with_tool_call(&id, "bash"));
messages.push(make_tool_result(&id, &format!("warm output {i}")));
}
messages.push(make_user_msg("continue"));
messages.push(make_assistant_msg("ok"));
let config = HierarchicalMemoryConfig {
hot_messages: 2,
warm_messages: 5, };
let masking_config = ObservationMaskingConfig {
keep_recent_tool_outputs: 0,
summary_format: MaskingSummaryFormat::OneLine,
};
let result = apply_hierarchical_memory(&messages, &config, &masking_config, None);
let has_skill_instructions = result.iter().any(|m| {
extract_text(&m.content).contains("Cross-tier skill instructions that must survive")
});
assert!(
has_skill_instructions,
"Skill result in warm tier with call in cold tier must be protected"
);
}
}