agent_base/engine/
context.rs1use crate::types::ChatMessage;
2
3#[derive(Clone, Debug)]
4pub struct ContextWindowManager {
5 pub max_tokens: usize,
6 pub keep_first_n: usize,
8 pub keep_last_n: usize,
10}
11
12impl Default for ContextWindowManager {
13 fn default() -> Self {
14 Self {
15 max_tokens: 128_000,
16 keep_first_n: 1,
17 keep_last_n: 20,
18 }
19 }
20}
21
22impl ContextWindowManager {
23 const IMAGE_OVERHEAD_TOKENS: usize = 85;
25
26 pub fn new(max_tokens: usize) -> Self {
27 Self {
28 max_tokens,
29 ..Default::default()
30 }
31 }
32
33 pub fn with_keep_first_n(mut self, n: usize) -> Self {
34 self.keep_first_n = n;
35 self
36 }
37
38 pub fn with_keep_last_n(mut self, n: usize) -> Self {
39 self.keep_last_n = n;
40 self
41 }
42
43 pub fn estimate_tokens(text: &str) -> usize {
46 if text.is_empty() {
47 return 0;
48 }
49 let chars = text.chars().count();
50 let cjk_count = text.chars().filter(|c| is_cjk(*c)).count();
51 let latin_count = chars - cjk_count;
52 (cjk_count as f64 / 1.5 + latin_count as f64 / 4.0).ceil() as usize
54 }
55
56 fn message_tokens(msg: &ChatMessage) -> usize {
57 match msg {
58 ChatMessage::System { content } => Self::estimate_tokens(content),
59 ChatMessage::User { content, images } => {
60 let mut tokens = Self::estimate_tokens(content);
61 for img in images {
62 match img {
63 crate::types::ImageAttachment::Url { url, detail: _ } => {
64 tokens += Self::estimate_tokens(url);
65 }
66 crate::types::ImageAttachment::Base64 { data, media_type, detail: _ } => {
67 tokens += data.len() / 4;
68 if let Some(mt) = media_type {
69 tokens += Self::estimate_tokens(mt);
70 }
71 }
72 }
73 tokens += Self::IMAGE_OVERHEAD_TOKENS;
74 }
75 tokens
76 }
77 ChatMessage::Assistant { content, reasoning_content: _, tool_calls } => {
78 let mut tokens = content
79 .as_deref()
80 .map(|c| Self::estimate_tokens(c))
81 .unwrap_or(0);
82 if let Some(tc) = tool_calls {
83 for t in tc {
84 tokens += Self::estimate_tokens(&t.name);
85 tokens += Self::estimate_tokens(&t.arguments);
86 tokens += Self::estimate_tokens(&t.id);
87 }
88 }
89 tokens
90 }
91 ChatMessage::Tool { tool_call_id, content } => {
92 Self::estimate_tokens(tool_call_id) + Self::estimate_tokens(content)
93 }
94 }
95 }
96
97 pub fn trim(&self, messages: &mut Vec<ChatMessage>) {
104 if messages.is_empty() || self.max_tokens == 0 {
105 return;
106 }
107
108 let total_tokens: usize = messages.iter().map(|m| Self::message_tokens(m)).sum();
109 if total_tokens <= self.max_tokens {
110 return;
111 }
112
113 let keep_first = self.keep_first_n.min(messages.len());
114 let keep_last = self.keep_last_n.min(messages.len().saturating_sub(keep_first));
115
116 let trim_start = keep_first;
118 let trim_end = messages.len().saturating_sub(keep_last);
119 if trim_start >= trim_end {
120 return;
121 }
122
123 let mut current_tokens: usize = total_tokens;
124 let remove_idx = trim_start;
125
126 while current_tokens > self.max_tokens && remove_idx < trim_end {
127 let removed = Self::message_tokens(&messages[remove_idx]);
128 messages.remove(remove_idx);
129 current_tokens = current_tokens.saturating_sub(removed);
130 let new_trim_end = messages.len().saturating_sub(keep_last);
132 if remove_idx >= new_trim_end {
133 break;
134 }
135 }
136 }
137}
138
139fn is_cjk(c: char) -> bool {
140 matches!(
141 c,
142 '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{3000}'..='\u{303F}' | '\u{FF00}'..='\u{FFEF}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' | '\u{AC00}'..='\u{D7AF}' )
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_estimate_tokens_empty() {
158 assert_eq!(ContextWindowManager::estimate_tokens(""), 0);
159 }
160
161 #[test]
162 fn test_estimate_tokens_english() {
163 let text = "Hello world this is a test";
164 let tokens = ContextWindowManager::estimate_tokens(text);
165 assert!(tokens > 0 && tokens <= 15);
167 }
168
169 #[test]
170 fn test_trim_no_trim_needed() {
171 let mgr = ContextWindowManager::new(1000);
172 let mut msgs = vec![
173 ChatMessage::system("You are a helpful assistant."),
174 ChatMessage::user("Hello"),
175 ChatMessage::assistant("Hi there!"),
176 ];
177 let original_len = msgs.len();
178 mgr.trim(&mut msgs);
179 assert_eq!(msgs.len(), original_len);
180 }
181
182 #[test]
183 fn test_trim_keeps_first_and_last() {
184 let mgr = ContextWindowManager::new(8)
185 .with_keep_first_n(1)
186 .with_keep_last_n(2);
187 let mut msgs = vec![
188 ChatMessage::system("system"),
189 ChatMessage::user("message number one"),
190 ChatMessage::assistant("message number two"),
191 ChatMessage::user("message number three"),
192 ChatMessage::assistant("message number four"),
193 ChatMessage::user("message number five"),
194 ChatMessage::assistant("message number six"),
195 ];
196 mgr.trim(&mut msgs);
197 assert_eq!(msgs.len(), 3);
198 assert!(matches!(msgs[0], ChatMessage::System { .. }));
199 }
200}