pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Token counting and message trimming utilities.
//!
//! Provides `TokenCounter` trait for pluggable token counting,
//! `CharTokenCounter` for test approximation, and `trim_messages`
//! for context window management.
//!
//! Based on Group 16.6 and Group 19 of the pre-plan.

use crate::message::{Message, MessageContent};

/// How messages should be counted for trimming.
pub trait TokenCounter: Send + Sync {
    /// Count tokens in a slice of messages.
    fn count_messages(&self, messages: &[Message]) -> u32;
    /// Count tokens in a text string.
    fn count_text(&self, text: &str) -> u32;
}

/// Approximate token counter using character count / 4.
///
/// Suitable for tests — no tokenizer dependency required.
/// For production, implement `TokenCounter` with a real tokenizer
/// (tiktoken, etc.).
pub struct CharTokenCounter;

impl TokenCounter for CharTokenCounter {
    fn count_messages(&self, messages: &[Message]) -> u32 {
        messages
            .iter()
            .map(|m| self.count_text(&message_to_text(m)))
            .sum()
    }

    fn count_text(&self, text: &str) -> u32 {
        (text.chars().count() / 4).max(1) as u32
    }
}

/// Strategy for which messages to keep when trimming.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TrimStrategy {
    /// Keep the most recent messages (discard oldest first).
    Last,
    /// Keep the oldest messages (discard newest first).
    First,
}

/// Which message roles are acceptable at window boundaries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MessageRole {
    Human,
    Ai,
    System,
    Tool,
}

/// Options for `trim_messages`.
pub struct TrimOptions<'a> {
    /// Which end to keep messages from.
    pub strategy: TrimStrategy,
    /// Hard token cap.
    pub max_tokens: u32,
    /// Token counting implementation.
    pub token_counter: &'a dyn TokenCounter,
    /// If set, the trimmed window must start with this role.
    pub start_on: Option<MessageRole>,
    /// If set, the trimmed window must end with one of these roles.
    pub end_on: Option<Vec<MessageRole>>,
}

/// Trim a message list to fit within a token budget.
///
/// Messages are selected according to the strategy (Last = keep newest,
/// First = keep oldest), then boundary constraints (start_on, end_on)
/// are enforced by dropping messages at the boundary.
///
/// # Example
///
/// ```
/// use pe_core::token::{trim_messages, TrimOptions, TrimStrategy, CharTokenCounter};
/// use pe_core::message::Message;
///
/// let messages = vec![
///     Message::system("You are helpful"),
///     Message::human("What is Rust?"),
///     Message::ai("Rust is a systems programming language."),
/// ];
///
/// let trimmed = trim_messages(&messages, TrimOptions {
///     strategy: TrimStrategy::Last,
///     max_tokens: 20,
///     token_counter: &CharTokenCounter,
///     start_on: None,
///     end_on: None,
/// });
///
/// // Some messages will be trimmed to fit the budget
/// assert!(trimmed.len() <= messages.len());
/// ```
pub fn trim_messages(messages: &[Message], opts: TrimOptions) -> Vec<Message> {
    if messages.is_empty() {
        return vec![];
    }

    // Build candidate list based on strategy and budget
    let mut result: Vec<Message> = match opts.strategy {
        TrimStrategy::Last => {
            // Iterate from end, accumulate until budget
            let mut selected = Vec::new();
            let mut budget = opts.max_tokens;

            for msg in messages.iter().rev() {
                let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
                if cost > budget {
                    break;
                }
                budget -= cost;
                selected.push(msg.clone());
            }
            selected.reverse();
            selected
        }
        TrimStrategy::First => {
            // Iterate from start, accumulate until budget
            let mut selected = Vec::new();
            let mut budget = opts.max_tokens;

            for msg in messages {
                let cost = opts.token_counter.count_messages(std::slice::from_ref(msg));
                if cost > budget {
                    break;
                }
                budget -= cost;
                selected.push(msg.clone());
            }
            selected
        }
    };

    // Enforce start_on: drop leading messages that don't match
    if let Some(ref start_role) = opts.start_on {
        if let Some(start_idx) = result.iter().position(|m| message_has_role(m, start_role)) {
            result.drain(..start_idx);
        } else {
            result.clear();
        }
    }

    // Enforce end_on: drop trailing messages that don't match
    if let Some(ref end_roles) = opts.end_on {
        while !result.is_empty()
            && !end_roles
                .iter()
                .any(|r| message_has_role(result.last().unwrap(), r))
        {
            result.pop();
        }
    }

    result
}

fn message_has_role(msg: &Message, role: &MessageRole) -> bool {
    matches!(
        (msg, role),
        (Message::Human(_), MessageRole::Human)
            | (Message::Ai(_), MessageRole::Ai)
            | (Message::System(_), MessageRole::System)
            | (Message::Tool(_), MessageRole::Tool)
    )
}

fn message_to_text(msg: &Message) -> String {
    match msg {
        Message::Human(m) => content_to_text(&m.content),
        Message::Ai(m) => content_to_text(&m.content),
        Message::System(m) => m.content.clone(),
        Message::Tool(m) => m.content.clone(),
    }
}

fn content_to_text(content: &MessageContent) -> String {
    match content {
        MessageContent::Text(t) => t.clone(),
        MessageContent::Blocks(blocks) => blocks
            .iter()
            .map(|b| match b {
                crate::message::ContentBlock::Text { text } => text.clone(),
                _ => String::new(),
            })
            .collect::<Vec<_>>()
            .join(" "),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_char_token_counter() {
        let counter = CharTokenCounter;
        // "hello world" = 11 chars / 4 = 2 (rounded down, min 1)
        assert_eq!(counter.count_text("hello world"), 2);
        assert_eq!(counter.count_text("hi"), 1); // 2 / 4 = 0, but min 1
        assert_eq!(counter.count_text(""), 1); // min 1
    }

    #[test]
    fn test_trim_empty_input() {
        let result = trim_messages(
            &[],
            TrimOptions {
                strategy: TrimStrategy::Last,
                max_tokens: 100,
                token_counter: &CharTokenCounter,
                start_on: None,
                end_on: None,
            },
        );
        assert!(result.is_empty());
    }

    #[test]
    fn test_trim_all_under_budget() {
        let messages = vec![Message::human("hello"), Message::ai("hi")];
        let result = trim_messages(
            &messages,
            TrimOptions {
                strategy: TrimStrategy::Last,
                max_tokens: 1000,
                token_counter: &CharTokenCounter,
                start_on: None,
                end_on: None,
            },
        );
        assert_eq!(result.len(), 2);
    }

    #[test]
    fn test_trim_last_strategy() {
        let messages = vec![
            Message::system("You are helpful. This is a long system prompt with many tokens."),
            Message::human("short q"),
            Message::ai("short a"),
        ];
        let result = trim_messages(
            &messages,
            TrimOptions {
                strategy: TrimStrategy::Last,
                max_tokens: 5,
                token_counter: &CharTokenCounter,
                start_on: None,
                end_on: None,
            },
        );
        // Should keep newest messages that fit
        assert!(result.len() < messages.len());
        // Last messages should be preserved
        if !result.is_empty() {
            assert!(matches!(result.last().unwrap(), Message::Ai(_)));
        }
    }

    #[test]
    fn test_trim_first_strategy() {
        let messages = vec![
            Message::human("first"),
            Message::ai("second"),
            Message::human("this is a much longer message that uses more tokens"),
        ];
        let result = trim_messages(
            &messages,
            TrimOptions {
                strategy: TrimStrategy::First,
                max_tokens: 5,
                token_counter: &CharTokenCounter,
                start_on: None,
                end_on: None,
            },
        );
        // Should keep oldest messages that fit
        assert!(result.len() < messages.len());
        if !result.is_empty() {
            assert!(matches!(result[0], Message::Human(_)));
        }
    }

    #[test]
    fn test_trim_start_on_human() {
        let messages = vec![
            Message::system("sys"),
            Message::ai("ai response"),
            Message::human("question"),
        ];
        let result = trim_messages(
            &messages,
            TrimOptions {
                strategy: TrimStrategy::Last,
                max_tokens: 1000,
                token_counter: &CharTokenCounter,
                start_on: Some(MessageRole::Human),
                end_on: None,
            },
        );
        // Should start with a Human message
        assert!(matches!(result[0], Message::Human(_)));
    }

    #[test]
    fn test_trim_end_on_human_or_tool() {
        let messages = vec![Message::human("q"), Message::ai("response")];
        let result = trim_messages(
            &messages,
            TrimOptions {
                strategy: TrimStrategy::Last,
                max_tokens: 1000,
                token_counter: &CharTokenCounter,
                start_on: None,
                end_on: Some(vec![MessageRole::Human, MessageRole::Tool]),
            },
        );
        // Should not end with AI message — AI gets dropped
        if !result.is_empty() {
            assert!(!matches!(result.last().unwrap(), Message::Ai(_)));
        }
    }
}