use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use crate::traits::output_transformer::OutputTransformer;
use crate::traits::provider::Provider;
use crate::types::agent_state::AgentState;
use crate::types::completion::{CompletionRequest, ResponseContent};
use crate::types::message::Message;
pub struct BudgetAwareTruncator {
max_chars: usize,
aggressive_threshold: f32,
}
impl BudgetAwareTruncator {
#[must_use]
pub fn new(max_chars: usize, aggressive_threshold: f32) -> Self {
Self {
max_chars,
aggressive_threshold: aggressive_threshold.clamp(0.0, 1.0),
}
}
}
impl Default for BudgetAwareTruncator {
fn default() -> Self {
Self::new(10_000, 0.8)
}
}
#[async_trait]
impl OutputTransformer for BudgetAwareTruncator {
async fn transform(&self, output: String, _tool_name: &str, state: &AgentState) -> String {
let limit = if state.context_utilization() > self.aggressive_threshold {
self.max_chars / 2
} else {
self.max_chars
};
if output.len() <= limit {
return output;
}
let truncated: String = output.chars().take(limit).collect();
format!(
"{truncated}\n\n[output truncated from {} to {limit} chars]",
output.len()
)
}
}
pub struct JsonExtractor;
#[async_trait]
impl OutputTransformer for JsonExtractor {
async fn transform(&self, output: String, _tool_name: &str, _state: &AgentState) -> String {
if let Some(start) = output.find('{') {
if let Some(end) = output.rfind('}') {
if end >= start {
return output[start..=end].to_string();
}
}
}
if let Some(start) = output.find('[') {
if let Some(end) = output.rfind(']') {
if end >= start {
return output[start..=end].to_string();
}
}
}
output
}
}
pub struct TransformerChain {
transformers: Vec<Box<dyn OutputTransformer>>,
}
impl TransformerChain {
#[must_use]
pub fn new(transformers: Vec<Box<dyn OutputTransformer>>) -> Self {
Self { transformers }
}
}
#[async_trait]
impl OutputTransformer for TransformerChain {
async fn transform(&self, mut output: String, tool_name: &str, state: &AgentState) -> String {
for t in &self.transformers {
output = t.transform(output, tool_name, state).await;
}
output
}
}
const DEFAULT_SUMMARY_PROMPT: &str =
"Summarize the following tool output concisely, preserving all key data points and values. \
Be brief but complete:\n\n{output}";
pub struct ProgressiveTransformer {
provider: Arc<dyn Provider>,
max_summary_length: usize,
summary_prompt: String,
cache: Arc<RwLock<HashMap<String, String>>>,
}
impl ProgressiveTransformer {
#[must_use]
pub fn new(provider: Arc<dyn Provider>, max_summary_length: usize) -> Self {
Self {
provider,
max_summary_length,
summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
self.summary_prompt = prompt.into();
self
}
#[must_use]
pub fn retriever_tool(&self) -> FullOutputRetriever {
FullOutputRetriever {
cache: Arc::clone(&self.cache),
}
}
fn cache_output(&self, tool_name: &str, output: &str) {
let mut cache = self
.cache
.write()
.expect("ProgressiveTransformer cache lock poisoned");
cache.insert(tool_name.to_string(), output.to_string());
}
fn build_prompt(&self, output: &str) -> String {
self.summary_prompt.replace("{output}", output)
}
}
#[async_trait]
impl OutputTransformer for ProgressiveTransformer {
async fn transform(&self, output: String, tool_name: &str, _state: &AgentState) -> String {
if output.len() <= self.max_summary_length {
return output;
}
self.cache_output(tool_name, &output);
let prompt = self.build_prompt(&output);
let request = CompletionRequest {
model: self.provider.model_info().name.clone(),
messages: vec![Message::user(prompt)],
tools: vec![],
max_tokens: Some(500),
temperature: Some(0.3),
response_format: None,
stream: false,
};
match self.provider.complete(request).await {
Ok(response) => {
let summary = match response.content {
ResponseContent::Text(t) => t,
ResponseContent::ToolCalls(_) => {
let truncated: String =
output.chars().take(self.max_summary_length).collect();
return format!(
"{truncated}\n\n\
[output truncated from {} chars — summarizer returned tool calls]",
output.len()
);
}
};
format!(
"{summary}\n\n\
[Full output ({} chars) cached. \
Call __get_full_output with {{\"tool_name\": \"{tool_name}\"}} to retrieve it.]",
output.len()
)
}
Err(e) => {
tracing::warn!(
"ProgressiveTransformer: LLM summarization failed for '{tool_name}': {e}. \
Falling back to truncation."
);
let truncated: String = output.chars().take(self.max_summary_length).collect();
format!(
"{truncated}\n\n\
[output truncated from {} chars — LLM summarization failed]",
output.len()
)
}
}
}
}
pub struct FullOutputRetriever {
cache: Arc<RwLock<HashMap<String, String>>>,
}
impl FullOutputRetriever {
#[must_use]
pub fn retrieve(&self, tool_name: &str) -> String {
let cache = self
.cache
.read()
.expect("FullOutputRetriever cache lock poisoned");
match cache.get(tool_name) {
Some(output) => output.clone(),
None => format!(
"[No cached output found for tool '{tool_name}'. \
The output may have expired or the tool name is incorrect.]"
),
}
}
#[must_use]
pub fn has_cached(&self, tool_name: &str) -> bool {
let cache = self
.cache
.read()
.expect("FullOutputRetriever cache lock poisoned");
cache.contains_key(tool_name)
}
}