astrid_runtime/
context.rs1use astrid_llm::{LlmProvider, Message, MessageContent};
6use tracing::{debug, info};
7
8use crate::error::RuntimeResult;
9use crate::session::AgentSession;
10
11pub struct ContextManager {
13 max_context_tokens: usize,
15 summarization_threshold: f32,
17 keep_recent_count: usize,
19}
20
21impl ContextManager {
22 #[must_use]
24 pub fn new(max_context_tokens: usize) -> Self {
25 Self {
26 max_context_tokens,
27 summarization_threshold: 0.85,
28 keep_recent_count: 10,
29 }
30 }
31
32 #[must_use]
34 pub fn with_threshold(mut self, threshold: f32) -> Self {
35 self.summarization_threshold = threshold.clamp(0.5, 0.95);
36 self
37 }
38
39 #[must_use]
41 pub fn keep_recent(mut self, count: usize) -> Self {
42 self.keep_recent_count = count;
43 self
44 }
45
46 #[must_use]
48 pub fn needs_summarization(&self, session: &AgentSession) -> bool {
49 session.is_near_limit(self.max_context_tokens, self.summarization_threshold)
50 }
51
52 pub async fn summarize<P: LlmProvider>(
60 &self,
61 session: &mut AgentSession,
62 provider: &P,
63 ) -> RuntimeResult<SummarizationResult> {
64 if session.messages.len() <= self.keep_recent_count {
65 return Ok(SummarizationResult {
66 messages_evicted: 0,
67 tokens_freed: 0,
68 summary: None,
69 });
70 }
71
72 #[allow(clippy::arithmetic_side_effects)]
74 let evict_count = session.messages.len() - self.keep_recent_count;
75 let messages_to_summarize: Vec<_> = session.messages.drain(..evict_count).collect();
76
77 info!(
78 evict_count = evict_count,
79 remaining = session.messages.len(),
80 "Summarizing old context"
81 );
82
83 let tokens_freed: usize = messages_to_summarize
85 .iter()
86 .map(|m| match &m.content {
87 MessageContent::Text(t) => t.len() / 4,
88 _ => 100,
89 })
90 .sum();
91
92 let messages_text = format_messages_for_summary(&messages_to_summarize);
94 let summary_prompt = format!(
95 "Summarize the following conversation, preserving key facts, decisions, \
96 and context that would be important for continuing the conversation:\n\n{messages_text}"
97 );
98
99 let summary = provider.complete_simple(&summary_prompt).await?;
101
102 debug!(summary_len = summary.len(), "Generated context summary");
103
104 let summary_message =
106 Message::system(format!("[Previous conversation summary]\n{summary}"));
107 session.messages.insert(0, summary_message);
108
109 session.token_count = session.token_count.saturating_sub(tokens_freed);
111 session.token_count = session.token_count.saturating_add(summary.len() / 4); Ok(SummarizationResult {
114 messages_evicted: evict_count,
115 tokens_freed,
116 summary: Some(summary),
117 })
118 }
119
120 #[must_use]
122 #[allow(clippy::cast_precision_loss)]
123 pub fn stats(&self, session: &AgentSession) -> ContextStats {
124 let utilization = session.token_count as f32 / self.max_context_tokens as f32;
125
126 ContextStats {
127 current_tokens: session.token_count,
128 max_tokens: self.max_context_tokens,
129 utilization,
130 message_count: session.messages.len(),
131 needs_summarization: self.needs_summarization(session),
132 }
133 }
134}
135
136impl Default for ContextManager {
137 fn default() -> Self {
138 Self::new(100_000) }
140}
141
142#[derive(Debug, Clone)]
144pub struct SummarizationResult {
145 pub messages_evicted: usize,
147 pub tokens_freed: usize,
149 pub summary: Option<String>,
151}
152
153#[derive(Debug, Clone)]
155pub struct ContextStats {
156 pub current_tokens: usize,
158 pub max_tokens: usize,
160 pub utilization: f32,
162 pub message_count: usize,
164 pub needs_summarization: bool,
166}
167
168impl ContextStats {
169 #[must_use]
171 pub fn utilization_percent(&self) -> f32 {
172 self.utilization * 100.0
173 }
174}
175
176fn format_messages_for_summary(messages: &[Message]) -> String {
178 messages
179 .iter()
180 .map(|m| {
181 let role = match m.role {
182 astrid_llm::MessageRole::User => "User",
183 astrid_llm::MessageRole::Assistant => "Assistant",
184 astrid_llm::MessageRole::System => "System",
185 astrid_llm::MessageRole::Tool => "Tool",
186 };
187
188 let content = match &m.content {
189 MessageContent::Text(t) => t.clone(),
190 MessageContent::ToolCalls(calls) => {
191 let call_strs: Vec<_> = calls
192 .iter()
193 .map(|c| format!("{}({})", &c.name, &c.arguments))
194 .collect();
195 let joined = call_strs.join(", ");
196 format!("[Tool calls: {joined}]")
197 },
198 MessageContent::ToolResult(r) => {
199 let result_content = if r.content.len() > 200 {
200 format!("{}...", &r.content[..200])
201 } else {
202 r.content.clone()
203 };
204 format!("[Tool result: {result_content}]")
205 },
206 MessageContent::MultiPart(_) => "[Multi-part content]".to_string(),
207 };
208
209 format!("{role}: {content}")
210 })
211 .collect::<Vec<_>>()
212 .join("\n\n")
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_context_manager() {
221 let manager = ContextManager::new(1000);
222 let mut session = AgentSession::new([0u8; 8], "");
223
224 for i in 0..50 {
226 session.add_message(Message::user(format!("Message {i}")));
227 }
228
229 session.token_count = 900;
231
232 assert!(manager.needs_summarization(&session));
233 }
234
235 #[test]
236 fn test_context_stats() {
237 let manager = ContextManager::new(1000);
238 let mut session = AgentSession::new([0u8; 8], "");
239 session.token_count = 500;
240
241 let stats = manager.stats(&session);
242 assert!((stats.utilization - 0.5_f32).abs() < f32::EPSILON);
243 assert!((stats.utilization_percent() - 50.0_f32).abs() < f32::EPSILON);
244 }
245}