forge_guardrails/context/
manager.rs1use crate::context::hardware::estimate_tokens_heuristic;
5use crate::context::strategies::CompactStrategy;
6use crate::core::message::{Message, ToolCallInfo};
7use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct CompactEvent {
13 pub step_index: i64,
15 pub tokens_before: i64,
17 pub tokens_after: i64,
19 pub budget_tokens: i64,
21 pub messages_before: usize,
23 pub messages_after: usize,
25 pub phase_reached: i64,
27}
28
29pub type OnCompactFn = Box<dyn Fn(&CompactEvent) + Send + Sync>;
31
32pub type OnThresholdFn = Box<dyn Fn(i64, i64, f64) -> Option<String> + Send + Sync>;
34
35#[derive(Debug, Clone, Copy)]
36struct StoredTokenCount {
37 count: i64,
38 messages_fingerprint: Option<u64>,
39}
40
41impl StoredTokenCount {
42 fn matches(self, messages_fingerprint: u64) -> bool {
43 self.messages_fingerprint
44 .map(|fingerprint| fingerprint == messages_fingerprint)
45 .unwrap_or(true)
46 }
47}
48
49pub struct ContextManager {
55 strategy: Box<dyn CompactStrategy>,
56 budget_tokens: i64,
57 on_compact: Option<OnCompactFn>,
58 context_thresholds: Option<Vec<f64>>,
59 on_context_threshold: Option<OnThresholdFn>,
60 stored_token_count: Option<StoredTokenCount>,
61 last_observed_messages_fingerprint: Option<u64>,
62 fired_thresholds: Vec<bool>,
63}
64
65impl ContextManager {
66 pub fn new(
68 strategy: Box<dyn CompactStrategy>,
69 budget_tokens: i64,
70 on_compact: Option<OnCompactFn>,
71 context_thresholds: Option<Vec<f64>>,
72 on_context_threshold: Option<OnThresholdFn>,
73 ) -> Self {
74 let fired = context_thresholds
75 .as_ref()
76 .map(|t| vec![false; t.len()])
77 .unwrap_or_default();
78 let sorted_thresholds = context_thresholds.map(|mut t| {
80 t.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
81 t
82 });
83 Self {
84 strategy,
85 budget_tokens,
86 on_compact,
87 context_thresholds: sorted_thresholds,
88 on_context_threshold,
89 stored_token_count: None,
90 last_observed_messages_fingerprint: None,
91 fired_thresholds: fired,
92 }
93 }
94
95 pub fn budget(&self) -> i64 {
97 self.budget_tokens
98 }
99
100 pub fn estimate_tokens(&self, messages: &[Message]) -> i64 {
105 let fingerprint = message_fingerprint(messages);
106 match self.stored_token_count {
107 Some(stored) if stored.matches(fingerprint) => stored.count,
108 _ => estimate_tokens_heuristic(messages),
109 }
110 }
111
112 pub fn update_token_count(&mut self, count: i64) {
118 self.stored_token_count = Some(StoredTokenCount {
119 count,
120 messages_fingerprint: self.last_observed_messages_fingerprint,
121 });
122 }
123
124 fn estimate_current_tokens(&mut self, messages: &[Message]) -> i64 {
125 let fingerprint = self.observe_messages(messages);
126 match self.stored_token_count {
127 Some(stored) if stored.matches(fingerprint) => stored.count,
128 Some(_) => {
129 self.stored_token_count = None;
130 estimate_tokens_heuristic(messages)
131 }
132 None => estimate_tokens_heuristic(messages),
133 }
134 }
135
136 fn observe_messages(&mut self, messages: &[Message]) -> u64 {
137 let fingerprint = message_fingerprint(messages);
138 self.last_observed_messages_fingerprint = Some(fingerprint);
139 fingerprint
140 }
141
142 pub fn maybe_compact<'a>(
147 &mut self,
148 messages: &'a [Message],
149 step_index: i64,
150 step_hint: Option<&str>,
151 ) -> std::borrow::Cow<'a, [Message]> {
152 let tokens_before = self.estimate_current_tokens(messages);
153 let (compacted, phase) = self
154 .strategy
155 .compact(messages, self.budget_tokens, step_hint);
156
157 if phase == 0 {
158 return std::borrow::Cow::Borrowed(messages);
159 }
160
161 let tokens_after = estimate_tokens_heuristic(&compacted);
162 let event = CompactEvent {
163 step_index,
164 tokens_before,
165 tokens_after,
166 budget_tokens: self.budget_tokens,
167 messages_before: messages.len(),
168 messages_after: compacted.len(),
169 phase_reached: phase,
170 };
171
172 if let Some(ref callback) = self.on_compact {
173 callback(&event);
174 }
175
176 self.stored_token_count = None;
178
179 std::borrow::Cow::Owned(compacted)
180 }
181
182 pub fn check_thresholds(&mut self, messages: &[Message]) -> Option<String> {
188 if self.context_thresholds.is_none() || self.on_context_threshold.is_none() {
189 return None;
190 }
191
192 if self.budget_tokens <= 0 {
193 return None;
194 }
195
196 let tokens = self.estimate_current_tokens(messages);
197 let pct = tokens as f64 / self.budget_tokens as f64;
198 let thresholds = self.context_thresholds.as_ref()?;
199
200 for (i, &threshold) in thresholds.iter().enumerate() {
202 if pct < threshold && self.fired_thresholds[i] {
203 self.fired_thresholds[i] = false;
204 }
205 }
206
207 let mut fired_idx: Option<usize> = None;
210 for (i, &threshold) in thresholds.iter().enumerate().rev() {
211 if pct >= threshold && !self.fired_thresholds[i] {
212 fired_idx = Some(i);
213 break;
214 }
215 }
216
217 let idx = fired_idx?;
218
219 self.fired_thresholds[idx] = true;
220 let callback = self.on_context_threshold.as_ref()?;
221 callback(tokens, self.budget_tokens, pct)
222 }
223}
224
225fn message_fingerprint(messages: &[Message]) -> u64 {
226 let mut hasher = DefaultHasher::new();
227 messages.len().hash(&mut hasher);
228 for message in messages {
229 message.role.hash(&mut hasher);
230 message.content.hash(&mut hasher);
231 message.metadata.msg_type.hash(&mut hasher);
232 message.metadata.step_index.hash(&mut hasher);
233 message.metadata.original_type.hash(&mut hasher);
234 message.metadata.token_estimate.hash(&mut hasher);
235 message.tool_name.hash(&mut hasher);
236 message.tool_call_id.hash(&mut hasher);
237 hash_tool_calls(&message.tool_calls, &mut hasher);
238 }
239 hasher.finish()
240}
241
242fn hash_tool_calls(tool_calls: &Option<Vec<ToolCallInfo>>, hasher: &mut DefaultHasher) {
243 match tool_calls {
244 Some(calls) => {
245 true.hash(hasher);
246 calls.len().hash(hasher);
247 for call in calls {
248 call.name.hash(hasher);
249 call.call_id.hash(hasher);
250 match &call.args {
251 Some(args) => {
252 true.hash(hasher);
253 args.len().hash(hasher);
254 for (key, value) in args {
255 key.hash(hasher);
256 value.to_string().hash(hasher);
257 }
258 }
259 None => false.hash(hasher),
260 }
261 }
262 }
263 None => false.hash(hasher),
264 }
265}
266
267pub fn default_context_warning(tokens: i64, budget: i64, pct: f64) -> Option<String> {
273 let pct_display = (pct * 100.0) as i64;
274 let message = if pct >= 0.80 {
275 format!(
276 "Context window nearly full: {}% ({} / {} tokens)",
277 pct_display, tokens, budget
278 )
279 } else if pct >= 0.65 {
280 format!(
281 "Context window filling up: {}% ({} / {} tokens)",
282 pct_display, tokens, budget
283 )
284 } else {
285 format!(
286 "Context usage at {}% ({} / {} tokens)",
287 pct_display, tokens, budget
288 )
289 };
290 Some(message)
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::context::strategies::NoCompact;
297 use crate::core::message::{Message, MessageMeta, MessageRole, MessageType};
298
299 #[test]
300 fn compact_event_fields() {
301 let event = CompactEvent {
302 step_index: 5,
303 tokens_before: 1000,
304 tokens_after: 500,
305 budget_tokens: 800,
306 messages_before: 10,
307 messages_after: 6,
308 phase_reached: 2,
309 };
310 assert_eq!(event.step_index, 5);
311 assert_eq!(event.tokens_after, 500);
312 assert_eq!(event.phase_reached, 2);
313 }
314
315 #[test]
316 fn estimate_tokens_heuristic_fallback() {
317 let msgs = vec![Message::new(
318 MessageRole::User,
319 "a".repeat(100),
320 MessageMeta::new(MessageType::UserInput),
321 )];
322 let mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
323 assert_eq!(mgr.estimate_tokens(&msgs), 25);
324 }
325
326 #[test]
327 fn update_token_count_overrides_heuristic() {
328 let msgs = vec![Message::new(
329 MessageRole::User,
330 "a".repeat(100),
331 MessageMeta::new(MessageType::UserInput),
332 )];
333 let mut mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
334 mgr.update_token_count(500);
335 assert_eq!(mgr.estimate_tokens(&msgs), 500);
336 }
337
338 #[test]
339 fn default_warning_escalates() {
340 let w50 = default_context_warning(400, 800, 0.50).unwrap();
341 assert!(w50.contains("50%"));
342 assert!(!w50.contains("nearly full"));
343 assert!(!w50.contains("filling up"));
344
345 let w65 = default_context_warning(520, 800, 0.65).unwrap();
346 assert!(w65.contains("filling up"));
347
348 let w80 = default_context_warning(640, 800, 0.80).unwrap();
349 assert!(w80.contains("nearly full"));
350 }
351}