aagt_core/agent/
context.rs1use crate::agent::message::Message;
10use crate::error::Result;
11
12#[derive(Debug, Clone)]
14pub struct ContextConfig {
15 pub max_tokens: usize,
17 pub max_history_messages: usize,
19 pub response_reserve: usize,
21 pub enable_cache_control: bool,
23 pub smart_pruning: bool,
25}
26
27impl Default for ContextConfig {
28 fn default() -> Self {
29 Self {
30 max_tokens: 128000, max_history_messages: 50,
32 response_reserve: 4096,
33 enable_cache_control: false,
34 smart_pruning: false,
35 }
36 }
37}
38
39#[async_trait::async_trait]
41pub trait ContextInjector: Send + Sync {
42 async fn inject(&self) -> Result<Vec<Message>>;
44}
45
46pub struct ContextManager {
48 config: ContextConfig,
49 system_prompt: Option<String>,
50 injectors: Vec<Box<dyn ContextInjector>>,
51}
52
53impl ContextManager {
54 pub fn new(config: ContextConfig) -> Self {
56 Self {
57 config,
58 system_prompt: None,
59 injectors: Vec::new(),
60 }
61 }
62
63 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
65 self.system_prompt = Some(prompt.into());
66 }
67
68 pub fn add_injector(&mut self, injector: Box<dyn ContextInjector>) {
70 self.injectors.push(injector);
71 }
72
73 pub async fn build_context(&self, history: &[Message]) -> Result<Vec<Message>> {
81 let bpe = tiktoken_rs::cl100k_base().map_err(|e| {
83 crate::error::Error::Internal(format!("Failed to load tokenizer: {}", e))
84 })?;
85
86 let mut final_context_start = Vec::new();
87
88 if let Some(prompt) = &self.system_prompt {
90 final_context_start.push(Message::system(prompt.clone()));
91 }
92
93 for injector in &self.injectors {
96 match injector.inject().await {
97 Ok(msgs) => final_context_start.extend(msgs),
98 Err(e) => tracing::warn!("Context injector failed: {}", e),
99 }
100 }
101
102 const SAFETY_MARGIN: usize = 1000;
105
106 let reserved_response = self.config.response_reserve;
107 let max_window = self.config.max_tokens;
108
109 let mut current_usage = 0;
111 for msg in &final_context_start {
112 current_usage += bpe.encode_with_special_tokens(&msg.content.as_text()).len();
113 current_usage += 4; }
115
116 let total_reserved = reserved_response + SAFETY_MARGIN + current_usage;
118 if total_reserved > max_window {
119 tracing::warn!(
120 "System prompt + RAG context exceeds context window! (Usage: {}, Limit: {})",
121 current_usage,
122 max_window - reserved_response - SAFETY_MARGIN
123 );
124 }
126
127 let history_budget = if max_window > total_reserved {
128 max_window - total_reserved
129 } else {
130 0
131 };
132
133 let mut selected_history = Vec::new();
135 let mut history_usage = 0;
136 let mut pruned_messages = Vec::new();
137
138 let history_slice = if history.len() > self.config.max_history_messages {
139 let (pruned, selected) = history.split_at(history.len() - self.config.max_history_messages);
140 pruned_messages.extend(pruned.iter().cloned());
141 selected
142 } else {
143 history
144 };
145
146 for msg in history_slice.iter().rev() {
148 let tokens = bpe.encode_with_special_tokens(&msg.content.as_text()).len();
149 let cost = tokens + 4;
150
151 if history_usage + cost <= history_budget {
152 history_usage += cost;
153 selected_history.push(msg.clone());
154 } else {
155 pruned_messages.push(msg.clone());
156 }
157 }
158
159 if self.config.smart_pruning && !pruned_messages.is_empty() {
161 let mut log = String::from("### Historical Observation Log (Pruned Summaries)\n");
162 for msg in pruned_messages {
165 match msg.role {
166 crate::agent::message::Role::Assistant => {
167 let text = msg.content.as_text();
168 let snippet = if text.len() > 60 {
169 format!("{}...", &text[..60].replace('\n', " "))
170 } else {
171 text.replace('\n', " ")
172 };
173 log.push_str(&format!("- Assistant: {}\n", snippet));
174 }
175 crate::agent::message::Role::Tool => {
176 let name = msg.name.as_deref().unwrap_or("unknown_tool");
177 log.push_str(&format!("- Tool executed: {}\n", name));
178 }
179 _ => {}
180 }
181 }
182 final_context_start.push(Message::system(log));
183 }
184
185 let mut final_messages = final_context_start;
187 selected_history.reverse();
188 final_messages.extend(selected_history);
189
190 Ok(final_messages)
191 }
192
193 pub fn estimate_tokens(messages: &[Message]) -> usize {
195 if let Ok(bpe) = tiktoken_rs::cl100k_base() {
196 messages
197 .iter()
198 .map(|m| bpe.encode_with_special_tokens(&m.content.as_text()).len() + 4)
199 .sum()
200 } else {
201 messages
203 .iter()
204 .map(|m| m.content.as_text().len() / 4)
205 .sum::<usize>()
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 #[tokio::test]
216 async fn test_smart_pruning_generation() {
217 let config = ContextConfig {
218 max_history_messages: 2, max_tokens: 10000,
220 response_reserve: 1000,
221 smart_pruning: true,
222 ..Default::default()
223 };
224 let mut mgr = ContextManager::new(config);
225 mgr.set_system_prompt("System Prompt");
226
227 let history = vec![
228 Message::assistant("I am thinking about the first task."),
229 Message::user("What about the second one?"),
230 Message::assistant("Executing the third part now."),
231 Message::user("Final question."),
232 ];
233
234 let ctx = mgr.build_context(&history).await.unwrap();
237
238 assert_eq!(ctx.len(), 4, "Context should contain System, Log, and 2 history messages");
240
241 let log_msg = &ctx[1];
242 assert!(log_msg.content.as_text().contains("Observation Log"), "Should contain Observation Log");
243 assert!(log_msg.content.as_text().contains("Assistant"), "Should mention Assistant in log");
244 }
245
246 #[tokio::test]
247 async fn test_basic_inclusion() {
248 let config = ContextConfig::default();
249 let mgr = ContextManager::new(config);
250 let history = vec![Message::user("test")];
251 let ctx = mgr.build_context(&history).await.unwrap();
252 assert!(ctx.len() >= 1);
255 }
256}