Skip to main content

aagt_core/agent/
context.rs

1//! Context Management Module
2//!
3//! This module provides the `ContextManager` which is responsible for:
4//! - Managing conversation history (short-term memory)
5//! - Constructing the final prompt/messages for the LLM
6//! - Handling token budgeting and windowing
7//! - Injecting system prompts and dynamic context (RAG)
8
9use crate::agent::message::Message;
10use crate::error::Result;
11
12/// Configuration for the Context Manager
13#[derive(Debug, Clone)]
14pub struct ContextConfig {
15    /// Maximum tokens allowed in the context window
16    pub max_tokens: usize,
17    /// Maximum number of messages to keep in history
18    pub max_history_messages: usize,
19    /// Reserve tokens for the response
20    pub response_reserve: usize,
21    /// Whether to enable explicit context caching markers
22    pub enable_cache_control: bool,
23    /// Whether to summarize pruned history
24    pub smart_pruning: bool,
25}
26
27impl Default for ContextConfig {
28    fn default() -> Self {
29        Self {
30            max_tokens: 128000, // Modern default (e.g. GPT-4o)
31            max_history_messages: 50,
32            response_reserve: 4096,
33            enable_cache_control: false,
34            smart_pruning: false,
35        }
36    }
37}
38
39/// Trait for injecting dynamic context
40#[async_trait::async_trait]
41pub trait ContextInjector: Send + Sync {
42    /// Generate messages to inject into the context
43    async fn inject(&self) -> Result<Vec<Message>>;
44}
45
46/// Manages the context window for an agent
47pub struct ContextManager {
48    config: ContextConfig,
49    system_prompt: Option<String>,
50    injectors: Vec<Box<dyn ContextInjector>>,
51}
52
53impl ContextManager {
54    /// Create a new ContextManager
55    pub fn new(config: ContextConfig) -> Self {
56        Self {
57            config,
58            system_prompt: None,
59            injectors: Vec::new(),
60        }
61    }
62
63    /// Set the system prompt
64    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
65        self.system_prompt = Some(prompt.into());
66    }
67
68    /// Add a context injector
69    pub fn add_injector(&mut self, injector: Box<dyn ContextInjector>) {
70        self.injectors.push(injector);
71    }
72
73    /// Construct the final list of messages to send to the provider
74    ///
75    /// This method applies:
76    /// 1. System prompt injection (Protected)
77    /// 2. Dynamic Context Injection (RAG, etc.) (Protected)
78    /// 3. Token budgeting using tiktoken (Soft Pruning)
79    /// 4. Message windowing (based on max_history_messages)
80    pub async fn build_context(&self, history: &[Message]) -> Result<Vec<Message>> {
81        // 1. Initialize Tokenizer
82        let bpe = tiktoken_rs::cl100k_base().map_err(|e| {
83            crate::error::Error::Internal(format!("Failed to load tokenizer: {}", e))
84        })?;
85
86        let mut final_context_start = Vec::new();
87
88        // --- 1. System Prompt (Protected) ---
89        if let Some(prompt) = &self.system_prompt {
90            final_context_start.push(Message::system(prompt.clone()));
91        }
92
93        // --- 2. Run Injectors (Protected - e.g. RAG) ---
94        // In a more advanced version, we might want to budget RAG too, but for now we treat it as critical context.
95        for injector in &self.injectors {
96            match injector.inject().await {
97                Ok(msgs) => final_context_start.extend(msgs),
98                Err(e) => tracing::warn!("Context injector failed: {}", e),
99            }
100        }
101
102        // --- 3. Calculate Budget ---
103        // Safety Margin: 1000 tokens for formatting, JSON overhead, and fragmentation
104        const SAFETY_MARGIN: usize = 1000;
105
106        let reserved_response = self.config.response_reserve;
107        let max_window = self.config.max_tokens;
108
109        // Calculate current usage from System + RAG
110        let mut current_usage = 0;
111        for msg in &final_context_start {
112            current_usage += bpe.encode_with_special_tokens(&msg.content.as_text()).len();
113            current_usage += 4; // Approx per-message overhead
114        }
115
116        // Check if we already blew the budget
117        let total_reserved = reserved_response + SAFETY_MARGIN + current_usage;
118        if total_reserved > max_window {
119            tracing::warn!(
120                "System prompt + RAG context exceeds context window! (Usage: {}, Limit: {})",
121                current_usage,
122                max_window - reserved_response - SAFETY_MARGIN
123            );
124            // We proceed, but truncation is guaranteed.
125        }
126
127        let history_budget = if max_window > total_reserved {
128            max_window - total_reserved
129        } else {
130            0
131        };
132
133        // --- 4. Select History (Sliding Window & Smart Pruning) ---
134        let mut selected_history = Vec::new();
135        let mut history_usage = 0;
136        let mut pruned_messages = Vec::new();
137
138        let history_slice = if history.len() > self.config.max_history_messages {
139            let (pruned, selected) = history.split_at(history.len() - self.config.max_history_messages);
140            pruned_messages.extend(pruned.iter().cloned());
141            selected
142        } else {
143            history
144        };
145
146        // Iterate REVERSE (Latest first) for selection
147        for msg in history_slice.iter().rev() {
148            let tokens = bpe.encode_with_special_tokens(&msg.content.as_text()).len();
149            let cost = tokens + 4; 
150
151            if history_usage + cost <= history_budget {
152                history_usage += cost;
153                selected_history.push(msg.clone());
154            } else {
155                pruned_messages.push(msg.clone());
156            }
157        }
158
159        // Handle Smart Pruning: Summarize pruned messages into an Observation Log
160        if self.config.smart_pruning && !pruned_messages.is_empty() {
161             let mut log = String::from("### Historical Observation Log (Pruned Summaries)\n");
162             // Pruned messages were collected in reverse or split order, let's sort them roughly by time or just list them
163             // For simplicity, we just extract tool calls and key facts
164             for msg in pruned_messages {
165                 match msg.role {
166                     crate::agent::message::Role::Assistant => {
167                         let text = msg.content.as_text();
168                         let snippet = if text.len() > 60 {
169                             format!("{}...", &text[..60].replace('\n', " "))
170                         } else {
171                             text.replace('\n', " ")
172                         };
173                         log.push_str(&format!("- Assistant: {}\n", snippet));
174                     }
175                     crate::agent::message::Role::Tool => {
176                         let name = msg.name.as_deref().unwrap_or("unknown_tool");
177                         log.push_str(&format!("- Tool executed: {}\n", name));
178                     }
179                     _ => {}
180                 }
181             }
182             final_context_start.push(Message::system(log));
183        }
184
185        // --- 5. Assemble Final Context ---
186        let mut final_messages = final_context_start;
187        selected_history.reverse();
188        final_messages.extend(selected_history);
189
190        Ok(final_messages)
191    }
192
193    /// Estimate token count for a list of messages using tiktoken
194    pub fn estimate_tokens(messages: &[Message]) -> usize {
195        if let Ok(bpe) = tiktoken_rs::cl100k_base() {
196            messages
197                .iter()
198                .map(|m| bpe.encode_with_special_tokens(&m.content.as_text()).len() + 4)
199                .sum()
200        } else {
201            // Fallback to heuristic if tokenizer fails
202            messages
203                .iter()
204                .map(|m| m.content.as_text().len() / 4)
205                .sum::<usize>()
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    // use crate::agent::message::Content;
214
215    #[tokio::test]
216    async fn test_smart_pruning_generation() {
217        let config = ContextConfig {
218            max_history_messages: 2, // Only keep 2 latest messages
219            max_tokens: 10000,
220            response_reserve: 1000,
221            smart_pruning: true,
222            ..Default::default()
223        };
224        let mut mgr = ContextManager::new(config);
225        mgr.set_system_prompt("System Prompt");
226
227        let history = vec![
228            Message::assistant("I am thinking about the first task."),
229            Message::user("What about the second one?"),
230            Message::assistant("Executing the third part now."),
231            Message::user("Final question."),
232        ];
233
234        // Should keep "Executing the third part now." and "Final question."
235        // And summarize "I am thinking about the first task." and "What about the second one?"
236        let ctx = mgr.build_context(&history).await.unwrap();
237
238        // System Prompt + Observation Log + 2 History Messages = 4 messages
239        assert_eq!(ctx.len(), 4, "Context should contain System, Log, and 2 history messages");
240        
241        let log_msg = &ctx[1];
242        assert!(log_msg.content.as_text().contains("Observation Log"), "Should contain Observation Log");
243        assert!(log_msg.content.as_text().contains("Assistant"), "Should mention Assistant in log");
244    }
245
246    #[tokio::test]
247    async fn test_basic_inclusion() {
248        let config = ContextConfig::default();
249        let mgr = ContextManager::new(config);
250        let history = vec![Message::user("test")];
251        let ctx = mgr.build_context(&history).await.unwrap();
252        // System prompt is None by default, so just history or empty system?
253        // Let's check: ContextManager::new initializes system_prompt to None.
254        assert!(ctx.len() >= 1);
255    }
256}