Skip to main content

atm_core/
context.rs

1//! Context window and token tracking.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5use std::ops::{Add, AddAssign};
6
7/// Represents a count of tokens.
8///
9/// Used for input tokens, output tokens, cache tokens.
10#[derive(
11    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize,
12)]
13#[serde(transparent)]
14pub struct TokenCount(u64);
15
16impl TokenCount {
17    /// Creates a new TokenCount.
18    pub const fn new(count: u64) -> Self {
19        Self(count)
20    }
21
22    /// Creates a zero TokenCount.
23    pub const fn zero() -> Self {
24        Self(0)
25    }
26
27    /// Returns the raw count.
28    pub const fn as_u64(&self) -> u64 {
29        self.0
30    }
31
32    /// Returns true if count is zero.
33    pub const fn is_zero(&self) -> bool {
34        self.0 == 0
35    }
36
37    /// Formats the token count for display.
38    ///
39    /// Uses K/M suffixes for large numbers.
40    pub fn format(&self) -> String {
41        if self.0 < 1_000 {
42            format!("{}", self.0)
43        } else if self.0 < 10_000 {
44            format!("{:.1}K", self.0 as f64 / 1_000.0)
45        } else if self.0 < 1_000_000 {
46            format!("{}K", self.0 / 1_000)
47        } else {
48            format!("{:.1}M", self.0 as f64 / 1_000_000.0)
49        }
50    }
51
52    /// Saturating addition.
53    pub fn saturating_add(self, other: Self) -> Self {
54        Self(self.0.saturating_add(other.0))
55    }
56}
57
58impl Add for TokenCount {
59    type Output = Self;
60
61    fn add(self, other: Self) -> Self {
62        Self(self.0.saturating_add(other.0))
63    }
64}
65
66impl AddAssign for TokenCount {
67    fn add_assign(&mut self, other: Self) {
68        self.0 = self.0.saturating_add(other.0);
69    }
70}
71
72impl From<u64> for TokenCount {
73    fn from(n: u64) -> Self {
74        Self(n)
75    }
76}
77
78impl From<u32> for TokenCount {
79    fn from(n: u32) -> Self {
80        Self(n as u64)
81    }
82}
83
84impl fmt::Display for TokenCount {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write!(f, "{}", self.format())
87    }
88}
89
90/// Context window usage information.
91///
92/// Tracks token counts and calculates usage percentage based on `current_usage`
93/// from Claude Code's status line. The `current_usage` object reflects the actual
94/// tokens being sent to the API in the current context window.
95///
96/// When `current_usage` is null (e.g., during /clear), all current_* fields are 0,
97/// which correctly shows 0% context usage.
98#[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize)]
99pub struct ContextUsage {
100    /// Total input tokens across all turns (cumulative, for reference)
101    pub total_input_tokens: TokenCount,
102
103    /// Total output tokens across all turns (cumulative, for reference)
104    pub total_output_tokens: TokenCount,
105
106    /// Maximum context window size for the model
107    pub context_window_size: u32,
108
109    /// Current context's input tokens (from current_usage.input_tokens)
110    pub current_input_tokens: TokenCount,
111
112    /// Current context's output tokens (from current_usage.output_tokens)
113    pub current_output_tokens: TokenCount,
114
115    /// Tokens written to cache (from current_usage.cache_creation_input_tokens)
116    pub cache_creation_tokens: TokenCount,
117
118    /// Tokens read from cache - this is the bulk of context (from current_usage.cache_read_input_tokens)
119    pub cache_read_tokens: TokenCount,
120}
121
122impl ContextUsage {
123    /// Creates a new ContextUsage with default values.
124    pub fn new(context_window_size: u32) -> Self {
125        Self {
126            context_window_size,
127            ..Default::default()
128        }
129    }
130
131    /// Calculates the context tokens currently in use.
132    ///
133    /// This is the actual context being sent to the API, calculated from:
134    /// - cache_read_tokens: Previously cached context being reused
135    /// - current_input_tokens: New input tokens in this turn
136    /// - cache_creation_tokens: New tokens being written to cache
137    ///
138    /// When current_usage is null (e.g., after /clear), these are all 0.
139    pub fn context_tokens(&self) -> TokenCount {
140        self.cache_read_tokens
141            .saturating_add(self.current_input_tokens)
142            .saturating_add(self.cache_creation_tokens)
143    }
144
145    /// Calculates the total tokens used (cumulative, for reference).
146    pub fn total_tokens(&self) -> TokenCount {
147        self.total_input_tokens
148            .saturating_add(self.total_output_tokens)
149    }
150
151    /// Returns the percentage of context window used (0.0 to 100.0).
152    ///
153    /// Uses context_tokens() which reflects the actual tokens in the current
154    /// context window (from current_usage). When current_usage is null,
155    /// this returns 0%.
156    pub fn usage_percentage(&self) -> f64 {
157        if self.context_window_size == 0 {
158            return 0.0;
159        }
160        let usage = self.context_tokens().as_u64() as f64 / self.context_window_size as f64;
161        (usage * 100.0).min(100.0)
162    }
163
164    /// Returns true if context usage is above the warning threshold (80%).
165    pub fn is_warning(&self) -> bool {
166        self.usage_percentage() >= 80.0
167    }
168
169    /// Returns true if context usage is critical (>90%).
170    pub fn is_critical(&self) -> bool {
171        self.usage_percentage() >= 90.0
172    }
173
174    /// Returns true if exceeds 200K tokens (Claude Code's extended context marker).
175    pub fn exceeds_200k(&self) -> bool {
176        self.context_tokens().as_u64() > 200_000
177    }
178
179    /// Returns the remaining tokens before hitting context limit.
180    pub fn remaining_tokens(&self) -> TokenCount {
181        let used = self.context_tokens().as_u64();
182        let limit = self.context_window_size as u64;
183        TokenCount::new(limit.saturating_sub(used))
184    }
185
186    /// Formats usage for display (e.g., "45.2% (26.4K/200K)").
187    pub fn format(&self) -> String {
188        format!(
189            "{:.1}% ({}/{})",
190            self.usage_percentage(),
191            self.context_tokens().format(),
192            TokenCount::new(self.context_window_size as u64).format()
193        )
194    }
195
196    /// Formats usage compactly (e.g., "45%").
197    pub fn format_compact(&self) -> String {
198        format!("{:.0}%", self.usage_percentage())
199    }
200}
201
202impl fmt::Display for ContextUsage {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        write!(f, "{}", self.format())
205    }
206}
207
208/// Warning level for context usage.
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
210pub enum ContextWarningLevel {
211    /// No warning needed
212    Normal,
213    /// Usage is elevated but not critical (60-80%)
214    Elevated,
215    /// Usage is high, should consider compacting (80-90%)
216    Warning,
217    /// Usage is critical, action needed (>90%)
218    Critical,
219}
220
221/// Service for analyzing context usage and generating warnings.
222pub struct ContextAnalyzer;
223
224impl ContextAnalyzer {
225    /// Analyzes context usage and returns warning level.
226    pub fn analyze(context: &ContextUsage) -> ContextWarningLevel {
227        let percentage = context.usage_percentage();
228        if percentage >= 90.0 {
229            ContextWarningLevel::Critical
230        } else if percentage >= 80.0 {
231            ContextWarningLevel::Warning
232        } else if percentage >= 60.0 {
233            ContextWarningLevel::Elevated
234        } else {
235            ContextWarningLevel::Normal
236        }
237    }
238
239    /// Generates a warning message if applicable.
240    pub fn warning_message(context: &ContextUsage) -> Option<String> {
241        match Self::analyze(context) {
242            ContextWarningLevel::Critical => Some(format!(
243                "CRITICAL: Context at {:.0}%. Consider /compact or starting new conversation.",
244                context.usage_percentage()
245            )),
246            ContextWarningLevel::Warning => Some(format!(
247                "Warning: Context at {:.0}%. Approaching limit.",
248                context.usage_percentage()
249            )),
250            ContextWarningLevel::Elevated => Some(format!(
251                "Note: Context at {:.0}%.",
252                context.usage_percentage()
253            )),
254            ContextWarningLevel::Normal => None,
255        }
256    }
257
258    /// Estimates remaining "turns" based on average token usage per turn.
259    pub fn estimate_remaining_turns(
260        context: &ContextUsage,
261        avg_tokens_per_turn: u64,
262    ) -> Option<u64> {
263        if avg_tokens_per_turn == 0 {
264            return None;
265        }
266        let remaining = context.remaining_tokens().as_u64();
267        Some(remaining / avg_tokens_per_turn)
268    }
269
270    /// Calculates cache efficiency (cache reads vs total input).
271    pub fn cache_efficiency(context: &ContextUsage) -> f64 {
272        let total_input = context.total_input_tokens.as_u64();
273        if total_input == 0 {
274            return 0.0;
275        }
276        let cache_reads = context.cache_read_tokens.as_u64();
277        (cache_reads as f64 / total_input as f64) * 100.0
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_token_count_formatting() {
287        assert_eq!(TokenCount::new(500).format(), "500");
288        assert_eq!(TokenCount::new(5_000).format(), "5.0K");
289        assert_eq!(TokenCount::new(50_000).format(), "50K");
290        assert_eq!(TokenCount::new(1_500_000).format(), "1.5M");
291    }
292
293    #[test]
294    fn test_usage_percentage_from_current_usage() {
295        // Context tokens from current_usage: cache_read + input + cache_creation
296        // 26000 cache_read + 9 input + 31 cache_creation = 26040 context tokens
297        // 26040 / 200000 = 13.02%
298        let usage = ContextUsage {
299            cache_read_tokens: TokenCount::new(26_000),
300            current_input_tokens: TokenCount::new(9),
301            cache_creation_tokens: TokenCount::new(31),
302            context_window_size: 200_000,
303            ..Default::default()
304        };
305        assert!((usage.usage_percentage() - 13.02).abs() < 0.01);
306        assert_eq!(usage.context_tokens().as_u64(), 26_040);
307    }
308
309    #[test]
310    fn test_usage_percentage_zero_when_current_usage_null() {
311        // When current_usage is null, all current_* fields are 0
312        // This correctly shows 0% after /clear
313        let usage = ContextUsage {
314            total_input_tokens: TokenCount::new(10_000), // cumulative still present
315            total_output_tokens: TokenCount::new(1_000),
316            context_window_size: 200_000,
317            // current_* fields all default to 0
318            ..Default::default()
319        };
320        assert!((usage.usage_percentage() - 0.0).abs() < 0.01);
321    }
322
323    #[test]
324    fn test_warning_thresholds() {
325        // 50% usage from cache_read
326        let normal = ContextUsage {
327            cache_read_tokens: TokenCount::new(100_000),
328            context_window_size: 200_000,
329            ..Default::default()
330        };
331        assert!(!normal.is_warning());
332        assert!(!normal.is_critical());
333        assert_eq!(
334            ContextAnalyzer::analyze(&normal),
335            ContextWarningLevel::Normal
336        );
337
338        // 80% usage
339        let warning = ContextUsage {
340            cache_read_tokens: TokenCount::new(160_000),
341            context_window_size: 200_000,
342            ..Default::default()
343        };
344        assert!(warning.is_warning());
345        assert!(!warning.is_critical());
346        assert_eq!(
347            ContextAnalyzer::analyze(&warning),
348            ContextWarningLevel::Warning
349        );
350
351        // 95% usage
352        let critical = ContextUsage {
353            cache_read_tokens: TokenCount::new(190_000),
354            context_window_size: 200_000,
355            ..Default::default()
356        };
357        assert!(critical.is_warning());
358        assert!(critical.is_critical());
359        assert_eq!(
360            ContextAnalyzer::analyze(&critical),
361            ContextWarningLevel::Critical
362        );
363    }
364
365    #[test]
366    fn test_remaining_tokens() {
367        // 100K context tokens = 100K remaining
368        let usage = ContextUsage {
369            cache_read_tokens: TokenCount::new(100_000),
370            context_window_size: 200_000,
371            ..Default::default()
372        };
373        assert_eq!(usage.remaining_tokens().as_u64(), 100_000);
374    }
375
376    #[test]
377    fn test_context_tokens_calculation() {
378        let usage = ContextUsage {
379            cache_read_tokens: TokenCount::new(25_000),
380            current_input_tokens: TokenCount::new(500),
381            cache_creation_tokens: TokenCount::new(100),
382            context_window_size: 200_000,
383            ..Default::default()
384        };
385        // 25000 + 500 + 100 = 25600
386        assert_eq!(usage.context_tokens().as_u64(), 25_600);
387    }
388}