use adk_core::{AdkError, ErrorCategory, ErrorComponent, Event, Llm};
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
#[error("Context overflow: {token_count} tokens (limit: {limit})")]
pub struct ContextOverflowError {
pub token_count: usize,
pub limit: usize,
}
impl From<ContextOverflowError> for AdkError {
fn from(err: ContextOverflowError) -> Self {
AdkError::new(
ErrorComponent::Model,
ErrorCategory::InvalidInput,
"runner.context_overflow",
err.to_string(),
)
}
}
#[async_trait]
pub trait CompactionStrategy: Send + Sync {
async fn compact(&self, events: Vec<Event>, budget: usize) -> Result<Vec<Event>, AdkError>;
}
pub struct TruncationCompaction {
pub preserve_recent: usize,
}
#[async_trait]
impl CompactionStrategy for TruncationCompaction {
async fn compact(&self, events: Vec<Event>, _budget: usize) -> Result<Vec<Event>, AdkError> {
let len = events.len();
if len <= self.preserve_recent + 1 {
return Ok(events);
}
let mut compacted = Vec::with_capacity(self.preserve_recent + 1);
if let Some(first) = events.first() {
compacted.push(first.clone());
}
let start = len.saturating_sub(self.preserve_recent);
compacted.extend_from_slice(&events[start..]);
tracing::debug!(
original_count = len,
compacted_count = compacted.len(),
dropped = len - compacted.len(),
"truncation compaction applied"
);
Ok(compacted)
}
}
pub struct SummarisationCompaction {
pub model: Arc<dyn Llm>,
pub turns_to_summarise: usize,
}
#[async_trait]
impl CompactionStrategy for SummarisationCompaction {
async fn compact(&self, events: Vec<Event>, _budget: usize) -> Result<Vec<Event>, AdkError> {
let len = events.len();
if len <= self.turns_to_summarise {
return Ok(events);
}
let summarize_end = self.turns_to_summarise.min(len);
let events_to_summarize = &events[..summarize_end];
let events_to_preserve = &events[summarize_end..];
let summary_text = build_summary_prompt(events_to_summarize);
let request = adk_core::LlmRequest::new(
self.model.name().to_string(),
vec![adk_core::Content::new("user").with_text(summary_text)],
);
let mut stream = self.model.generate_content(request, false).await?;
use futures::StreamExt;
let mut summary_content = String::new();
while let Some(response) = stream.next().await {
let response = response?;
if let Some(content) = &response.content {
for part in &content.parts {
if let adk_core::Part::Text { text } = part {
summary_content.push_str(text);
}
}
}
}
let mut summary_event = Event::new("compaction");
summary_event.author = "system".to_string();
summary_event.set_content(
adk_core::Content::new("model")
.with_text(format!("[Context Summary]\n{summary_content}")),
);
let mut compacted = Vec::with_capacity(1 + events_to_preserve.len());
compacted.push(summary_event);
compacted.extend_from_slice(events_to_preserve);
tracing::debug!(
original_count = len,
summarized_count = summarize_end,
preserved_count = events_to_preserve.len(),
"summarisation compaction applied"
);
Ok(compacted)
}
}
fn build_summary_prompt(events: &[Event]) -> String {
let mut prompt = String::from(
"Summarize the following conversation history into a concise summary \
that preserves key facts, decisions, and context. Be brief but complete.\n\n",
);
for event in events {
if let Some(content) = event.content() {
prompt.push_str(&format!("[{}]: ", content.role));
for part in &content.parts {
if let adk_core::Part::Text { text } = part {
prompt.push_str(text);
}
}
prompt.push('\n');
}
}
prompt
}
pub struct CompactionConfig {
pub strategy: Box<dyn CompactionStrategy>,
pub context_budget: usize,
pub max_retries: usize,
}
impl CompactionConfig {
pub fn new(strategy: Box<dyn CompactionStrategy>, context_budget: usize) -> Self {
Self { strategy, context_budget, max_retries: 2 }
}
}
impl std::fmt::Debug for CompactionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompactionConfig")
.field("context_budget", &self.context_budget)
.field("max_retries", &self.max_retries)
.field("strategy", &"<dyn CompactionStrategy>")
.finish()
}
}
pub fn is_token_limit_error(err: &AdkError) -> bool {
if err.component != adk_core::ErrorComponent::Model {
return false;
}
if err.category != adk_core::ErrorCategory::InvalidInput {
return false;
}
if err.code == "runner.context_overflow" {
return true;
}
let msg = err.message.to_lowercase();
let token_limit_patterns = [
"token",
"context length",
"context_length",
"too long",
"too many tokens",
"payload size exceeds",
"maximum context",
"max_tokens",
"input too large",
"prompt is too long",
"exceeds the model",
];
token_limit_patterns.iter().any(|pattern| msg.contains(pattern))
}
pub async fn apply_compaction_with_retry(
config: &CompactionConfig,
events: Vec<Event>,
) -> Result<Vec<Event>, AdkError> {
let mut current_events = events;
for attempt in 0..config.max_retries {
tracing::info!(
attempt = attempt + 1,
max_retries = config.max_retries,
event_count = current_events.len(),
budget = config.context_budget,
"applying context compaction"
);
current_events = config.strategy.compact(current_events, config.context_budget).await?;
let estimated_tokens = estimate_event_tokens(¤t_events);
if estimated_tokens <= config.context_budget {
tracing::info!(
estimated_tokens,
budget = config.context_budget,
"compaction succeeded, context within budget"
);
return Ok(current_events);
}
tracing::warn!(
estimated_tokens,
budget = config.context_budget,
attempt = attempt + 1,
"compaction did not bring context under budget, retrying"
);
}
let final_tokens = estimate_event_tokens(¤t_events);
Err(ContextOverflowError { token_count: final_tokens, limit: config.context_budget }.into())
}
pub fn estimate_event_tokens(events: &[Event]) -> usize {
let total_chars: usize = events
.iter()
.map(|e| {
e.content()
.map(|c| {
c.parts
.iter()
.map(|p| match p {
adk_core::Part::Text { text } => text.len(),
_ => 20, })
.sum::<usize>()
})
.unwrap_or(0)
})
.sum();
total_chars / 4
}
#[cfg(test)]
mod tests {
use super::*;
use adk_core::{Content, Event};
fn make_events(count: usize) -> Vec<Event> {
(0..count)
.map(|i| {
let mut event = Event::new("test-inv");
event.author = if i == 0 { "system".to_string() } else { "user".to_string() };
event.set_content(Content::new("user").with_text(format!("message {i}")));
event
})
.collect()
}
#[tokio::test]
async fn test_truncation_preserves_system_and_recent() {
let strategy = TruncationCompaction { preserve_recent: 3 };
let events = make_events(10);
let compacted = strategy.compact(events.clone(), 4096).await.unwrap();
assert_eq!(compacted.len(), 4);
assert_eq!(compacted[0].author, "system");
assert_eq!(compacted[1].id, events[7].id);
assert_eq!(compacted[2].id, events[8].id);
assert_eq!(compacted[3].id, events[9].id);
}
#[tokio::test]
async fn test_truncation_no_op_when_few_events() {
let strategy = TruncationCompaction { preserve_recent: 5 };
let events = make_events(3);
let compacted = strategy.compact(events.clone(), 4096).await.unwrap();
assert_eq!(compacted.len(), 3);
}
#[tokio::test]
async fn test_truncation_exact_boundary() {
let strategy = TruncationCompaction { preserve_recent: 4 };
let events = make_events(5);
let compacted = strategy.compact(events.clone(), 4096).await.unwrap();
assert_eq!(compacted.len(), 5);
}
#[test]
fn test_context_overflow_error_display() {
let err = ContextOverflowError { token_count: 50_000, limit: 32_000 };
assert_eq!(err.to_string(), "Context overflow: 50000 tokens (limit: 32000)");
}
#[test]
fn test_context_overflow_error_into_adk_error() {
let err = ContextOverflowError { token_count: 50_000, limit: 32_000 };
let adk_err: AdkError = err.into();
assert!(adk_err.is_model());
assert_eq!(adk_err.code, "runner.context_overflow");
}
#[test]
fn test_compaction_config_new_defaults_max_retries() {
let strategy = TruncationCompaction { preserve_recent: 5 };
let config = CompactionConfig::new(Box::new(strategy), 100_000);
assert_eq!(config.context_budget, 100_000);
assert_eq!(config.max_retries, 2);
}
#[test]
fn test_compaction_config_custom_max_retries() {
let config = CompactionConfig {
strategy: Box::new(TruncationCompaction { preserve_recent: 3 }),
context_budget: 50_000,
max_retries: 5,
};
assert_eq!(config.context_budget, 50_000);
assert_eq!(config.max_retries, 5);
}
#[test]
fn test_compaction_config_debug() {
let config =
CompactionConfig::new(Box::new(TruncationCompaction { preserve_recent: 3 }), 32_000);
let debug_str = format!("{config:?}");
assert!(debug_str.contains("CompactionConfig"));
assert!(debug_str.contains("32000"));
assert!(debug_str.contains("max_retries: 2"));
}
#[test]
fn test_build_summary_prompt() {
let mut events = Vec::new();
let mut e1 = Event::new("inv");
e1.set_content(Content::new("user").with_text("Hello"));
events.push(e1);
let mut e2 = Event::new("inv");
e2.set_content(Content::new("model").with_text("Hi there!"));
events.push(e2);
let prompt = build_summary_prompt(&events);
assert!(prompt.contains("[user]: Hello"));
assert!(prompt.contains("[model]: Hi there!"));
assert!(prompt.contains("Summarize"));
}
#[test]
fn test_is_token_limit_error_detects_openai_style() {
let err = AdkError::new(
adk_core::ErrorComponent::Model,
adk_core::ErrorCategory::InvalidInput,
"model.openai.bad_request",
"This model's maximum context length is 128000 tokens",
);
assert!(is_token_limit_error(&err));
}
#[test]
fn test_is_token_limit_error_detects_anthropic_style() {
let err = AdkError::new(
adk_core::ErrorComponent::Model,
adk_core::ErrorCategory::InvalidInput,
"model.anthropic.bad_request",
"prompt is too long: 200000 tokens > 100000 maximum",
);
assert!(is_token_limit_error(&err));
}
#[test]
fn test_is_token_limit_error_detects_context_overflow_code() {
let err = AdkError::new(
adk_core::ErrorComponent::Model,
adk_core::ErrorCategory::InvalidInput,
"runner.context_overflow",
"Context overflow: 50000 tokens (limit: 32000)",
);
assert!(is_token_limit_error(&err));
}
#[test]
fn test_is_token_limit_error_rejects_non_model_error() {
let err = AdkError::new(
adk_core::ErrorComponent::Tool,
adk_core::ErrorCategory::InvalidInput,
"tool.error",
"token limit exceeded",
);
assert!(!is_token_limit_error(&err));
}
#[test]
fn test_is_token_limit_error_rejects_non_invalid_input() {
let err = AdkError::new(
adk_core::ErrorComponent::Model,
adk_core::ErrorCategory::Internal,
"model.internal",
"token limit exceeded",
);
assert!(!is_token_limit_error(&err));
}
#[test]
fn test_is_token_limit_error_rejects_unrelated_invalid_input() {
let err = AdkError::new(
adk_core::ErrorComponent::Model,
adk_core::ErrorCategory::InvalidInput,
"model.openai.bad_request",
"invalid JSON in request body",
);
assert!(!is_token_limit_error(&err));
}
#[test]
fn test_estimate_event_tokens_empty() {
let events: Vec<Event> = Vec::new();
assert_eq!(estimate_event_tokens(&events), 0);
}
#[test]
fn test_estimate_event_tokens_with_content() {
let mut event = Event::new("inv");
event.set_content(Content::new("user").with_text("Hello world"));
let events = vec![event];
assert_eq!(estimate_event_tokens(&events), 11 / 4); }
#[tokio::test]
async fn test_apply_compaction_with_retry_succeeds_first_try() {
let strategy = TruncationCompaction { preserve_recent: 2 };
let config = CompactionConfig {
strategy: Box::new(strategy),
context_budget: 100, max_retries: 2,
};
let events = make_events(10);
let result = apply_compaction_with_retry(&config, events).await;
assert!(result.is_ok());
let compacted = result.unwrap();
assert_eq!(compacted.len(), 3);
}
#[tokio::test]
async fn test_apply_compaction_with_retry_fails_when_budget_too_small() {
let strategy = TruncationCompaction { preserve_recent: 2 };
let config = CompactionConfig {
strategy: Box::new(strategy),
context_budget: 0, max_retries: 2,
};
let mut events = Vec::new();
for i in 0..5 {
let mut e = Event::new("inv");
e.set_content(Content::new("user").with_text(format!("message {i} with some content")));
events.push(e);
}
let result = apply_compaction_with_retry(&config, events).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.code, "runner.context_overflow");
}
}