use std::cell::RefCell;
use std::collections::BTreeMap;
use std::thread_local;
use serde::{Deserialize, Serialize};
use super::{compact_strategy_name, parse_compact_strategy, CompactStrategy, CompactionPolicy};
use crate::value::VmValue;
pub const DEFAULT_SAFETY_RATIO: f64 = 0.7;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PolicyStrategy {
Summarize,
SummarizeThenPrune,
HeadAndTail,
Window,
ObservationMask,
Custom,
}
impl PolicyStrategy {
pub fn as_str(self) -> &'static str {
match self {
Self::Summarize => "summarize",
Self::SummarizeThenPrune => "summarize-then-prune",
Self::HeadAndTail => "head+tail",
Self::Window => "window",
Self::ObservationMask => "observation_mask",
Self::Custom => "custom",
}
}
pub fn parse(value: &str) -> Result<Self, String> {
match value.trim() {
"summarize" | "llm" => Ok(Self::Summarize),
"summarize-then-prune" | "summarize_then_prune" => Ok(Self::SummarizeThenPrune),
"head+tail" | "head-tail" | "head_tail" => Ok(Self::HeadAndTail),
"window" | "truncate" => Ok(Self::Window),
"observation_mask" | "observation-mask" | "mask" => Ok(Self::ObservationMask),
"custom" => Ok(Self::Custom),
other => Err(format!(
"unknown compaction policy strategy '{other}' (expected one of: summarize, \
summarize-then-prune, head+tail, window, observation_mask, custom)"
)),
}
}
pub fn engine_strategy(self) -> CompactStrategy {
match self {
Self::Summarize | Self::SummarizeThenPrune => CompactStrategy::Llm,
Self::HeadAndTail | Self::Window => CompactStrategy::Truncate,
Self::ObservationMask => CompactStrategy::ObservationMask,
Self::Custom => CompactStrategy::Custom,
}
}
pub fn engine_fallback(self) -> Option<CompactStrategy> {
match self {
Self::SummarizeThenPrune => Some(CompactStrategy::Truncate),
_ => None,
}
}
}
#[derive(Clone, Debug)]
pub struct CompactionPolicyDeclaration {
pub strategy: PolicyStrategy,
pub max_tokens: Option<usize>,
pub max_turns: Option<usize>,
pub context_window: Option<usize>,
pub safety_ratio: f64,
pub keep_last: usize,
pub keep_first: usize,
pub hard_limit_tokens: Option<usize>,
pub tool_output_max_chars: Option<usize>,
pub summarize_fn: Option<VmValue>,
pub summarize_prompt: Option<String>,
pub instructions: CompactionPolicy,
}
impl Default for CompactionPolicyDeclaration {
fn default() -> Self {
Self {
strategy: PolicyStrategy::SummarizeThenPrune,
max_tokens: None,
max_turns: None,
context_window: None,
safety_ratio: DEFAULT_SAFETY_RATIO,
keep_last: 12,
keep_first: 0,
hard_limit_tokens: None,
tool_output_max_chars: None,
summarize_fn: None,
summarize_prompt: None,
instructions: CompactionPolicy::default(),
}
}
}
impl CompactionPolicyDeclaration {
pub fn token_threshold(&self) -> Option<usize> {
let ratio_threshold = self.context_window.map(|window| {
let raw = (window as f64) * self.safety_ratio;
if raw.is_finite() && raw > 0.0 {
raw.floor() as usize
} else {
window
}
});
match (self.max_tokens, ratio_threshold) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}
pub fn evaluate(&self, estimated_tokens: usize, message_count: usize) -> EvaluationContext {
let token_threshold = self.token_threshold();
let token_trigger = token_threshold.is_some_and(|cap| estimated_tokens > cap);
let turn_trigger = self
.max_turns
.is_some_and(|cap| cap > 0 && message_count > cap);
EvaluationContext {
token_threshold,
token_trigger,
turn_trigger,
estimated_tokens,
message_count,
strategy: self.strategy,
}
}
pub fn to_json(&self) -> serde_json::Value {
let mut map = serde_json::Map::new();
map.insert(
"strategy".to_string(),
serde_json::Value::String(self.strategy.as_str().to_string()),
);
map.insert(
"engine_strategy".to_string(),
serde_json::Value::String(
compact_strategy_name(&self.strategy.engine_strategy()).to_string(),
),
);
if let Some(value) = self.max_tokens {
map.insert("max_tokens".to_string(), serde_json::json!(value));
}
if let Some(value) = self.max_turns {
map.insert("max_turns".to_string(), serde_json::json!(value));
}
if let Some(value) = self.context_window {
map.insert("context_window".to_string(), serde_json::json!(value));
}
map.insert(
"safety_ratio".to_string(),
serde_json::json!(self.safety_ratio),
);
map.insert("keep_last".to_string(), serde_json::json!(self.keep_last));
if self.keep_first > 0 {
map.insert("keep_first".to_string(), serde_json::json!(self.keep_first));
}
if let Some(value) = self.hard_limit_tokens {
map.insert("hard_limit_tokens".to_string(), serde_json::json!(value));
}
if let Some(value) = self.tool_output_max_chars {
map.insert(
"tool_output_max_chars".to_string(),
serde_json::json!(value),
);
}
if let Some(threshold) = self.token_threshold() {
map.insert("token_threshold".to_string(), serde_json::json!(threshold));
}
if let Some(policy_json) = self.instructions.metadata_json() {
map.insert("instructions".to_string(), policy_json);
}
serde_json::Value::Object(map)
}
}
#[derive(Clone, Debug)]
pub struct EvaluationContext {
pub token_threshold: Option<usize>,
pub token_trigger: bool,
pub turn_trigger: bool,
pub estimated_tokens: usize,
pub message_count: usize,
pub strategy: PolicyStrategy,
}
impl EvaluationContext {
pub fn fires(&self) -> bool {
self.token_trigger || self.turn_trigger
}
pub fn trigger_label(&self) -> &'static str {
match (self.token_trigger, self.turn_trigger) {
(true, true) => "tokens_and_turns",
(true, false) => "tokens",
(false, true) => "turns",
(false, false) => "manual",
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CompactionAction {
CompactNow,
Defer,
Abandon,
}
impl CompactionAction {
pub fn as_str(self) -> &'static str {
match self {
Self::CompactNow => "compact_now",
Self::Defer => "defer",
Self::Abandon => "abandon",
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompactionDecision {
pub action: String,
pub session_id: String,
pub estimated_tokens: usize,
pub message_count: usize,
pub trigger: String,
pub strategy: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_threshold: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_threshold: Option<usize>,
pub engine_strategy: String,
pub policy_inherited: bool,
}
const DEFAULT_POLICY_KEY: &str = "";
thread_local! {
static POLICIES: RefCell<BTreeMap<String, CompactionPolicyDeclaration>> =
const { RefCell::new(BTreeMap::new()) };
}
pub fn set_policy(session_id: &str, policy: CompactionPolicyDeclaration) {
POLICIES.with(|cell| {
cell.borrow_mut().insert(session_id.to_string(), policy);
});
}
pub fn clear_policy(session_id: &str) -> Option<CompactionPolicyDeclaration> {
POLICIES.with(|cell| cell.borrow_mut().remove(session_id))
}
pub fn policy_for(session_id: &str) -> Option<(CompactionPolicyDeclaration, bool)> {
POLICIES.with(|cell| {
let borrow = cell.borrow();
if let Some(policy) = borrow.get(session_id) {
return Some((policy.clone(), false));
}
borrow
.get(DEFAULT_POLICY_KEY)
.map(|policy| (policy.clone(), true))
})
}
pub fn reset_registry() {
POLICIES.with(|cell| cell.borrow_mut().clear());
}
pub fn to_auto_compact_config(policy: &CompactionPolicyDeclaration) -> super::AutoCompactConfig {
let engine_strategy = policy.strategy.engine_strategy();
let mut cfg = super::AutoCompactConfig {
keep_last: policy.keep_last,
keep_first: policy.keep_first,
compact_strategy: engine_strategy.clone(),
hard_limit_strategy: engine_strategy,
fallback_strategy: policy.strategy.engine_fallback(),
summarize_prompt: policy.summarize_prompt.clone(),
custom_compactor: policy.summarize_fn.clone(),
policy: policy.instructions.clone(),
policy_strategy: policy.strategy.as_str().to_string(),
..Default::default()
};
if let Some(threshold) = policy.token_threshold() {
cfg.token_threshold = threshold;
} else {
cfg.token_threshold = 0;
}
cfg.hard_limit_tokens = policy.hard_limit_tokens;
if let Some(value) = policy.tool_output_max_chars {
cfg.tool_output_max_chars = value;
}
cfg
}
pub fn parse_policy_dict(
builtin: &str,
dict: &BTreeMap<String, VmValue>,
) -> Result<CompactionPolicyDeclaration, String> {
let mut policy = CompactionPolicyDeclaration::default();
if let Some(value) = dict.get("strategy") {
match value {
VmValue::String(text) => {
policy.strategy =
PolicyStrategy::parse(text).map_err(|e| format!("{builtin}: {e}"))?;
}
VmValue::Nil => {}
other => {
return Err(format!(
"{builtin}: `strategy` must be a string, got {}",
other.type_name()
));
}
}
}
if let Some(value) = optional_usize(dict, "max_tokens", builtin)? {
policy.max_tokens = Some(value);
}
if let Some(value) = optional_usize(dict, "max_turns", builtin)? {
policy.max_turns = Some(value);
}
if let Some(value) = optional_usize(dict, "context_window", builtin)? {
policy.context_window = Some(value);
}
if let Some(value) = optional_f64(dict, "safety_ratio", builtin)? {
if !(0.0..=1.0).contains(&value) {
return Err(format!(
"{builtin}: `safety_ratio` must be between 0.0 and 1.0, got {value}"
));
}
policy.safety_ratio = value;
}
if let Some(value) = optional_usize(dict, "keep_last", builtin)? {
policy.keep_last = value;
}
if let Some(value) = optional_usize(dict, "keep_first", builtin)? {
policy.keep_first = value;
}
if let Some(value) = optional_usize(dict, "hard_limit_tokens", builtin)? {
policy.hard_limit_tokens = Some(value);
}
if let Some(value) = optional_usize(dict, "tool_output_max_chars", builtin)? {
policy.tool_output_max_chars = Some(value);
}
if let Some(value) = dict.get("summarize_fn") {
match value {
VmValue::Closure(_) => {
policy.summarize_fn = Some(value.clone());
}
VmValue::Nil => {}
other => {
return Err(format!(
"{builtin}: `summarize_fn` must be a closure, got {}",
other.type_name()
));
}
}
}
if let Some(value) = dict.get("summarize_prompt") {
match value {
VmValue::String(text) => {
let trimmed = text.trim();
if !trimmed.is_empty() {
policy.summarize_prompt = Some(trimmed.to_string());
}
}
VmValue::Nil => {}
other => {
return Err(format!(
"{builtin}: `summarize_prompt` must be a string, got {}",
other.type_name()
));
}
}
}
policy.instructions = super::parse_compaction_policy_options(Some(dict), builtin)
.map_err(|error| format!("{builtin}: {}", display_vm_error(&error)))?;
if matches!(policy.strategy, PolicyStrategy::Custom) && policy.summarize_fn.is_none() {
return Err(format!(
"{builtin}: `summarize_fn` is required when strategy is 'custom'"
));
}
if matches!(policy.strategy, PolicyStrategy::SummarizeThenPrune)
&& parse_compact_strategy("truncate").is_err()
{
return Err(format!(
"{builtin}: summarize-then-prune fallback 'truncate' is no longer a known engine strategy"
));
}
Ok(policy)
}
fn display_vm_error(error: &crate::value::VmError) -> String {
match error {
crate::value::VmError::Runtime(message) => message.clone(),
other => format!("{other:?}"),
}
}
fn optional_usize(
dict: &BTreeMap<String, VmValue>,
key: &str,
builtin: &str,
) -> Result<Option<usize>, String> {
match dict.get(key) {
None | Some(VmValue::Nil) => Ok(None),
Some(VmValue::Int(value)) => {
if *value < 0 {
return Err(format!("{builtin}: `{key}` must be >= 0, got {value}"));
}
Ok(Some(*value as usize))
}
Some(other) => Err(format!(
"{builtin}: `{key}` must be an int, got {}",
other.type_name()
)),
}
}
fn optional_f64(
dict: &BTreeMap<String, VmValue>,
key: &str,
builtin: &str,
) -> Result<Option<f64>, String> {
match dict.get(key) {
None | Some(VmValue::Nil) => Ok(None),
Some(VmValue::Float(value)) => Ok(Some(*value)),
Some(VmValue::Int(value)) => Ok(Some(*value as f64)),
Some(other) => Err(format!(
"{builtin}: `{key}` must be a number, got {}",
other.type_name()
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safety_ratio_picks_more_restrictive_cap() {
let policy = CompactionPolicyDeclaration {
max_tokens: Some(40_000),
context_window: Some(100_000),
safety_ratio: 0.5,
..Default::default()
};
assert_eq!(policy.token_threshold(), Some(40_000));
}
#[test]
fn ratio_only_when_window_set() {
let policy = CompactionPolicyDeclaration {
context_window: Some(120_000),
safety_ratio: 0.7,
..Default::default()
};
assert_eq!(policy.token_threshold(), Some(84_000));
}
#[test]
fn evaluate_marks_token_trigger() {
let policy = CompactionPolicyDeclaration {
max_tokens: Some(10_000),
..Default::default()
};
let ctx = policy.evaluate(12_000, 5);
assert!(ctx.token_trigger);
assert!(ctx.fires());
assert_eq!(ctx.trigger_label(), "tokens");
}
#[test]
fn evaluate_marks_turn_trigger() {
let policy = CompactionPolicyDeclaration {
max_turns: Some(20),
..Default::default()
};
let ctx = policy.evaluate(0, 25);
assert!(ctx.turn_trigger);
assert_eq!(ctx.trigger_label(), "turns");
}
#[test]
fn defer_when_no_thresholds_configured() {
let policy = CompactionPolicyDeclaration::default();
let ctx = policy.evaluate(1_000_000, 1_000_000);
assert!(!ctx.fires());
}
#[test]
fn default_policy_falls_back_to_session_lookup() {
reset_registry();
let default = CompactionPolicyDeclaration {
max_tokens: Some(50_000),
..Default::default()
};
set_policy(DEFAULT_POLICY_KEY, default);
let (resolved, inherited) =
policy_for("session-without-explicit-policy").expect("default policy resolved");
assert!(inherited);
assert_eq!(resolved.max_tokens, Some(50_000));
reset_registry();
}
#[test]
fn session_specific_policy_takes_precedence() {
reset_registry();
set_policy(
"",
CompactionPolicyDeclaration {
max_tokens: Some(50_000),
..Default::default()
},
);
set_policy(
"session-a",
CompactionPolicyDeclaration {
max_tokens: Some(80_000),
..Default::default()
},
);
let (resolved, inherited) = policy_for("session-a").expect("session policy resolved");
assert!(!inherited);
assert_eq!(resolved.max_tokens, Some(80_000));
reset_registry();
}
#[test]
fn strategy_aliases_round_trip() {
assert_eq!(
PolicyStrategy::parse("summarize")
.unwrap()
.engine_strategy(),
CompactStrategy::Llm
);
assert_eq!(
PolicyStrategy::parse("summarize-then-prune")
.unwrap()
.engine_fallback(),
Some(CompactStrategy::Truncate)
);
assert_eq!(
PolicyStrategy::parse("window").unwrap().engine_strategy(),
CompactStrategy::Truncate
);
assert_eq!(
PolicyStrategy::parse("head+tail")
.unwrap()
.engine_strategy(),
CompactStrategy::Truncate
);
assert_eq!(
PolicyStrategy::parse("observation_mask")
.unwrap()
.engine_strategy(),
CompactStrategy::ObservationMask
);
assert!(PolicyStrategy::parse("unknown").is_err());
}
}