ai_tokenopt 0.5.5

Adaptive token optimization engine for LLM inference pipelines — compresses prompts, conversation history, tool schemas, and output streams to minimize token usage while preserving response quality.
Documentation
//! Tool result truncation
//!
//! Extractive truncation of tool execution results to reduce token
//! usage in multi-step ReAct chains without losing critical information.

use crate::estimator::TokenEstimator;
use crate::types::{ChatMessage, MessageRole};

/// Compress tool-result messages that precede the last user message.
///
/// Only messages with `MessageRole::Tool` that appear *before* the last
/// user message are truncated. Current-turn tool results (those after the
/// last user message) are left intact so the LLM can reason over the full
/// output.
///
/// Returns the total number of estimated tokens saved.
pub fn compress_old_tool_results(messages: &mut [ChatMessage], max_tokens: u32) -> u32 {
    if messages.is_empty() {
        return 0;
    }

    // Find the index of the last user message
    let last_user_idx = messages
        .iter()
        .rposition(|m| m.role == MessageRole::User)
        .unwrap_or(messages.len());

    let mut tokens_saved: u32 = 0;

    for msg in &mut messages[..last_user_idx] {
        if msg.role != MessageRole::Tool {
            continue;
        }

        let before = TokenEstimator::estimate_tokens(&msg.content);
        if before <= max_tokens {
            continue;
        }

        msg.content = truncate_tool_result(&msg.content, max_tokens);
        let after = TokenEstimator::estimate_tokens(&msg.content);
        tokens_saved += before.saturating_sub(after);
    }

    tokens_saved
}

/// Truncate a tool result to fit within a token budget.
///
/// Strategy:
/// - Preserve error messages completely (they're concise and critical)
/// - For JSON: extract key fields
/// - For text: take first N sentences + key data points
/// - Always preserve numeric values and named entities
#[must_use]
pub fn truncate_tool_result(result: &str, max_tokens: u32) -> String {
    if result.is_empty() {
        return String::new();
    }

    let current_tokens = TokenEstimator::estimate_tokens(result);
    if current_tokens <= max_tokens {
        return result.to_string();
    }

    // Check if it looks like an error — preserve fully (errors are usually short)
    if is_error_result(result) {
        return truncate_text(result, max_tokens);
    }

    // Try JSON extraction
    if result.trim_start().starts_with('{') || result.trim_start().starts_with('[') {
        if let Some(extracted) = extract_json_key_fields(result, max_tokens) {
            return extracted;
        }
    }

    // Text extraction: prioritize first and last lines, numeric lines
    extract_text_key_lines(result, max_tokens)
}

/// Check if a result looks like an error message.
fn is_error_result(text: &str) -> bool {
    let lower = text.to_lowercase();
    lower.starts_with("error")
        || lower.contains("failed")
        || lower.contains("exception")
        || lower.contains("not found")
        || lower.contains("timeout")
}

/// Extract key fields from a JSON string.
fn extract_json_key_fields(json_str: &str, max_tokens: u32) -> Option<String> {
    let value: serde_json::Value = serde_json::from_str(json_str).ok()?;

    match value {
        serde_json::Value::Object(map) => {
            let mut parts: Vec<String> = Vec::new();
            let mut tokens_used: u32 = 2; // for braces

            // Prioritize: error/status fields first, then short values, then rest
            let priority_keys = [
                "error", "status", "message", "name", "title", "id", "result",
            ];

            for key in priority_keys {
                if let Some(val) = map.get(key) {
                    let line = format_json_field(key, val);
                    let line_tokens = TokenEstimator::estimate_tokens(&line);
                    if tokens_used + line_tokens <= max_tokens {
                        parts.push(line);
                        tokens_used += line_tokens;
                    }
                }
            }

            // Add remaining fields
            for (key, val) in &map {
                if priority_keys.contains(&key.as_str()) {
                    continue;
                }
                let line = format_json_field(key, val);
                let line_tokens = TokenEstimator::estimate_tokens(&line);
                if tokens_used + line_tokens <= max_tokens {
                    parts.push(line);
                    tokens_used += line_tokens;
                }
            }

            if parts.is_empty() {
                return None;
            }

            Some(format!("{{{}}}", parts.join(", ")))
        },
        serde_json::Value::Array(arr) => {
            // For arrays: take first N elements
            let mut parts: Vec<String> = Vec::new();
            let mut tokens_used: u32 = 2; // brackets

            for item in &arr {
                let s = item.to_string();
                let item_tokens = TokenEstimator::estimate_tokens(&s);
                if tokens_used + item_tokens > max_tokens {
                    parts.push(format!("...({} more items)", arr.len() - parts.len()));
                    break;
                }
                parts.push(s);
                tokens_used += item_tokens;
            }

            Some(format!("[{}]", parts.join(", ")))
        },
        _ => None,
    }
}

/// Format a single JSON field concisely.
fn format_json_field(key: &str, value: &serde_json::Value) -> String {
    match value {
        serde_json::Value::String(s) if s.len() > 100 => {
            let truncated: String = s.chars().take(97).collect();
            format!("\"{key}\": \"{truncated}...\"")
        },
        serde_json::Value::Array(arr) if arr.len() > 3 => {
            let preview: Vec<String> = arr
                .iter()
                .take(3)
                .map(std::string::ToString::to_string)
                .collect();
            format!(
                "\"{key}\": [{}, ...({} more)]",
                preview.join(", "),
                arr.len() - 3
            )
        },
        serde_json::Value::Object(_) => {
            format!("\"{key}\": {{...}}")
        },
        _ => format!("\"{key}\": {value}"),
    }
}

/// Extract key lines from plain text.
fn extract_text_key_lines(text: &str, max_tokens: u32) -> String {
    let lines: Vec<&str> = text.lines().collect();
    if lines.is_empty() {
        return String::new();
    }

    let mut selected: Vec<&str> = Vec::new();
    let mut tokens_used: u32 = 0;

    // Prioritize: first 3 lines, lines with numbers, last line
    let mut priority_indices: Vec<usize> = Vec::new();

    // First 3 lines
    for i in 0..3.min(lines.len()) {
        priority_indices.push(i);
    }

    // Last line
    if lines.len() > 3 {
        priority_indices.push(lines.len() - 1);
    }

    // Lines containing numbers (high information)
    for (i, line) in lines.iter().enumerate() {
        if line.chars().any(|c| c.is_ascii_digit()) && !priority_indices.contains(&i) {
            priority_indices.push(i);
        }
    }

    // Deduplicate and sort
    priority_indices.sort_unstable();
    priority_indices.dedup();

    for &i in &priority_indices {
        let line_tokens = TokenEstimator::estimate_tokens(lines[i]);
        if tokens_used + line_tokens > max_tokens {
            break;
        }
        selected.push(lines[i]);
        tokens_used += line_tokens;
    }

    if selected.is_empty() {
        return truncate_text(text, max_tokens);
    }

    let omitted = lines.len() - selected.len();
    let mut result = selected.join("\n");
    if omitted > 0 {
        result.push_str(&format!("\n[...{omitted} lines omitted]"));
    }
    result
}

/// Simple character-based text truncation.
fn truncate_text(text: &str, max_tokens: u32) -> String {
    let max_chars = (max_tokens as usize) * 4;
    if text.len() <= max_chars {
        return text.to_string();
    }
    let truncated: String = text.chars().take(max_chars.saturating_sub(3)).collect();
    format!("{truncated}...")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn short_result_unchanged() {
        let result = truncate_tool_result("OK", 100);
        assert_eq!(result, "OK");
    }

    #[test]
    fn error_preserved() {
        let error = "Error: Connection refused to weather API";
        let result = truncate_tool_result(error, 100);
        assert_eq!(result, error);
    }

    #[test]
    fn json_key_fields_extracted() {
        let json = r#"{"name": "Berlin", "temperature": 22, "wind_speed": 15, "humidity": 65, "description": "Partly cloudy with occasional sunshine throughout the afternoon and evening hours", "pressure": 1013, "visibility": 10000}"#;
        let result = truncate_tool_result(json, 20);
        // Should extract key fields within budget
        assert!(result.contains("name"));
        assert!(TokenEstimator::estimate_tokens(&result) <= 25); // Small grace margin
    }

    #[test]
    fn json_array_truncated() {
        let items: Vec<String> = (0..50)
            .map(|i| format!(r#"{{"id": {i}, "value": "item_{i}"}}"#))
            .collect();
        let json = format!("[{}]", items.join(", "));
        let result = truncate_tool_result(&json, 30);
        assert!(result.contains("more"));
    }

    #[test]
    fn text_prioritizes_first_lines_and_numbers() {
        let text = "Header: Weather Report\n\
                     Location: Berlin\n\
                     Date: 2026-03-19\n\
                     Some filler text here\n\
                     More filler text\n\
                     Temperature: 22°C\n\
                     End of report";
        let result = truncate_tool_result(text, 20);
        // Should include header and lines with numbers
        assert!(result.contains("Header"));
    }

    #[test]
    fn empty_result_returns_empty() {
        assert!(truncate_tool_result("", 100).is_empty());
    }

    fn tool_msg(content: impl Into<String>) -> ChatMessage {
        #[cfg(feature = "pisovereign")]
        {
            ChatMessage::tool("call_id", content)
        }
        #[cfg(not(feature = "pisovereign"))]
        {
            ChatMessage::tool(content)
        }
    }

    #[test]
    fn compress_old_tool_results_only_before_last_user() {
        let long_json = r#"{"data": "x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]x]"}"#;
        let mut messages = vec![
            ChatMessage::user("first question"),
            ChatMessage::assistant("answer"),
            tool_msg(long_json),
            ChatMessage::user("second question"),
            tool_msg(long_json),
        ];

        let saved = compress_old_tool_results(&mut messages, 20);

        // Old tool result (index 2) should be compressed
        assert!(saved > 0);
        assert!(TokenEstimator::estimate_tokens(&messages[2].content) <= 25);

        // Current-turn tool result (index 4, after last user msg) should be untouched
        assert_eq!(messages[4].content, long_json);
    }

    #[test]
    fn compress_old_tool_results_skips_short_results() {
        let mut messages = vec![tool_msg("OK"), ChatMessage::user("question")];
        let saved = compress_old_tool_results(&mut messages, 100);
        assert_eq!(saved, 0);
        assert_eq!(messages[0].content, "OK");
    }

    #[test]
    fn compress_old_tool_results_empty_messages() {
        let mut messages: Vec<ChatMessage> = vec![];
        let saved = compress_old_tool_results(&mut messages, 100);
        assert_eq!(saved, 0);
    }

    #[test]
    fn compress_old_tool_results_no_user_message_compresses_all() {
        let long = "a]".repeat(300);
        let mut messages = vec![tool_msg(&long), ChatMessage::assistant("response")];
        let saved = compress_old_tool_results(&mut messages, 20);
        // No user message → last_user_idx = len → all messages are "old"
        assert!(saved > 0);
    }
}