Skip to main content

agent_base/engine/
context.rs

1use crate::types::ChatMessage;
2
3#[derive(Clone, Debug)]
4pub struct ContextWindowManager {
5    pub max_tokens: usize,
6    /// Always keep first N messages (typically system prompt)
7    pub keep_first_n: usize,
8    /// Always keep last N messages
9    pub keep_last_n: usize,
10}
11
12impl Default for ContextWindowManager {
13    fn default() -> Self {
14        Self {
15            max_tokens: 128_000,
16            keep_first_n: 1,
17            keep_last_n: 20,
18        }
19    }
20}
21
22impl ContextWindowManager {
23    /// OpenAI Vision API fixed token overhead per image
24    const IMAGE_OVERHEAD_TOKENS: usize = 85;
25
26    pub fn new(max_tokens: usize) -> Self {
27        Self {
28            max_tokens,
29            ..Default::default()
30        }
31    }
32
33    pub fn with_keep_first_n(mut self, n: usize) -> Self {
34        self.keep_first_n = n;
35        self
36    }
37
38    pub fn with_keep_last_n(mut self, n: usize) -> Self {
39        self.keep_last_n = n;
40        self
41    }
42
43    /// Simple token estimation: ~4 chars/token for Latin, ~1.5 for CJK
44    /// Mixed text uses a compromise of 3 chars/token
45    pub fn estimate_tokens(text: &str) -> usize {
46        if text.is_empty() {
47            return 0;
48        }
49        let chars = text.chars().count();
50        let cjk_count = text.chars().filter(|c| is_cjk(*c)).count();
51        let latin_count = chars - cjk_count;
52        // CJK: ~1.5 chars/token, Latin: ~4 chars/token
53        (cjk_count as f64 / 1.5 + latin_count as f64 / 4.0).ceil() as usize
54    }
55
56    fn message_tokens(msg: &ChatMessage) -> usize {
57        match msg {
58            ChatMessage::System { content } => Self::estimate_tokens(content),
59            ChatMessage::User { content, images } => {
60                let mut tokens = Self::estimate_tokens(content);
61                for img in images {
62                    match img {
63                        crate::types::ImageAttachment::Url { url, detail: _ } => {
64                            tokens += Self::estimate_tokens(url);
65                        }
66                        crate::types::ImageAttachment::Base64 { data, media_type, detail: _ } => {
67                            tokens += data.len() / 4;
68                            if let Some(mt) = media_type {
69                                tokens += Self::estimate_tokens(mt);
70                            }
71                        }
72                    }
73                    tokens += Self::IMAGE_OVERHEAD_TOKENS;
74                }
75                tokens
76            }
77            ChatMessage::Assistant { content, reasoning_content: _, tool_calls } => {
78                let mut tokens = content
79                    .as_deref()
80                    .map(|c| Self::estimate_tokens(c))
81                    .unwrap_or(0);
82                if let Some(tc) = tool_calls {
83                    for t in tc {
84                        tokens += Self::estimate_tokens(&t.name);
85                        tokens += Self::estimate_tokens(&t.arguments);
86                        tokens += Self::estimate_tokens(&t.id);
87                    }
88                }
89                tokens
90            }
91            ChatMessage::Tool { tool_call_id, content } => {
92                Self::estimate_tokens(tool_call_id) + Self::estimate_tokens(content)
93            }
94        }
95    }
96
97    /// Trim message list to keep total tokens under `max_tokens`。
98    ///
99    /// Trimming strategy:
100    /// - Always keep the first `keep_first_n` messages (typically system prompt)
101    /// - Always keep the last `keep_last_n` messages (recent conversation)
102    /// - Remove oldest messages from the middle until within budget
103    pub fn trim(&self, messages: &mut Vec<ChatMessage>) {
104        if messages.is_empty() || self.max_tokens == 0 {
105            return;
106        }
107
108        let total_tokens: usize = messages.iter().map(|m| Self::message_tokens(m)).sum();
109        if total_tokens <= self.max_tokens {
110            return;
111        }
112
113        let keep_first = self.keep_first_n.min(messages.len());
114        let keep_last = self.keep_last_n.min(messages.len().saturating_sub(keep_first));
115
116        // Trimmable range: [keep_first, messages.len() - keep_last)
117        let trim_start = keep_first;
118        let trim_end = messages.len().saturating_sub(keep_last);
119        if trim_start >= trim_end {
120            return;
121        }
122
123        let mut current_tokens: usize = total_tokens;
124        let remove_idx = trim_start;
125
126        while current_tokens > self.max_tokens && remove_idx < trim_end {
127            let removed = Self::message_tokens(&messages[remove_idx]);
128            messages.remove(remove_idx);
129            current_tokens = current_tokens.saturating_sub(removed);
130            // Do not increment remove_idx because remove shifts elements down
131            let new_trim_end = messages.len().saturating_sub(keep_last);
132            if remove_idx >= new_trim_end {
133                break;
134            }
135        }
136    }
137}
138
139fn is_cjk(c: char) -> bool {
140    matches!(
141        c,
142        '\u{4E00}'..='\u{9FFF}'   // CJK Unified Ideographs
143        | '\u{3400}'..='\u{4DBF}' // CJK Unified Ideographs Extension A
144        | '\u{3000}'..='\u{303F}' // CJK Symbols and Punctuation
145        | '\u{FF00}'..='\u{FFEF}' // Halfwidth and Fullwidth Forms
146        | '\u{3040}'..='\u{309F}' // Hiragana
147        | '\u{30A0}'..='\u{30FF}' // Katakana
148        | '\u{AC00}'..='\u{D7AF}' // Hangul Syllables
149    )
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_estimate_tokens_empty() {
158        assert_eq!(ContextWindowManager::estimate_tokens(""), 0);
159    }
160
161    #[test]
162    fn test_estimate_tokens_english() {
163        let text = "Hello world this is a test";
164        let tokens = ContextWindowManager::estimate_tokens(text);
165        // ~28 chars / 4 ≈ 7
166        assert!(tokens > 0 && tokens <= 15);
167    }
168
169    #[test]
170    fn test_trim_no_trim_needed() {
171        let mgr = ContextWindowManager::new(1000);
172        let mut msgs = vec![
173            ChatMessage::system("You are a helpful assistant."),
174            ChatMessage::user("Hello"),
175            ChatMessage::assistant("Hi there!"),
176        ];
177        let original_len = msgs.len();
178        mgr.trim(&mut msgs);
179        assert_eq!(msgs.len(), original_len);
180    }
181
182    #[test]
183    fn test_trim_keeps_first_and_last() {
184        let mgr = ContextWindowManager::new(8)
185            .with_keep_first_n(1)
186            .with_keep_last_n(2);
187        let mut msgs = vec![
188            ChatMessage::system("system"),
189            ChatMessage::user("message number one"),
190            ChatMessage::assistant("message number two"),
191            ChatMessage::user("message number three"),
192            ChatMessage::assistant("message number four"),
193            ChatMessage::user("message number five"),
194            ChatMessage::assistant("message number six"),
195        ];
196        mgr.trim(&mut msgs);
197        assert_eq!(msgs.len(), 3);
198        assert!(matches!(msgs[0], ChatMessage::System { .. }));
199    }
200}