Skip to main content

atm_core/
context.rs

1//! Context window and token tracking.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5use std::ops::{Add, AddAssign};
6
7/// Represents a count of tokens.
8///
9/// Used for input tokens, output tokens, cache tokens.
10#[derive(
11    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize,
12)]
13#[serde(transparent)]
14pub struct TokenCount(u64);
15
16impl TokenCount {
17    /// Creates a new TokenCount.
18    pub const fn new(count: u64) -> Self {
19        Self(count)
20    }
21
22    /// Creates a zero TokenCount.
23    pub const fn zero() -> Self {
24        Self(0)
25    }
26
27    /// Returns the raw count.
28    pub const fn as_u64(&self) -> u64 {
29        self.0
30    }
31
32    /// Returns true if count is zero.
33    pub const fn is_zero(&self) -> bool {
34        self.0 == 0
35    }
36
37    /// Formats the token count for display.
38    ///
39    /// Uses K/M suffixes for large numbers.
40    pub fn format(&self) -> String {
41        if self.0 < 1_000 {
42            format!("{}", self.0)
43        } else if self.0 < 10_000 {
44            format!("{:.1}K", self.0 as f64 / 1_000.0)
45        } else if self.0 < 1_000_000 {
46            format!("{}K", self.0 / 1_000)
47        } else {
48            format!("{:.1}M", self.0 as f64 / 1_000_000.0)
49        }
50    }
51
52    /// Saturating addition.
53    pub fn saturating_add(self, other: Self) -> Self {
54        Self(self.0.saturating_add(other.0))
55    }
56}
57
58impl Add for TokenCount {
59    type Output = Self;
60
61    fn add(self, other: Self) -> Self {
62        Self(self.0.saturating_add(other.0))
63    }
64}
65
66impl AddAssign for TokenCount {
67    fn add_assign(&mut self, other: Self) {
68        self.0 = self.0.saturating_add(other.0);
69    }
70}
71
72impl From<u64> for TokenCount {
73    fn from(n: u64) -> Self {
74        Self(n)
75    }
76}
77
78impl From<u32> for TokenCount {
79    fn from(n: u32) -> Self {
80        Self(n as u64)
81    }
82}
83
84impl fmt::Display for TokenCount {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write!(f, "{}", self.format())
87    }
88}
89
90/// Context window usage information.
91///
92/// Tracks token counts and calculates usage percentage based on `current_usage`
93/// from Claude Code's status line. The `current_usage` object reflects the actual
94/// tokens being sent to the API in the current context window.
95///
96/// When `current_usage` is null (e.g., during /clear), all current_* fields are 0,
97/// which correctly shows 0% context usage.
98#[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize)]
99pub struct ContextUsage {
100    /// Total input tokens across all turns (cumulative, for reference)
101    pub total_input_tokens: TokenCount,
102
103    /// Total output tokens across all turns (cumulative, for reference)
104    pub total_output_tokens: TokenCount,
105
106    /// Maximum context window size for the model
107    pub context_window_size: u32,
108
109    /// Current context's input tokens (from current_usage.input_tokens)
110    pub current_input_tokens: TokenCount,
111
112    /// Current context's output tokens (from current_usage.output_tokens)
113    pub current_output_tokens: TokenCount,
114
115    /// Tokens written to cache (from current_usage.cache_creation_input_tokens)
116    pub cache_creation_tokens: TokenCount,
117
118    /// Tokens read from cache - this is the bulk of context (from current_usage.cache_read_input_tokens)
119    pub cache_read_tokens: TokenCount,
120}
121
122impl ContextUsage {
123    /// Creates a new ContextUsage with default values.
124    pub fn new(context_window_size: u32) -> Self {
125        Self {
126            context_window_size,
127            ..Default::default()
128        }
129    }
130
131    /// Calculates the context tokens currently in use.
132    ///
133    /// This is the actual context being sent to the API, calculated from:
134    /// - cache_read_tokens: Previously cached context being reused
135    /// - current_input_tokens: New input tokens in this turn
136    /// - cache_creation_tokens: New tokens being written to cache
137    ///
138    /// When current_usage is null (e.g., after /clear), these are all 0.
139    pub fn context_tokens(&self) -> TokenCount {
140        self.cache_read_tokens
141            .saturating_add(self.current_input_tokens)
142            .saturating_add(self.cache_creation_tokens)
143    }
144
145    /// Calculates the total tokens used (cumulative, for reference).
146    pub fn total_tokens(&self) -> TokenCount {
147        self.total_input_tokens
148            .saturating_add(self.total_output_tokens)
149    }
150
151    /// Returns the percentage of context window used (0.0 to 100.0).
152    ///
153    /// Uses context_tokens() which reflects the actual tokens in the current
154    /// context window (from current_usage). When current_usage is null,
155    /// this returns 0%.
156    pub fn usage_percentage(&self) -> f64 {
157        if self.context_window_size == 0 {
158            return 0.0;
159        }
160        let usage = self.context_tokens().as_u64() as f64 / self.context_window_size as f64;
161        (usage * 100.0).min(100.0)
162    }
163
164    /// Returns true if context usage is above the warning threshold (80%).
165    pub fn is_warning(&self) -> bool {
166        self.usage_percentage() >= 80.0
167    }
168
169    /// Returns true if context usage is critical (>90%).
170    pub fn is_critical(&self) -> bool {
171        self.usage_percentage() >= 90.0
172    }
173
174    /// Returns true if exceeds 200K tokens (Claude Code's extended context marker).
175    pub fn exceeds_200k(&self) -> bool {
176        self.context_tokens().as_u64() > 200_000
177    }
178
179    /// Returns the remaining tokens before hitting context limit.
180    pub fn remaining_tokens(&self) -> TokenCount {
181        let used = self.context_tokens().as_u64();
182        let limit = self.context_window_size as u64;
183        TokenCount::new(limit.saturating_sub(used))
184    }
185
186    /// Formats usage for display (e.g., "45.2% (26.4K/200K)").
187    pub fn format(&self) -> String {
188        format!(
189            "{:.1}% ({}/{})",
190            self.usage_percentage(),
191            self.context_tokens().format(),
192            TokenCount::new(self.context_window_size as u64).format()
193        )
194    }
195
196    /// Formats usage compactly (e.g., "45%").
197    pub fn format_compact(&self) -> String {
198        format!("{:.0}%", self.usage_percentage())
199    }
200}
201
202impl fmt::Display for ContextUsage {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        write!(f, "{}", self.format())
205    }
206}
207
208/// Warning level for context usage.
209#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
210pub enum ContextWarningLevel {
211    /// No warning needed
212    Normal,
213    /// Usage is elevated but not critical (60-80%)
214    Elevated,
215    /// Usage is high, should consider compacting (80-90%)
216    Warning,
217    /// Usage is critical, action needed (>90%)
218    Critical,
219}
220
221/// Service for analyzing context usage and generating warnings.
222pub struct ContextAnalyzer;
223
224impl ContextAnalyzer {
225    /// Analyzes context usage and returns warning level.
226    pub fn analyze(context: &ContextUsage) -> ContextWarningLevel {
227        let percentage = context.usage_percentage();
228        if percentage >= 90.0 {
229            ContextWarningLevel::Critical
230        } else if percentage >= 80.0 {
231            ContextWarningLevel::Warning
232        } else if percentage >= 60.0 {
233            ContextWarningLevel::Elevated
234        } else {
235            ContextWarningLevel::Normal
236        }
237    }
238
239    /// Generates a warning message if applicable.
240    pub fn warning_message(context: &ContextUsage) -> Option<String> {
241        match Self::analyze(context) {
242            ContextWarningLevel::Critical => Some(format!(
243                "CRITICAL: Context at {:.0}%. Consider /compact or starting new conversation.",
244                context.usage_percentage()
245            )),
246            ContextWarningLevel::Warning => Some(format!(
247                "Warning: Context at {:.0}%. Approaching limit.",
248                context.usage_percentage()
249            )),
250            ContextWarningLevel::Elevated => Some(format!(
251                "Note: Context at {:.0}%.",
252                context.usage_percentage()
253            )),
254            ContextWarningLevel::Normal => None,
255        }
256    }
257
258    /// Estimates remaining "turns" based on average token usage per turn.
259    pub fn estimate_remaining_turns(context: &ContextUsage, avg_tokens_per_turn: u64) -> Option<u64> {
260        if avg_tokens_per_turn == 0 {
261            return None;
262        }
263        let remaining = context.remaining_tokens().as_u64();
264        Some(remaining / avg_tokens_per_turn)
265    }
266
267    /// Calculates cache efficiency (cache reads vs total input).
268    pub fn cache_efficiency(context: &ContextUsage) -> f64 {
269        let total_input = context.total_input_tokens.as_u64();
270        if total_input == 0 {
271            return 0.0;
272        }
273        let cache_reads = context.cache_read_tokens.as_u64();
274        (cache_reads as f64 / total_input as f64) * 100.0
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_token_count_formatting() {
284        assert_eq!(TokenCount::new(500).format(), "500");
285        assert_eq!(TokenCount::new(5_000).format(), "5.0K");
286        assert_eq!(TokenCount::new(50_000).format(), "50K");
287        assert_eq!(TokenCount::new(1_500_000).format(), "1.5M");
288    }
289
290    #[test]
291    fn test_usage_percentage_from_current_usage() {
292        // Context tokens from current_usage: cache_read + input + cache_creation
293        // 26000 cache_read + 9 input + 31 cache_creation = 26040 context tokens
294        // 26040 / 200000 = 13.02%
295        let usage = ContextUsage {
296            cache_read_tokens: TokenCount::new(26_000),
297            current_input_tokens: TokenCount::new(9),
298            cache_creation_tokens: TokenCount::new(31),
299            context_window_size: 200_000,
300            ..Default::default()
301        };
302        assert!((usage.usage_percentage() - 13.02).abs() < 0.01);
303        assert_eq!(usage.context_tokens().as_u64(), 26_040);
304    }
305
306    #[test]
307    fn test_usage_percentage_zero_when_current_usage_null() {
308        // When current_usage is null, all current_* fields are 0
309        // This correctly shows 0% after /clear
310        let usage = ContextUsage {
311            total_input_tokens: TokenCount::new(10_000), // cumulative still present
312            total_output_tokens: TokenCount::new(1_000),
313            context_window_size: 200_000,
314            // current_* fields all default to 0
315            ..Default::default()
316        };
317        assert!((usage.usage_percentage() - 0.0).abs() < 0.01);
318    }
319
320    #[test]
321    fn test_warning_thresholds() {
322        // 50% usage from cache_read
323        let normal = ContextUsage {
324            cache_read_tokens: TokenCount::new(100_000),
325            context_window_size: 200_000,
326            ..Default::default()
327        };
328        assert!(!normal.is_warning());
329        assert!(!normal.is_critical());
330        assert_eq!(ContextAnalyzer::analyze(&normal), ContextWarningLevel::Normal);
331
332        // 80% usage
333        let warning = ContextUsage {
334            cache_read_tokens: TokenCount::new(160_000),
335            context_window_size: 200_000,
336            ..Default::default()
337        };
338        assert!(warning.is_warning());
339        assert!(!warning.is_critical());
340        assert_eq!(
341            ContextAnalyzer::analyze(&warning),
342            ContextWarningLevel::Warning
343        );
344
345        // 95% usage
346        let critical = ContextUsage {
347            cache_read_tokens: TokenCount::new(190_000),
348            context_window_size: 200_000,
349            ..Default::default()
350        };
351        assert!(critical.is_warning());
352        assert!(critical.is_critical());
353        assert_eq!(
354            ContextAnalyzer::analyze(&critical),
355            ContextWarningLevel::Critical
356        );
357    }
358
359    #[test]
360    fn test_remaining_tokens() {
361        // 100K context tokens = 100K remaining
362        let usage = ContextUsage {
363            cache_read_tokens: TokenCount::new(100_000),
364            context_window_size: 200_000,
365            ..Default::default()
366        };
367        assert_eq!(usage.remaining_tokens().as_u64(), 100_000);
368    }
369
370    #[test]
371    fn test_context_tokens_calculation() {
372        let usage = ContextUsage {
373            cache_read_tokens: TokenCount::new(25_000),
374            current_input_tokens: TokenCount::new(500),
375            cache_creation_tokens: TokenCount::new(100),
376            context_window_size: 200_000,
377            ..Default::default()
378        };
379        // 25000 + 500 + 100 = 25600
380        assert_eq!(usage.context_tokens().as_u64(), 25_600);
381    }
382}