use zeph_context::summarization::extract_overflow_ref;
use zeph_llm::provider::MessagePart;
use crate::compaction::{
BlockScore, score_blocks_mig, score_blocks_subgoal, score_blocks_subgoal_mig,
score_blocks_task_aware,
};
use crate::state::ContextSummarizationView;
pub(crate) fn prune_tool_outputs(
summ: &mut ContextSummarizationView<'_>,
min_to_free: usize,
) -> usize {
use zeph_config::PruningStrategy;
match &summ.context_manager.compression.pruning_strategy {
PruningStrategy::TaskAware => prune_tool_outputs_scored(summ, min_to_free),
PruningStrategy::Mig => prune_tool_outputs_mig(summ, min_to_free),
PruningStrategy::Subgoal => prune_tool_outputs_subgoal(summ, min_to_free),
PruningStrategy::SubgoalMig => prune_tool_outputs_subgoal_mig(summ, min_to_free),
PruningStrategy::Reactive => prune_tool_outputs_oldest_first(summ, min_to_free),
}
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn prune_tool_outputs_oldest_first(
summ: &mut ContextSummarizationView<'_>,
min_to_free: usize,
) -> usize {
let protect = summ.context_manager.prune_protect_tokens;
let mut tail_tokens = 0usize;
let mut protection_boundary = summ.messages.len();
if protect > 0 {
for (i, msg) in summ.messages.iter().enumerate().rev() {
tail_tokens += summ.token_counter.count_message_tokens(msg);
if tail_tokens >= protect {
protection_boundary = i;
break;
}
if i == 0 {
protection_boundary = 0;
}
}
}
let mut freed = 0usize;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.cast_signed();
for msg in &mut summ.messages[..protection_boundary] {
if freed >= min_to_free {
break;
}
if msg.metadata.focus_pinned {
continue;
}
let mut modified = false;
for part in &mut msg.parts {
if let &mut MessagePart::ToolOutput {
ref mut body,
ref mut compacted_at,
..
} = part
&& compacted_at.is_none()
&& !body.is_empty()
&& !body.starts_with("[archived:")
{
freed += summ.token_counter.count_tokens(body);
let ref_notice = extract_overflow_ref(body)
.map(|p| format!("[tool output pruned; use read_overflow {p} to retrieve]"))
.unwrap_or_default();
freed -= summ.token_counter.count_tokens(&ref_notice);
*compacted_at = Some(now);
*body = ref_notice;
modified = true;
}
}
if modified {
msg.rebuild_content();
}
}
if freed > 0 {
if let Some(metrics) = summ.metrics {
metrics.record_tool_output_prune(1);
}
tracing::info!(freed, protection_boundary, "pruned tool outputs");
}
freed
}
fn prune_protection_boundary(summ: &ContextSummarizationView<'_>) -> usize {
let protect = summ.context_manager.prune_protect_tokens;
if protect == 0 {
return summ.messages.len();
}
let mut tail_tokens = 0usize;
let mut boundary = summ.messages.len();
for (i, msg) in summ.messages.iter().enumerate().rev() {
tail_tokens += summ.token_counter.count_message_tokens(msg);
if tail_tokens >= protect {
boundary = i;
break;
}
if i == 0 {
boundary = 0;
}
}
boundary
}
fn prune_tool_outputs_scored(summ: &mut ContextSummarizationView<'_>, min_to_free: usize) -> usize {
let goal = summ.current_task_goal.clone();
let mut scores = if let Some(ref goal) = goal {
score_blocks_task_aware(summ.messages, goal, &summ.token_counter)
} else {
return prune_tool_outputs_oldest_first(summ, min_to_free);
};
scores.sort_unstable_by(|a, b| {
a.relevance
.partial_cmp(&b.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
evict_sorted_blocks(summ, &scores, min_to_free, "task_aware")
}
fn prune_tool_outputs_mig(summ: &mut ContextSummarizationView<'_>, min_to_free: usize) -> usize {
let goal = summ.current_task_goal.as_deref();
let mut scores = score_blocks_mig(summ.messages, goal, &summ.token_counter);
scores.sort_unstable_by(|a, b| {
a.mig
.partial_cmp(&b.mig)
.unwrap_or(std::cmp::Ordering::Equal)
});
evict_sorted_blocks(summ, &scores, min_to_free, "mig")
}
fn prune_tool_outputs_subgoal(
summ: &mut ContextSummarizationView<'_>,
min_to_free: usize,
) -> usize {
let mut scores =
score_blocks_subgoal(summ.messages, summ.subgoal_registry, &summ.token_counter);
scores.sort_unstable_by(|a, b| {
a.relevance
.partial_cmp(&b.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
evict_sorted_blocks(summ, &scores, min_to_free, "subgoal")
}
fn prune_tool_outputs_subgoal_mig(
summ: &mut ContextSummarizationView<'_>,
min_to_free: usize,
) -> usize {
let mut scores =
score_blocks_subgoal_mig(summ.messages, summ.subgoal_registry, &summ.token_counter);
scores.sort_unstable_by(|a, b| {
a.mig
.partial_cmp(&b.mig)
.unwrap_or(std::cmp::Ordering::Equal)
});
evict_sorted_blocks(summ, &scores, min_to_free, "subgoal_mig")
}
fn evict_sorted_blocks(
summ: &mut ContextSummarizationView<'_>,
sorted_scores: &[BlockScore],
min_to_free: usize,
strategy: &str,
) -> usize {
let protection_boundary = prune_protection_boundary(summ);
let mut freed = 0usize;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.cast_signed();
let mut pruned_indices = Vec::new();
for block in sorted_scores {
if freed >= min_to_free {
break;
}
if block.msg_index >= protection_boundary {
continue;
}
let msg = &mut summ.messages[block.msg_index];
if msg.metadata.focus_pinned {
continue;
}
let mut modified = false;
for part in &mut msg.parts {
if let MessagePart::ToolOutput {
body, compacted_at, ..
} = part
&& compacted_at.is_none()
&& !body.is_empty()
{
freed += summ.token_counter.count_tokens(body);
let ref_notice = extract_overflow_ref(body)
.map(|p| format!("[tool output pruned; use read_overflow {p} to retrieve]"))
.unwrap_or_default();
freed -= summ.token_counter.count_tokens(&ref_notice);
*compacted_at = Some(now);
*body = ref_notice;
modified = true;
}
}
if modified {
pruned_indices.push(block.msg_index);
}
}
for &idx in &pruned_indices {
summ.messages[idx].rebuild_content();
}
if freed > 0 {
if let Some(metrics) = summ.metrics {
metrics.record_tool_output_prune(pruned_indices.len());
}
tracing::info!(
freed,
pruned = pruned_indices.len(),
strategy,
"pruned tool outputs"
);
}
freed
}