open_agent/
context.rs

1//! Context management utilities for manual history management
2//!
3//! This module provides low-level helpers for managing conversation history.
4//! These are opt-in utilities - nothing is automatic. You decide when and how
5//! to manage context.
6//!
7//! # Features
8//!
9//! - Token estimation (character-based approximation)
10//! - Message truncation with system prompt preservation
11//! - Manual history management patterns
12//!
13//! # Examples
14//!
15//! ```rust
16//! use open_agent::{estimate_tokens, truncate_messages};
17//!
18//! // Estimate tokens
19//! let messages = vec![/* your messages */];
20//! let tokens = estimate_tokens(&messages);
21//! println!("Estimated tokens: {}", tokens);
22//!
23//! // Truncate when needed
24//! if tokens > 28000 {
25//!     let truncated = truncate_messages(&messages, 10, true);
26//!     // Use truncated messages...
27//! }
28//! ```
29
30use crate::types::Message;
31
32/// Estimate token count for message list
33///
34/// Uses character-based approximation (1 token ≈ 4 characters).
35/// For images, uses OpenAI Vision API token costs adjusted for different
36/// detail levels.
37///
38/// # Arguments
39///
40/// * `messages` - List of messages to estimate tokens for
41///
42/// # Returns
43///
44/// Estimated token count
45///
46/// # Note
47///
48/// This is an APPROXIMATION. Actual token counts vary by model family:
49/// - GPT models: ~70-85% accurate (different tokenizers)
50/// - Llama, Qwen, Mistral: ~70-85% accurate
51/// - Always include 10-20% safety margin when checking limits
52///
53/// Image token costs are based on OpenAI's Vision API and may differ significantly for local models.
54///
55/// For more accurate estimation, consider using tiktoken bindings
56/// (not included to keep dependencies minimal).
57///
58/// # Examples
59///
60/// ```rust
61/// use open_agent::{Message, MessageRole, estimate_tokens};
62///
63/// let messages = vec![
64///     Message::system("You are a helpful assistant"),
65///     Message::user("Hello!"),
66/// ];
67///
68/// let tokens = estimate_tokens(&messages);
69/// println!("Estimated tokens: {}", tokens);
70///
71/// // Check if approaching context limit
72/// if tokens > 28000 {
73///     println!("Need to truncate!");
74/// }
75/// ```
76pub fn estimate_tokens(messages: &[Message]) -> usize {
77    // Character-based approximation: 1 token ≈ 4 characters
78    // This is a conservative estimate that works across model families
79
80    if messages.is_empty() {
81        return 0;
82    }
83
84    let mut total_chars = 0;
85
86    for message in messages {
87        // Count role overhead (approximately 1-2 tokens)
88        total_chars += 8; // ~2 tokens for role formatting
89
90        // Count content
91        for block in &message.content {
92            match block {
93                crate::types::ContentBlock::Text(text) => {
94                    total_chars += text.text.len();
95                }
96                crate::types::ContentBlock::Image(image) => {
97                    // Token estimates based on OpenAI Vision API
98                    // Local models may have significantly different token costs
99                    use crate::types::ImageDetail;
100                    let token_estimate = match image.detail() {
101                        ImageDetail::Low => 85 * 4,   // Fixed ~85 tokens (512x512 max)
102                        ImageDetail::High => 300 * 4, // Conservative upper bound (variable based on dimensions)
103                        ImageDetail::Auto => 200 * 4, // Middle ground default
104                    };
105                    total_chars += token_estimate;
106                }
107                crate::types::ContentBlock::ToolUse(tool) => {
108                    // Tool calls add significant overhead
109                    total_chars += tool.name().len();
110                    total_chars += tool.id().len();
111                    total_chars += tool.input().to_string().len();
112                }
113                crate::types::ContentBlock::ToolResult(result) => {
114                    // Tool results add overhead
115                    total_chars += result.tool_use_id().len();
116                    total_chars += result.content().to_string().len();
117                }
118            }
119        }
120    }
121
122    // Add conversation-level overhead (~2-4 tokens)
123    total_chars += 16;
124
125    // Convert characters to tokens (4 chars ≈ 1 token, round up for safety)
126    total_chars.div_ceil(4)
127}
128
129/// Truncate message history, keeping recent messages
130///
131/// Always preserves the system prompt (if present) and keeps the most
132/// recent N messages. This is a simple truncation - it does NOT attempt
133/// to preserve tool chains or important context.
134///
135/// # Arguments
136///
137/// * `messages` - List of messages to truncate
138/// * `keep` - Number of recent messages to keep (default: 10)
139/// * `preserve_system` - Keep system message if present (default: true)
140///
141/// # Returns
142///
143/// Truncated message list (new Vec, original unchanged)
144///
145/// # Examples
146///
147/// ```rust
148/// use open_agent::{Message, Client, truncate_messages, estimate_tokens};
149///
150/// # async fn example(mut client: Client) {
151/// // Manual truncation when needed
152/// let tokens = estimate_tokens(client.history());
153/// if tokens > 28000 {
154///     let truncated = truncate_messages(client.history(), 10, true);
155///     *client.history_mut() = truncated;
156/// }
157/// # }
158/// ```
159///
160/// # Note
161///
162/// This is a SIMPLE truncation. For domain-specific needs (e.g.,
163/// preserving tool call chains, keeping important context), implement
164/// your own logic or use this as a starting point.
165///
166/// Warning: Truncating mid-conversation may remove context that the
167/// model needs to properly respond. Use judiciously at natural breakpoints.
168pub fn truncate_messages(messages: &[Message], keep: usize, preserve_system: bool) -> Vec<Message> {
169    if messages.is_empty() {
170        return Vec::new();
171    }
172
173    if messages.len() <= keep {
174        return messages.to_vec();
175    }
176
177    // Check if first message is system prompt
178    let has_system = preserve_system
179        && !messages.is_empty()
180        && messages[0].role == crate::types::MessageRole::System;
181
182    if has_system {
183        // Keep system + last N messages
184        let mut result = vec![messages[0].clone()];
185        if keep > 0 && messages.len() > 1 {
186            let start = messages.len().saturating_sub(keep);
187            result.extend_from_slice(&messages[start..]);
188        }
189        result
190    } else {
191        // Just keep last N messages
192        if keep > 0 {
193            let start = messages.len().saturating_sub(keep);
194            messages[start..].to_vec()
195        } else {
196            Vec::new()
197        }
198    }
199}
200
201/// Check if history is approaching a token limit
202///
203/// Convenience function that combines estimation with a threshold check.
204///
205/// # Arguments
206///
207/// * `messages` - Messages to check
208/// * `limit` - Token limit (e.g., 32000 for a 32k context window)
209/// * `margin` - Safety margin as a percentage (default: 0.9 = 90%)
210///
211/// # Returns
212///
213/// `true` if estimated tokens exceed limit * margin
214///
215/// # Examples
216///
217/// ```rust
218/// use open_agent::{is_approaching_limit, Message};
219///
220/// # fn example(messages: Vec<Message>) {
221/// if is_approaching_limit(&messages, 32000, 0.9) {
222///     println!("Time to truncate!");
223/// }
224/// # }
225/// ```
226pub fn is_approaching_limit(messages: &[Message], limit: usize, margin: f32) -> bool {
227    let estimated = estimate_tokens(messages);
228    let threshold = (limit as f32 * margin) as usize;
229    estimated > threshold
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::types::{ContentBlock, Message, MessageRole, TextBlock};
236
237    #[test]
238    fn test_estimate_tokens_empty() {
239        let messages: Vec<Message> = vec![];
240        assert_eq!(estimate_tokens(&messages), 0);
241    }
242
243    #[test]
244    fn test_estimate_tokens_simple() {
245        let messages = vec![Message::new(
246            MessageRole::User,
247            vec![ContentBlock::Text(TextBlock::new("Hello world"))],
248        )];
249
250        let tokens = estimate_tokens(&messages);
251        // "Hello world" = 11 chars + overhead ≈ 5-8 tokens
252        assert!((3..=10).contains(&tokens));
253    }
254
255    #[test]
256    fn test_truncate_messages_empty() {
257        let messages: Vec<Message> = vec![];
258        let truncated = truncate_messages(&messages, 10, true);
259        assert_eq!(truncated.len(), 0);
260    }
261
262    #[test]
263    fn test_truncate_messages_preserve_system() {
264        let messages = vec![
265            Message::system("System prompt"),
266            Message::user("Message 1"),
267            Message::user("Message 2"),
268            Message::user("Message 3"),
269            Message::user("Message 4"),
270        ];
271
272        let truncated = truncate_messages(&messages, 2, true);
273
274        // Should have system + last 2 = 3 messages
275        assert_eq!(truncated.len(), 3);
276        assert_eq!(truncated[0].role, MessageRole::System);
277    }
278
279    #[test]
280    fn test_truncate_messages_no_preserve() {
281        let messages = vec![
282            Message::system("System prompt"),
283            Message::user("Message 1"),
284            Message::user("Message 2"),
285            Message::user("Message 3"),
286        ];
287
288        let truncated = truncate_messages(&messages, 2, false);
289
290        // Should have only last 2 messages
291        assert_eq!(truncated.len(), 2);
292        assert_eq!(truncated[0].role, MessageRole::User);
293    }
294
295    #[test]
296    fn test_truncate_messages_keep_all() {
297        let messages = vec![Message::user("Message 1"), Message::user("Message 2")];
298
299        let truncated = truncate_messages(&messages, 10, true);
300        assert_eq!(truncated.len(), 2);
301    }
302
303    #[test]
304    fn test_is_approaching_limit() {
305        let messages = vec![Message::user("x".repeat(1000))];
306
307        // ~250 tokens, should not exceed 90% of 1000
308        assert!(!is_approaching_limit(&messages, 1000, 0.9));
309
310        // Should exceed 90% of 200
311        assert!(is_approaching_limit(&messages, 200, 0.9));
312    }
313
314    #[test]
315    fn test_estimate_tokens_image_detail_low() {
316        use crate::types::{ImageBlock, ImageDetail};
317
318        let img = ImageBlock::from_url("https://example.com/img.jpg")
319            .unwrap()
320            .with_detail(ImageDetail::Low);
321        let msg = Message::new(MessageRole::User, vec![ContentBlock::Image(img)]);
322
323        let token_count = estimate_tokens(&[msg]);
324        // Low detail: ~85 tokens * 4 chars/token = 340 chars
325        assert!(
326            (75..=95).contains(&token_count),
327            "Low detail should be ~85 tokens, got {}",
328            token_count
329        );
330    }
331
332    #[test]
333    fn test_estimate_tokens_image_detail_high() {
334        use crate::types::{ImageBlock, ImageDetail};
335
336        let img = ImageBlock::from_url("https://example.com/img.jpg")
337            .unwrap()
338            .with_detail(ImageDetail::High);
339        let msg = Message::new(MessageRole::User, vec![ContentBlock::Image(img)]);
340
341        let token_count = estimate_tokens(&[msg]);
342        // High detail: ~300 tokens * 4 chars/token = 1200 chars (conservative)
343        assert!(
344            token_count >= 250,
345            "High detail should be ~300+ tokens, got {}",
346            token_count
347        );
348    }
349
350    #[test]
351    fn test_estimate_tokens_image_detail_auto() {
352        use crate::types::{ImageBlock, ImageDetail};
353
354        let img = ImageBlock::from_url("https://example.com/img.jpg")
355            .unwrap()
356            .with_detail(ImageDetail::Auto);
357        let msg = Message::new(MessageRole::User, vec![ContentBlock::Image(img)]);
358
359        let token_count = estimate_tokens(&[msg]);
360        // Auto detail: ~200 tokens * 4 chars/token = 800 chars (middle ground)
361        assert!(
362            (150..=250).contains(&token_count),
363            "Auto detail should be ~200 tokens, got {}",
364            token_count
365        );
366    }
367}