open_agent/
context.rs

1//! Context management utilities for manual history management
2//!
3//! This module provides low-level helpers for managing conversation history.
4//! These are opt-in utilities - nothing is automatic. You decide when and how
5//! to manage context.
6//!
7//! # Features
8//!
9//! - Token estimation (character-based approximation)
10//! - Message truncation with system prompt preservation
11//! - Manual history management patterns
12//!
13//! # Examples
14//!
15//! ```rust
16//! use open_agent::{estimate_tokens, truncate_messages};
17//!
18//! // Estimate tokens
19//! let messages = vec![/* your messages */];
20//! let tokens = estimate_tokens(&messages);
21//! println!("Estimated tokens: {}", tokens);
22//!
23//! // Truncate when needed
24//! if tokens > 28000 {
25//!     let truncated = truncate_messages(&messages, 10, true);
26//!     // Use truncated messages...
27//! }
28//! ```
29
30use crate::types::Message;
31
32/// Estimate token count for message list
33///
34/// Uses character-based approximation (1 token ≈ 4 characters).
35///
36/// # Arguments
37///
38/// * `messages` - List of messages to estimate tokens for
39///
40/// # Returns
41///
42/// Estimated token count
43///
44/// # Note
45///
46/// This is an APPROXIMATION. Actual token counts vary by model family:
47/// - GPT models: ~70-85% accurate (different tokenizers)
48/// - Llama, Qwen, Mistral: ~70-85% accurate
49/// - Always include 10-20% safety margin when checking limits
50///
51/// For more accurate estimation, consider using tiktoken bindings
52/// (not included to keep dependencies minimal).
53///
54/// # Examples
55///
56/// ```rust
57/// use open_agent::{Message, MessageRole, estimate_tokens};
58///
59/// let messages = vec![
60///     Message::system("You are a helpful assistant"),
61///     Message::user("Hello!"),
62/// ];
63///
64/// let tokens = estimate_tokens(&messages);
65/// println!("Estimated tokens: {}", tokens);
66///
67/// // Check if approaching context limit
68/// if tokens > 28000 {
69///     println!("Need to truncate!");
70/// }
71/// ```
72pub fn estimate_tokens(messages: &[Message]) -> usize {
73    // Character-based approximation: 1 token ≈ 4 characters
74    // This is a conservative estimate that works across model families
75
76    if messages.is_empty() {
77        return 0;
78    }
79
80    let mut total_chars = 0;
81
82    for message in messages {
83        // Count role overhead (approximately 1-2 tokens)
84        total_chars += 8; // ~2 tokens for role formatting
85
86        // Count content
87        for block in &message.content {
88            match block {
89                crate::types::ContentBlock::Text(text) => {
90                    total_chars += text.text.len();
91                }
92                crate::types::ContentBlock::ToolUse(tool) => {
93                    // Tool calls add significant overhead
94                    total_chars += tool.name.len();
95                    total_chars += tool.id.len();
96                    total_chars += tool.input.to_string().len();
97                }
98                crate::types::ContentBlock::ToolResult(result) => {
99                    // Tool results add overhead
100                    total_chars += result.tool_use_id.len();
101                    total_chars += result.content.to_string().len();
102                }
103            }
104        }
105    }
106
107    // Add conversation-level overhead (~2-4 tokens)
108    total_chars += 16;
109
110    // Convert characters to tokens (4 chars ≈ 1 token, round up for safety)
111    total_chars.div_ceil(4)
112}
113
114/// Truncate message history, keeping recent messages
115///
116/// Always preserves the system prompt (if present) and keeps the most
117/// recent N messages. This is a simple truncation - it does NOT attempt
118/// to preserve tool chains or important context.
119///
120/// # Arguments
121///
122/// * `messages` - List of messages to truncate
123/// * `keep` - Number of recent messages to keep (default: 10)
124/// * `preserve_system` - Keep system message if present (default: true)
125///
126/// # Returns
127///
128/// Truncated message list (new Vec, original unchanged)
129///
130/// # Examples
131///
132/// ```rust
133/// use open_agent::{Message, Client, truncate_messages, estimate_tokens};
134///
135/// # async fn example(mut client: Client) {
136/// // Manual truncation when needed
137/// let tokens = estimate_tokens(client.history());
138/// if tokens > 28000 {
139///     let truncated = truncate_messages(client.history(), 10, true);
140///     *client.history_mut() = truncated;
141/// }
142/// # }
143/// ```
144///
145/// # Note
146///
147/// This is a SIMPLE truncation. For domain-specific needs (e.g.,
148/// preserving tool call chains, keeping important context), implement
149/// your own logic or use this as a starting point.
150///
151/// Warning: Truncating mid-conversation may remove context that the
152/// model needs to properly respond. Use judiciously at natural breakpoints.
153pub fn truncate_messages(messages: &[Message], keep: usize, preserve_system: bool) -> Vec<Message> {
154    if messages.is_empty() {
155        return Vec::new();
156    }
157
158    if messages.len() <= keep {
159        return messages.to_vec();
160    }
161
162    // Check if first message is system prompt
163    let has_system = preserve_system
164        && !messages.is_empty()
165        && messages[0].role == crate::types::MessageRole::System;
166
167    if has_system {
168        // Keep system + last N messages
169        let mut result = vec![messages[0].clone()];
170        if keep > 0 && messages.len() > 1 {
171            let start = messages.len().saturating_sub(keep);
172            result.extend_from_slice(&messages[start..]);
173        }
174        result
175    } else {
176        // Just keep last N messages
177        if keep > 0 {
178            let start = messages.len().saturating_sub(keep);
179            messages[start..].to_vec()
180        } else {
181            Vec::new()
182        }
183    }
184}
185
186/// Check if history is approaching a token limit
187///
188/// Convenience function that combines estimation with a threshold check.
189///
190/// # Arguments
191///
192/// * `messages` - Messages to check
193/// * `limit` - Token limit (e.g., 32000 for a 32k context window)
194/// * `margin` - Safety margin as a percentage (default: 0.9 = 90%)
195///
196/// # Returns
197///
198/// `true` if estimated tokens exceed limit * margin
199///
200/// # Examples
201///
202/// ```rust
203/// use open_agent::{is_approaching_limit, Message};
204///
205/// # fn example(messages: Vec<Message>) {
206/// if is_approaching_limit(&messages, 32000, 0.9) {
207///     println!("Time to truncate!");
208/// }
209/// # }
210/// ```
211pub fn is_approaching_limit(messages: &[Message], limit: usize, margin: f32) -> bool {
212    let estimated = estimate_tokens(messages);
213    let threshold = (limit as f32 * margin) as usize;
214    estimated > threshold
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::types::{ContentBlock, Message, MessageRole, TextBlock};
221
222    #[test]
223    fn test_estimate_tokens_empty() {
224        let messages: Vec<Message> = vec![];
225        assert_eq!(estimate_tokens(&messages), 0);
226    }
227
228    #[test]
229    fn test_estimate_tokens_simple() {
230        let messages = vec![Message::new(
231            MessageRole::User,
232            vec![ContentBlock::Text(TextBlock::new("Hello world"))],
233        )];
234
235        let tokens = estimate_tokens(&messages);
236        // "Hello world" = 11 chars + overhead ≈ 5-8 tokens
237        assert!((3..=10).contains(&tokens));
238    }
239
240    #[test]
241    fn test_truncate_messages_empty() {
242        let messages: Vec<Message> = vec![];
243        let truncated = truncate_messages(&messages, 10, true);
244        assert_eq!(truncated.len(), 0);
245    }
246
247    #[test]
248    fn test_truncate_messages_preserve_system() {
249        let messages = vec![
250            Message::system("System prompt"),
251            Message::user("Message 1"),
252            Message::user("Message 2"),
253            Message::user("Message 3"),
254            Message::user("Message 4"),
255        ];
256
257        let truncated = truncate_messages(&messages, 2, true);
258
259        // Should have system + last 2 = 3 messages
260        assert_eq!(truncated.len(), 3);
261        assert_eq!(truncated[0].role, MessageRole::System);
262    }
263
264    #[test]
265    fn test_truncate_messages_no_preserve() {
266        let messages = vec![
267            Message::system("System prompt"),
268            Message::user("Message 1"),
269            Message::user("Message 2"),
270            Message::user("Message 3"),
271        ];
272
273        let truncated = truncate_messages(&messages, 2, false);
274
275        // Should have only last 2 messages
276        assert_eq!(truncated.len(), 2);
277        assert_eq!(truncated[0].role, MessageRole::User);
278    }
279
280    #[test]
281    fn test_truncate_messages_keep_all() {
282        let messages = vec![Message::user("Message 1"), Message::user("Message 2")];
283
284        let truncated = truncate_messages(&messages, 10, true);
285        assert_eq!(truncated.len(), 2);
286    }
287
288    #[test]
289    fn test_is_approaching_limit() {
290        let messages = vec![Message::user("x".repeat(1000))];
291
292        // ~250 tokens, should not exceed 90% of 1000
293        assert!(!is_approaching_limit(&messages, 1000, 0.9));
294
295        // Should exceed 90% of 200
296        assert!(is_approaching_limit(&messages, 200, 0.9));
297    }
298}