Skip to main content

forge_guardrails/context/
manager.rs

1//! Context window manager with token tracking, compaction triggering,
2//! and threshold callbacks.
3
4use crate::context::hardware::estimate_tokens_heuristic;
5use crate::context::strategies::CompactStrategy;
6use crate::core::message::{Message, ToolCallInfo};
7use std::collections::hash_map::DefaultHasher;
8use std::hash::{Hash, Hasher};
9
10/// Immutable record of a compaction event.
11#[derive(Debug, Clone, PartialEq)]
12pub struct CompactEvent {
13    /// The step index at which the compaction occurred.
14    pub step_index: i64,
15    /// Estimated token count of the context prior to compaction.
16    pub tokens_before: i64,
17    /// Estimated token count of the context after compaction.
18    pub tokens_after: i64,
19    /// Total context token budget.
20    pub budget_tokens: i64,
21    /// Message count in the conversation list prior to compaction.
22    pub messages_before: usize,
23    /// Message count in the conversation list after compaction.
24    pub messages_after: usize,
25    /// The compaction phase reached (e.g. 1, 2, 3).
26    pub phase_reached: i64,
27}
28
29/// Callback type invoked when compaction occurs.
30pub type OnCompactFn = Box<dyn Fn(&CompactEvent) + Send + Sync>;
31
32/// Callback type for threshold warnings. Returns an optional warning string.
33pub type OnThresholdFn = Box<dyn Fn(i64, i64, f64) -> Option<String> + Send + Sync>;
34
35#[derive(Debug, Clone, Copy)]
36struct StoredTokenCount {
37    count: i64,
38    messages_fingerprint: Option<u64>,
39}
40
41impl StoredTokenCount {
42    fn matches(self, messages_fingerprint: u64) -> bool {
43        self.messages_fingerprint
44            .map(|fingerprint| fingerprint == messages_fingerprint)
45            .unwrap_or(true)
46    }
47}
48
49/// Central context budget tracker.
50///
51/// Wraps a compaction strategy and provides token estimation, threshold
52/// checking, and compaction triggering. Tracks a stored token count scoped to
53/// the last observed message list when possible.
54pub struct ContextManager {
55    strategy: Box<dyn CompactStrategy>,
56    budget_tokens: i64,
57    on_compact: Option<OnCompactFn>,
58    context_thresholds: Option<Vec<f64>>,
59    on_context_threshold: Option<OnThresholdFn>,
60    stored_token_count: Option<StoredTokenCount>,
61    last_observed_messages_fingerprint: Option<u64>,
62    fired_thresholds: Vec<bool>,
63}
64
65impl ContextManager {
66    /// Creates a new `ContextManager` with the specified strategy, budget, and callbacks.
67    pub fn new(
68        strategy: Box<dyn CompactStrategy>,
69        budget_tokens: i64,
70        on_compact: Option<OnCompactFn>,
71        context_thresholds: Option<Vec<f64>>,
72        on_context_threshold: Option<OnThresholdFn>,
73    ) -> Self {
74        let fired = context_thresholds
75            .as_ref()
76            .map(|t| vec![false; t.len()])
77            .unwrap_or_default();
78        // Sort thresholds ascending for deterministic processing.
79        let sorted_thresholds = context_thresholds.map(|mut t| {
80            t.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
81            t
82        });
83        Self {
84            strategy,
85            budget_tokens,
86            on_compact,
87            context_thresholds: sorted_thresholds,
88            on_context_threshold,
89            stored_token_count: None,
90            last_observed_messages_fingerprint: None,
91            fired_thresholds: fired,
92        }
93    }
94
95    /// Get the context budget in tokens.
96    pub fn budget(&self) -> i64 {
97        self.budget_tokens
98    }
99
100    /// Estimate token count for messages.
101    ///
102    /// Returns the stored count if `update_token_count` was called for this
103    /// message list, otherwise falls back to the character-count heuristic.
104    pub fn estimate_tokens(&self, messages: &[Message]) -> i64 {
105        let fingerprint = message_fingerprint(messages);
106        match self.stored_token_count {
107            Some(stored) if stored.matches(fingerprint) => stored.count,
108            _ => estimate_tokens_heuristic(messages),
109        }
110    }
111
112    /// Store an actual token count from the backend.
113    ///
114    /// The count is tied to the most recent message list observed by
115    /// `maybe_compact` or `check_thresholds`. If no message list has been
116    /// observed, the count remains unscoped for backwards compatibility.
117    pub fn update_token_count(&mut self, count: i64) {
118        self.stored_token_count = Some(StoredTokenCount {
119            count,
120            messages_fingerprint: self.last_observed_messages_fingerprint,
121        });
122    }
123
124    fn estimate_current_tokens(&mut self, messages: &[Message]) -> i64 {
125        let fingerprint = self.observe_messages(messages);
126        match self.stored_token_count {
127            Some(stored) if stored.matches(fingerprint) => stored.count,
128            Some(_) => {
129                self.stored_token_count = None;
130                estimate_tokens_heuristic(messages)
131            }
132            None => estimate_tokens_heuristic(messages),
133        }
134    }
135
136    fn observe_messages(&mut self, messages: &[Message]) -> u64 {
137        let fingerprint = message_fingerprint(messages);
138        self.last_observed_messages_fingerprint = Some(fingerprint);
139        fingerprint
140    }
141
142    /// Apply compaction if the strategy deems it necessary.
143    ///
144    /// Returns the original message slice when no compaction occurs (phase 0),
145    /// or a new list when compaction occurs (phase > 0).
146    pub fn maybe_compact<'a>(
147        &mut self,
148        messages: &'a [Message],
149        step_index: i64,
150        step_hint: Option<&str>,
151    ) -> std::borrow::Cow<'a, [Message]> {
152        let tokens_before = self.estimate_current_tokens(messages);
153        let (compacted, phase) = self
154            .strategy
155            .compact(messages, self.budget_tokens, step_hint);
156
157        if phase == 0 {
158            return std::borrow::Cow::Borrowed(messages);
159        }
160
161        let tokens_after = estimate_tokens_heuristic(&compacted);
162        let event = CompactEvent {
163            step_index,
164            tokens_before,
165            tokens_after,
166            budget_tokens: self.budget_tokens,
167            messages_before: messages.len(),
168            messages_after: compacted.len(),
169            phase_reached: phase,
170        };
171
172        if let Some(ref callback) = self.on_compact {
173            callback(&event);
174        }
175
176        // Clear stored token count so heuristic runs on next estimate.
177        self.stored_token_count = None;
178
179        std::borrow::Cow::Owned(compacted)
180    }
181
182    /// Check context thresholds and fire the highest unfired threshold
183    /// callback if usage crosses it.
184    ///
185    /// Returns `None` when thresholds or callback are not configured,
186    /// budget is zero/negative, or no threshold is newly crossed.
187    pub fn check_thresholds(&mut self, messages: &[Message]) -> Option<String> {
188        if self.context_thresholds.is_none() || self.on_context_threshold.is_none() {
189            return None;
190        }
191
192        if self.budget_tokens <= 0 {
193            return None;
194        }
195
196        let tokens = self.estimate_current_tokens(messages);
197        let pct = tokens as f64 / self.budget_tokens as f64;
198        let thresholds = self.context_thresholds.as_ref()?;
199
200        // Reset thresholds where usage has dropped below them.
201        for (i, &threshold) in thresholds.iter().enumerate() {
202            if pct < threshold && self.fired_thresholds[i] {
203                self.fired_thresholds[i] = false;
204            }
205        }
206
207        // Find highest unfired threshold that is crossed.
208        // Thresholds are sorted ascending, so iterate in reverse.
209        let mut fired_idx: Option<usize> = None;
210        for (i, &threshold) in thresholds.iter().enumerate().rev() {
211            if pct >= threshold && !self.fired_thresholds[i] {
212                fired_idx = Some(i);
213                break;
214            }
215        }
216
217        let idx = fired_idx?;
218
219        self.fired_thresholds[idx] = true;
220        let callback = self.on_context_threshold.as_ref()?;
221        callback(tokens, self.budget_tokens, pct)
222    }
223}
224
225fn message_fingerprint(messages: &[Message]) -> u64 {
226    let mut hasher = DefaultHasher::new();
227    messages.len().hash(&mut hasher);
228    for message in messages {
229        message.role.hash(&mut hasher);
230        message.content.hash(&mut hasher);
231        message.metadata.msg_type.hash(&mut hasher);
232        message.metadata.step_index.hash(&mut hasher);
233        message.metadata.original_type.hash(&mut hasher);
234        message.metadata.token_estimate.hash(&mut hasher);
235        message.tool_name.hash(&mut hasher);
236        message.tool_call_id.hash(&mut hasher);
237        hash_tool_calls(&message.tool_calls, &mut hasher);
238    }
239    hasher.finish()
240}
241
242fn hash_tool_calls(tool_calls: &Option<Vec<ToolCallInfo>>, hasher: &mut DefaultHasher) {
243    match tool_calls {
244        Some(calls) => {
245            true.hash(hasher);
246            calls.len().hash(hasher);
247            for call in calls {
248                call.name.hash(hasher);
249                call.call_id.hash(hasher);
250                match &call.args {
251                    Some(args) => {
252                        true.hash(hasher);
253                        args.len().hash(hasher);
254                        for (key, value) in args {
255                            key.hash(hasher);
256                            value.to_string().hash(hasher);
257                        }
258                    }
259                    None => false.hash(hasher),
260                }
261            }
262        }
263        None => false.hash(hasher),
264    }
265}
266
267/// Default context warning callback.
268///
269/// Escalating message: >= 80% mentions "nearly full", >= 65% mentions
270/// "filling up", otherwise a mild reminder. Always includes percentage,
271/// token count, and budget.
272pub fn default_context_warning(tokens: i64, budget: i64, pct: f64) -> Option<String> {
273    let pct_display = (pct * 100.0) as i64;
274    let message = if pct >= 0.80 {
275        format!(
276            "Context window nearly full: {}% ({} / {} tokens)",
277            pct_display, tokens, budget
278        )
279    } else if pct >= 0.65 {
280        format!(
281            "Context window filling up: {}% ({} / {} tokens)",
282            pct_display, tokens, budget
283        )
284    } else {
285        format!(
286            "Context usage at {}% ({} / {} tokens)",
287            pct_display, tokens, budget
288        )
289    };
290    Some(message)
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::context::strategies::NoCompact;
297    use crate::core::message::{Message, MessageMeta, MessageRole, MessageType};
298
299    #[test]
300    fn compact_event_fields() {
301        let event = CompactEvent {
302            step_index: 5,
303            tokens_before: 1000,
304            tokens_after: 500,
305            budget_tokens: 800,
306            messages_before: 10,
307            messages_after: 6,
308            phase_reached: 2,
309        };
310        assert_eq!(event.step_index, 5);
311        assert_eq!(event.tokens_after, 500);
312        assert_eq!(event.phase_reached, 2);
313    }
314
315    #[test]
316    fn estimate_tokens_heuristic_fallback() {
317        let msgs = vec![Message::new(
318            MessageRole::User,
319            "a".repeat(100),
320            MessageMeta::new(MessageType::UserInput),
321        )];
322        let mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
323        assert_eq!(mgr.estimate_tokens(&msgs), 25);
324    }
325
326    #[test]
327    fn update_token_count_overrides_heuristic() {
328        let msgs = vec![Message::new(
329            MessageRole::User,
330            "a".repeat(100),
331            MessageMeta::new(MessageType::UserInput),
332        )];
333        let mut mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
334        mgr.update_token_count(500);
335        assert_eq!(mgr.estimate_tokens(&msgs), 500);
336    }
337
338    #[test]
339    fn default_warning_escalates() {
340        let w50 = default_context_warning(400, 800, 0.50).unwrap();
341        assert!(w50.contains("50%"));
342        assert!(!w50.contains("nearly full"));
343        assert!(!w50.contains("filling up"));
344
345        let w65 = default_context_warning(520, 800, 0.65).unwrap();
346        assert!(w65.contains("filling up"));
347
348        let w80 = default_context_warning(640, 800, 0.80).unwrap();
349        assert!(w80.contains("nearly full"));
350    }
351}