mod compaction;
mod truncation;
pub use compaction::{
ARTIFACT_COMPACT_THRESHOLD_TOKENS, ARTIFACT_PREVIEW_MAX_CHARS, ARTIFACT_PREVIEW_MAX_LINES,
compact_artifact, compact_tool_results,
};
pub use truncation::{adjust_split_for_tool_pairs, find_split_point};
use awaken_contract::contract::inference::ContextWindowPolicy;
use awaken_contract::contract::message::{Message, Role};
use awaken_contract::contract::tool::ToolDescriptor;
use awaken_contract::contract::transform::{
InferenceRequestTransform, TransformOutput, estimate_message_tokens, estimate_tokens,
estimate_tool_tokens, patch_dangling_tool_calls,
};
pub struct ContextTransform {
policy: ContextWindowPolicy,
}
impl ContextTransform {
pub fn new(policy: ContextWindowPolicy) -> Self {
Self { policy }
}
}
impl InferenceRequestTransform for ContextTransform {
fn transform(
&self,
mut messages: Vec<Message>,
tool_descriptors: &[ToolDescriptor],
) -> TransformOutput {
compact_tool_results(&mut messages);
let tool_tokens = estimate_tool_tokens(tool_descriptors);
let available = self
.policy
.max_context_tokens
.saturating_sub(self.policy.max_output_tokens)
.saturating_sub(tool_tokens);
let total = estimate_tokens(&messages);
if total <= available {
return TransformOutput { messages };
}
let system_end = messages
.iter()
.position(|m| m.role != Role::System)
.unwrap_or(messages.len());
let system_tokens: usize = messages[..system_end]
.iter()
.map(estimate_message_tokens)
.sum();
let history_budget = available.saturating_sub(system_tokens);
let history = &messages[system_end..];
let split = find_split_point(history, history_budget, self.policy.min_recent_messages);
let absolute_split = system_end + split;
let dropped = absolute_split.saturating_sub(system_end);
if absolute_split > system_end {
messages.drain(system_end..absolute_split);
}
let kept = messages.len();
patch_dangling_tool_calls(&mut messages);
tracing::debug!(dropped, kept, "truncation_applied");
TransformOutput { messages }
}
}
#[cfg(test)]
mod tests {
use super::*;
use awaken_contract::contract::inference::ContextWindowPolicy;
use awaken_contract::contract::message::ToolCall;
use serde_json::json;
fn make_policy(max_tokens: usize, min_recent: usize) -> ContextWindowPolicy {
ContextWindowPolicy {
max_context_tokens: max_tokens,
max_output_tokens: 0,
min_recent_messages: min_recent,
enable_prompt_cache: false,
autocompact_threshold: None,
compaction_mode: Default::default(),
compaction_raw_suffix_messages: 2,
}
}
#[test]
fn truncation_preserves_all_when_under_budget() {
let transform = ContextTransform::new(make_policy(100_000, 2));
let messages = vec![
Message::system("sys"),
Message::user("hello"),
Message::assistant("hi"),
];
let output = transform.transform(messages.clone(), &[]);
assert_eq!(output.messages.len(), 3);
}
#[test]
fn truncation_keeps_system_and_recent() {
let transform = ContextTransform::new(make_policy(50, 2));
let mut messages = vec![Message::system("sys")];
for i in 0..20 {
messages.push(Message::user(format!("msg {i}")));
messages.push(Message::assistant(format!("reply {i}")));
}
let output = transform.transform(messages, &[]);
assert!(output.messages.len() >= 3);
assert_eq!(output.messages[0].role, Role::System);
}
#[test]
fn truncation_repairs_dangling_tool_calls() {
let transform = ContextTransform::new(make_policy(30, 1));
let messages = vec![
Message::system("sys"),
Message::user("old msg 1"),
Message::assistant_with_tool_calls(
"calling",
vec![ToolCall::new("c1", "search", json!({}))],
),
Message::tool("c1", "result"),
Message::user("old msg 2"),
Message::assistant("old reply"),
Message::user("recent"),
Message::assistant("recent reply"),
];
let output = transform.transform(messages, &[]);
for (i, msg) in output.messages.iter().enumerate() {
if msg.role == Role::Assistant
&& let Some(ref calls) = msg.tool_calls
{
for call in calls {
assert!(
output.messages[i + 1..]
.iter()
.any(|m| m.tool_call_id.as_deref() == Some(&call.id)),
"tool call {} should have a matching result",
call.id
);
}
}
}
}
#[test]
fn truncation_tool_pair_not_broken() {
let transform = ContextTransform::new(make_policy(60, 1));
let messages = vec![
Message::system("sys"),
Message::user("old"),
Message::assistant_with_tool_calls(
"calling",
vec![ToolCall::new("c1", "search", json!({}))],
),
Message::tool("c1", "found"),
Message::user("recent"),
Message::assistant("reply"),
];
let output = transform.transform(messages, &[]);
for (i, msg) in output.messages.iter().enumerate() {
if msg.role == Role::Assistant
&& let Some(ref calls) = msg.tool_calls
{
for call in calls {
let has_result = output.messages[i + 1..]
.iter()
.any(|m| m.tool_call_id.as_deref() == Some(&call.id));
assert!(
has_result,
"tool call {} should have matching result",
call.id
);
}
}
}
}
#[test]
fn truncation_with_tool_descriptors_reduces_budget() {
use awaken_contract::contract::tool::ToolDescriptor;
let transform = ContextTransform::new(make_policy(100, 2));
let messages = vec![
Message::system("sys"),
Message::user("hello"),
Message::assistant("world"),
];
let output_no_tools = transform.transform(messages.clone(), &[]);
let count_no_tools = output_no_tools.messages.len();
let big_tool = ToolDescriptor {
id: "t".into(),
name: "t".into(),
description: "x".repeat(200),
parameters: json!({"type": "object", "properties": {
"a": {"type": "string"}, "b": {"type": "string"},
"c": {"type": "string"}, "d": {"type": "string"},
}}),
category: None,
metadata: Default::default(),
};
let output_with_tools = transform.transform(messages, &[big_tool]);
assert!(output_with_tools.messages.len() <= count_no_tools);
}
#[test]
fn no_truncation_when_within_budget() {
let transform = ContextTransform::new(make_policy(100_000, 2));
let messages = vec![
Message::system("system prompt"),
Message::user("hello"),
Message::assistant("hi there"),
Message::user("how are you?"),
Message::assistant("doing great"),
];
let output = transform.transform(messages.clone(), &[]);
assert_eq!(output.messages.len(), messages.len());
for (a, b) in output.messages.iter().zip(messages.iter()) {
assert_eq!(a.text(), b.text());
}
}
#[test]
fn truncation_drops_oldest_history() {
let transform = ContextTransform::new(make_policy(60, 2));
let filler = |tag: &str| format!("{tag}:{}", "x".repeat(40));
let messages = vec![
Message::system("sys"),
Message::user(filler("old1")),
Message::assistant(filler("old_reply1")),
Message::user(filler("old2")),
Message::assistant(filler("old_reply2")),
Message::user(filler("recent1")),
Message::assistant(filler("recent_reply1")),
];
let output = transform.transform(messages, &[]);
assert_eq!(output.messages[0].role, Role::System);
assert_eq!(output.messages[0].text(), "sys");
let texts: Vec<String> = output.messages.iter().map(|m| m.text()).collect();
assert!(
!texts.iter().any(|t| t.starts_with("old1:")),
"oldest message should be dropped"
);
assert!(
texts.iter().any(|t| t.starts_with("recent_reply1:")),
"most recent message should be preserved"
);
}
#[test]
fn min_recent_always_preserved() {
let transform = ContextTransform::new(make_policy(20, 4));
let messages = vec![
Message::system("s"),
Message::user("a"),
Message::assistant("b"),
Message::user("c"),
Message::assistant("d"),
Message::user("e"),
Message::assistant("f"),
];
let output = transform.transform(messages, &[]);
let history_count = output
.messages
.iter()
.filter(|m| m.role != Role::System)
.count();
assert!(
history_count >= 4,
"min_recent_messages=4 but only {history_count} history messages kept"
);
}
#[test]
fn system_messages_never_truncated() {
let transform = ContextTransform::new(make_policy(60, 1));
let messages = vec![
Message::system("system prompt 1"),
Message::system("system prompt 2"),
Message::system("system prompt 3"),
Message::user("old1"),
Message::assistant("old_reply1"),
Message::user("old2"),
Message::assistant("old_reply2"),
Message::user("recent"),
Message::assistant("recent_reply"),
];
let output = transform.transform(messages, &[]);
let system_msgs: Vec<&Message> = output
.messages
.iter()
.filter(|m| m.role == Role::System)
.collect();
assert_eq!(
system_msgs.len(),
3,
"all system messages must be preserved"
);
assert_eq!(system_msgs[0].text(), "system prompt 1");
assert_eq!(system_msgs[1].text(), "system prompt 2");
assert_eq!(system_msgs[2].text(), "system prompt 3");
}
#[test]
fn truncation_empty_messages() {
let transform = ContextTransform::new(make_policy(100, 2));
let messages = vec![];
let output = transform.transform(messages, &[]);
assert!(output.messages.is_empty());
}
#[test]
fn truncation_system_only() {
let transform = ContextTransform::new(make_policy(100, 2));
let messages = vec![Message::system("system only")];
let output = transform.transform(messages, &[]);
assert_eq!(output.messages.len(), 1);
assert_eq!(output.messages[0].role, Role::System);
}
#[test]
fn truncation_preserves_message_order() {
let transform = ContextTransform::new(make_policy(100_000, 2));
let messages = vec![
Message::system("sys"),
Message::user("u1"),
Message::assistant("a1"),
Message::user("u2"),
Message::assistant("a2"),
];
let output = transform.transform(messages.clone(), &[]);
for (i, msg) in output.messages.iter().enumerate() {
assert_eq!(msg.role, messages[i].role);
assert_eq!(msg.text(), messages[i].text());
}
}
#[test]
fn truncation_with_only_tool_messages() {
let transform = ContextTransform::new(make_policy(100, 1));
let messages = vec![
Message::system("sys"),
Message::user("go"),
Message::assistant_with_tool_calls("", vec![ToolCall::new("c1", "t", json!({}))]),
Message::tool("c1", "result"),
];
let output = transform.transform(messages, &[]);
assert!(!output.messages.is_empty());
assert_eq!(output.messages[0].role, Role::System);
}
}