Skip to main content

astrid_runtime/
context.rs

1//! Context management and auto-summarization.
2//!
3//! Handles context window overflow by summarizing old messages.
4
5use astrid_llm::{LlmProvider, Message, MessageContent};
6use tracing::{debug, info};
7
8use crate::error::RuntimeResult;
9use crate::session::AgentSession;
10
11/// Context manager for handling context overflow.
12pub struct ContextManager {
13    /// Maximum context tokens before summarization.
14    max_context_tokens: usize,
15    /// Threshold (0.0-1.0) at which to trigger summarization.
16    summarization_threshold: f32,
17    /// Number of recent messages to always keep.
18    keep_recent_count: usize,
19}
20
21impl ContextManager {
22    /// Create a new context manager.
23    #[must_use]
24    pub fn new(max_context_tokens: usize) -> Self {
25        Self {
26            max_context_tokens,
27            summarization_threshold: 0.85,
28            keep_recent_count: 10,
29        }
30    }
31
32    /// Set the summarization threshold.
33    #[must_use]
34    pub fn with_threshold(mut self, threshold: f32) -> Self {
35        self.summarization_threshold = threshold.clamp(0.5, 0.95);
36        self
37    }
38
39    /// Set how many recent messages to keep.
40    #[must_use]
41    pub fn keep_recent(mut self, count: usize) -> Self {
42        self.keep_recent_count = count;
43        self
44    }
45
46    /// Check if summarization is needed.
47    #[must_use]
48    pub fn needs_summarization(&self, session: &AgentSession) -> bool {
49        session.is_near_limit(self.max_context_tokens, self.summarization_threshold)
50    }
51
52    /// Summarize old messages in a session.
53    ///
54    /// This removes old messages and replaces them with a summary.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the LLM provider fails to generate a summary.
59    pub async fn summarize<P: LlmProvider>(
60        &self,
61        session: &mut AgentSession,
62        provider: &P,
63    ) -> RuntimeResult<SummarizationResult> {
64        if session.messages.len() <= self.keep_recent_count {
65            return Ok(SummarizationResult {
66                messages_evicted: 0,
67                tokens_freed: 0,
68                summary: None,
69            });
70        }
71
72        // Safety: checked `len() > keep_recent_count` above
73        #[allow(clippy::arithmetic_side_effects)]
74        let evict_count = session.messages.len() - self.keep_recent_count;
75        let messages_to_summarize: Vec<_> = session.messages.drain(..evict_count).collect();
76
77        info!(
78            evict_count = evict_count,
79            remaining = session.messages.len(),
80            "Summarizing old context"
81        );
82
83        // Calculate tokens freed (approximate)
84        let tokens_freed: usize = messages_to_summarize
85            .iter()
86            .map(|m| match &m.content {
87                MessageContent::Text(t) => t.len() / 4,
88                _ => 100,
89            })
90            .sum();
91
92        // Build summary prompt
93        let messages_text = format_messages_for_summary(&messages_to_summarize);
94        let summary_prompt = format!(
95            "Summarize the following conversation, preserving key facts, decisions, \
96             and context that would be important for continuing the conversation:\n\n{messages_text}"
97        );
98
99        // Get summary from LLM
100        let summary = provider.complete_simple(&summary_prompt).await?;
101
102        debug!(summary_len = summary.len(), "Generated context summary");
103
104        // Insert summary as a system message at the beginning
105        let summary_message =
106            Message::system(format!("[Previous conversation summary]\n{summary}"));
107        session.messages.insert(0, summary_message);
108
109        // Update token count
110        session.token_count = session.token_count.saturating_sub(tokens_freed);
111        session.token_count = session.token_count.saturating_add(summary.len() / 4); // Add summary tokens
112
113        Ok(SummarizationResult {
114            messages_evicted: evict_count,
115            tokens_freed,
116            summary: Some(summary),
117        })
118    }
119
120    /// Get context statistics.
121    #[must_use]
122    #[allow(clippy::cast_precision_loss)]
123    pub fn stats(&self, session: &AgentSession) -> ContextStats {
124        let utilization = session.token_count as f32 / self.max_context_tokens as f32;
125
126        ContextStats {
127            current_tokens: session.token_count,
128            max_tokens: self.max_context_tokens,
129            utilization,
130            message_count: session.messages.len(),
131            needs_summarization: self.needs_summarization(session),
132        }
133    }
134}
135
136impl Default for ContextManager {
137    fn default() -> Self {
138        Self::new(100_000) // Default to ~100k tokens
139    }
140}
141
142/// Result of a summarization operation.
143#[derive(Debug, Clone)]
144pub struct SummarizationResult {
145    /// Number of messages evicted.
146    pub messages_evicted: usize,
147    /// Approximate tokens freed.
148    pub tokens_freed: usize,
149    /// The generated summary (if any).
150    pub summary: Option<String>,
151}
152
153/// Context statistics.
154#[derive(Debug, Clone)]
155pub struct ContextStats {
156    /// Current token count.
157    pub current_tokens: usize,
158    /// Maximum allowed tokens.
159    pub max_tokens: usize,
160    /// Context utilization (0.0-1.0).
161    pub utilization: f32,
162    /// Number of messages.
163    pub message_count: usize,
164    /// Whether summarization is needed.
165    pub needs_summarization: bool,
166}
167
168impl ContextStats {
169    /// Get utilization as a percentage.
170    #[must_use]
171    pub fn utilization_percent(&self) -> f32 {
172        self.utilization * 100.0
173    }
174}
175
176/// Format messages for summarization.
177fn format_messages_for_summary(messages: &[Message]) -> String {
178    messages
179        .iter()
180        .map(|m| {
181            let role = match m.role {
182                astrid_llm::MessageRole::User => "User",
183                astrid_llm::MessageRole::Assistant => "Assistant",
184                astrid_llm::MessageRole::System => "System",
185                astrid_llm::MessageRole::Tool => "Tool",
186            };
187
188            let content = match &m.content {
189                MessageContent::Text(t) => t.clone(),
190                MessageContent::ToolCalls(calls) => {
191                    let call_strs: Vec<_> = calls
192                        .iter()
193                        .map(|c| format!("{}({})", &c.name, &c.arguments))
194                        .collect();
195                    let joined = call_strs.join(", ");
196                    format!("[Tool calls: {joined}]")
197                },
198                MessageContent::ToolResult(r) => {
199                    let result_content = if r.content.len() > 200 {
200                        format!("{}...", &r.content[..200])
201                    } else {
202                        r.content.clone()
203                    };
204                    format!("[Tool result: {result_content}]")
205                },
206                MessageContent::MultiPart(_) => "[Multi-part content]".to_string(),
207            };
208
209            format!("{role}: {content}")
210        })
211        .collect::<Vec<_>>()
212        .join("\n\n")
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_context_manager() {
221        let manager = ContextManager::new(1000);
222        let mut session = AgentSession::new([0u8; 8], "");
223
224        // Add messages to exceed threshold
225        for i in 0..50 {
226            session.add_message(Message::user(format!("Message {i}")));
227        }
228
229        // Manually set high token count
230        session.token_count = 900;
231
232        assert!(manager.needs_summarization(&session));
233    }
234
235    #[test]
236    fn test_context_stats() {
237        let manager = ContextManager::new(1000);
238        let mut session = AgentSession::new([0u8; 8], "");
239        session.token_count = 500;
240
241        let stats = manager.stats(&session);
242        assert_eq!(stats.utilization, 0.5);
243        assert_eq!(stats.utilization_percent(), 50.0);
244    }
245}