Skip to main content

codetether_agent/tui/
token_display.rs

1use crate::telemetry::{ContextLimit, CostEstimate, TOKEN_USAGE, TokenUsageSnapshot};
2use crate::tui::theme::Theme;
3use ratatui::{
4    style::{Color, Modifier, Style},
5    text::{Line, Span},
6};
7
8/// Enhanced token usage display with costs and warnings
9pub struct TokenDisplay {
10    /// Pricing data (not context limits — those come from the canonical
11    /// [`crate::provider::limits::context_window_for_model`]).
12    model_pricing: std::collections::HashMap<String, (f64, f64)>,
13}
14
15impl TokenDisplay {
16    pub fn new() -> Self {
17        Self {
18            model_pricing: std::collections::HashMap::new(),
19        }
20    }
21
22    /// Get context limit for a model by delegating to the canonical
23    /// [`crate::provider::limits::context_window_for_model`].
24    ///
25    /// Returns `None` only for the zero-length model string edge case;
26    /// callers that relied on the old `HashMap::get` returning `None`
27    /// for unknown models will now get the default 128 000 instead,
28    /// which is strictly more correct (all models have *some* context
29    /// window).
30    pub fn get_context_limit(&self, model: &str) -> Option<u64> {
31        if model.is_empty() {
32            return None;
33        }
34        Some(crate::provider::limits::context_window_for_model(model) as u64)
35    }
36
37    /// Get pricing for a model (returns $ per million tokens for input/output).
38    ///
39    /// Delegates to the canonical
40    /// [`crate::provider::pricing::pricing_for_model`] so costs in the TUI
41    /// stay in sync with the cost-guardrail enforcement path.
42    fn get_model_pricing(&self, model: &str) -> (f64, f64) {
43        crate::provider::pricing::pricing_for_model(model)
44    }
45
46    /// Calculate cost for a model given input and output token counts
47    pub fn calculate_cost_for_tokens(
48        &self,
49        model: &str,
50        input_tokens: u64,
51        output_tokens: u64,
52    ) -> CostEstimate {
53        let (input_price, output_price) = self.get_model_pricing(model);
54        CostEstimate::from_tokens(
55            &crate::telemetry::TokenCounts::new(input_tokens, output_tokens),
56            input_price,
57            output_price,
58        )
59    }
60
61    /// Create status bar content with token usage
62    pub fn create_status_bar(&self, theme: &Theme) -> Line<'_> {
63        let global_snapshot = TOKEN_USAGE.global_snapshot();
64        let model_snapshots = TOKEN_USAGE.model_snapshots();
65
66        let total_tokens = global_snapshot.totals.total();
67        let session_cost = self.calculate_session_cost();
68        let tps_display = self.get_tps_display();
69
70        let mut spans = Vec::new();
71
72        // Help indicator
73        spans.push(Span::styled(
74            " ? ",
75            Style::default()
76                .fg(theme.status_bar_foreground.to_color())
77                .bg(theme.status_bar_background.to_color()),
78        ));
79        spans.push(Span::raw(" Help "));
80
81        // Switch agent
82        spans.push(Span::styled(
83            " Tab ",
84            Style::default()
85                .fg(theme.status_bar_foreground.to_color())
86                .bg(theme.status_bar_background.to_color()),
87        ));
88        spans.push(Span::raw(" Switch Agent "));
89
90        // Quit
91        spans.push(Span::styled(
92            " Ctrl+C ",
93            Style::default()
94                .fg(theme.status_bar_foreground.to_color())
95                .bg(theme.status_bar_background.to_color()),
96        ));
97        spans.push(Span::raw(" Quit "));
98
99        // Token usage
100        spans.push(Span::styled(
101            format!(" Tokens: {} ", total_tokens),
102            Style::default().fg(theme.timestamp_color.to_color()),
103        ));
104
105        // TPS (tokens per second)
106        if let Some(tps) = tps_display {
107            spans.push(Span::styled(
108                format!(" TPS: {} ", tps),
109                Style::default().fg(Color::Cyan),
110            ));
111        }
112
113        // Cost — colorized based on the configured cost guardrails so
114        // users get an immediate visual when they cross warn / hard limit
115        // thresholds (see `CODETETHER_COST_WARN_USD` /
116        // `CODETETHER_COST_LIMIT_USD`).
117        let cost_style = match crate::session::helper::cost_guard::cost_guard_level() {
118            crate::session::helper::cost_guard::CostGuardLevel::OverLimit => {
119                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)
120            }
121            crate::session::helper::cost_guard::CostGuardLevel::OverWarn => Style::default()
122                .fg(Color::Yellow)
123                .add_modifier(Modifier::BOLD),
124            crate::session::helper::cost_guard::CostGuardLevel::Ok => {
125                Style::default().fg(theme.timestamp_color.to_color())
126            }
127        };
128        spans.push(Span::styled(
129            format!(" Cost: {} ", session_cost.format_smart()),
130            cost_style,
131        ));
132
133        // Prompt-cache hit rate — surfaces whether Anthropic/Bedrock
134        // prompt caching is actually saving input tokens. Only shown when
135        // the active model has recorded any cache activity at all.
136        if let Some(cache_pct) = self.get_cache_hit_rate(&model_snapshots) {
137            spans.push(Span::styled(
138                format!(" Cache: {:.0}% ", cache_pct),
139                Style::default().fg(Color::Green),
140            ));
141        }
142
143        // Context warning if active model is near limit
144        if let Some(warning) = self.get_context_warning(&model_snapshots) {
145            spans.push(Span::styled(
146                format!(" {} ", warning),
147                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD),
148            ));
149        }
150
151        Line::from(spans)
152    }
153
154    /// Calculate total session cost across all models
155    pub fn calculate_session_cost(&self) -> CostEstimate {
156        let model_snapshots = TOKEN_USAGE.model_snapshots();
157        let mut total = CostEstimate::default();
158
159        for snapshot in model_snapshots {
160            let model_cost = self.calculate_cost_for_tokens(
161                &snapshot.name,
162                snapshot.totals.input,
163                snapshot.totals.output,
164            );
165            total.input_cost += model_cost.input_cost;
166            total.output_cost += model_cost.output_cost;
167            total.total_cost += model_cost.total_cost;
168        }
169
170        total
171    }
172
173    /// Get context warning for the active model based on the **current
174    /// turn's** prompt size (what the next request will send), not
175    /// cumulative lifetime tokens.
176    ///
177    /// This matters because the agent loop re-sends the growing
178    /// conversation on every step and the RLM layer compresses history
179    /// behind the scenes — users need to see the real edge of the
180    /// context window, not a bogus 6000% figure summed over all turns.
181    /// This matters because the agent loop re-sends the growing
182    /// conversation on every step and the RLM layer compresses history
183    /// behind the scenes — users need to see the real edge of the
184    /// context window, not a bogus 6000% figure summed over all turns.
185    fn get_context_warning(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<String> {
186        if model_snapshots.is_empty() {
187            return None;
188        }
189
190        let active_model = model_snapshots.iter().max_by_key(|s| s.totals.total())?;
191        let limit = self.get_context_limit(&active_model.name)?;
192
193        // Prefer the last turn's actual prompt size; fall back to cumulative
194        // only if no turn has been recorded yet (first-render race).
195        let used = crate::telemetry::TOKEN_USAGE
196            .last_prompt_tokens_for(&active_model.name)
197            .unwrap_or_else(|| active_model.totals.total().min(limit));
198
199        let context = ContextLimit::new(used, limit);
200
201        if context.percentage >= 90.0 {
202            Some(format!("🛑 Context: {:.0}%", context.percentage))
203        } else if context.percentage >= 75.0 {
204            Some(format!("⚠️ Context: {:.0}%", context.percentage))
205        } else if context.percentage >= 50.0 {
206            Some(format!("Context: {:.0}%", context.percentage))
207        } else {
208            None
209        }
210    }
211
212    /// Aggregate prompt-cache hit rate across all recorded models.
213    ///
214    /// Defined as `cache_read / (cache_read + full_price_input)` × 100,
215    /// i.e. what fraction of billable input was served from the cache.
216    /// Returns `None` when no cache activity has been recorded (which is
217    /// the common case for providers that don't support it).
218    fn get_cache_hit_rate(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<f64> {
219        let mut full_input: u64 = 0;
220        let mut cache_read: u64 = 0;
221        for s in model_snapshots {
222            full_input += s.prompt_tokens;
223            let (cr, _cw) = crate::telemetry::TOKEN_USAGE.cache_usage_for(&s.name);
224            cache_read += cr;
225        }
226        let denom = full_input + cache_read;
227        if cache_read == 0 || denom == 0 {
228            return None;
229        }
230        Some(cache_read as f64 * 100.0 / denom as f64)
231    }
232
233    /// Get TPS (tokens per second) display string from provider metrics
234    fn get_tps_display(&self) -> Option<String> {
235        use crate::telemetry::PROVIDER_METRICS;
236
237        let snapshots = PROVIDER_METRICS.all_snapshots();
238        if snapshots.is_empty() {
239            return None;
240        }
241
242        // Find the provider with the most recent activity
243        let most_active = snapshots
244            .iter()
245            .filter(|s| s.avg_tps > 0.0)
246            .max_by(|a, b| {
247                a.total_output_tokens
248                    .partial_cmp(&b.total_output_tokens)
249                    .unwrap_or(std::cmp::Ordering::Equal)
250            })?;
251
252        // Format TPS nicely
253        let tps = most_active.avg_tps;
254        let formatted = if tps >= 100.0 {
255            format!("{:.0}", tps)
256        } else if tps >= 10.0 {
257            format!("{:.1}", tps)
258        } else {
259            format!("{:.2}", tps)
260        };
261
262        Some(formatted)
263    }
264
265    /// Create detailed token usage display
266    pub fn create_detailed_display(&self) -> Vec<String> {
267        use crate::telemetry::PROVIDER_METRICS;
268
269        let mut lines = Vec::new();
270        let global_snapshot = TOKEN_USAGE.global_snapshot();
271        let model_snapshots = TOKEN_USAGE.model_snapshots();
272
273        lines.push("".to_string());
274        lines.push("  TOKEN USAGE & COSTS".to_string());
275        lines.push("  ===================".to_string());
276        lines.push("".to_string());
277
278        // Global totals
279        let total_cost = self.calculate_session_cost();
280        lines.push(format!(
281            "  Total: {} tokens ({} requests) - {}",
282            global_snapshot.totals.total(),
283            global_snapshot.request_count,
284            total_cost.format_currency()
285        ));
286        lines.push(format!(
287            "  Current: {} in / {} out",
288            global_snapshot.totals.input, global_snapshot.totals.output
289        ));
290        lines.push("".to_string());
291
292        // Per-model breakdown
293        if !model_snapshots.is_empty() {
294            lines.push("  BY MODEL:".to_string());
295
296            for snapshot in model_snapshots.iter().take(5) {
297                let model_cost = self.calculate_cost_for_tokens(
298                    &snapshot.name,
299                    snapshot.totals.input,
300                    snapshot.totals.output,
301                );
302                lines.push(format!(
303                    "    {}: {} tokens ({} requests) - {}",
304                    snapshot.name,
305                    snapshot.totals.total(),
306                    snapshot.request_count,
307                    model_cost.format_currency()
308                ));
309
310                // Context limit info
311                if let Some(limit) = self.get_context_limit(&snapshot.name) {
312                    let context = ContextLimit::new(snapshot.totals.total(), limit);
313                    if context.percentage >= 50.0 {
314                        lines.push(format!(
315                            "      Context: {:.1}% of {} tokens",
316                            context.percentage, limit
317                        ));
318                    }
319                }
320
321                // Prompt-cache stats (Anthropic / Bedrock).
322                let (cache_read, cache_write) =
323                    crate::telemetry::TOKEN_USAGE.cache_usage_for(&snapshot.name);
324                if cache_read > 0 || cache_write > 0 {
325                    let denom = snapshot.prompt_tokens + cache_read;
326                    let hit_pct = if denom > 0 {
327                        cache_read as f64 * 100.0 / denom as f64
328                    } else {
329                        0.0
330                    };
331                    lines.push(format!(
332                        "      Cache: {} read / {} write ({:.1}% hit)",
333                        cache_read, cache_write, hit_pct
334                    ));
335                }
336            }
337
338            if model_snapshots.len() > 5 {
339                lines.push(format!(
340                    "    ... and {} more models",
341                    model_snapshots.len() - 5
342                ));
343            }
344            lines.push("".to_string());
345        }
346
347        // Provider performance metrics (TPS, latency)
348        let provider_snapshots = PROVIDER_METRICS.all_snapshots();
349        if !provider_snapshots.is_empty() {
350            lines.push("  PROVIDER PERFORMANCE:".to_string());
351
352            for snapshot in provider_snapshots.iter().take(5) {
353                if snapshot.request_count > 0 {
354                    lines.push(format!(
355                        "    {}: {:.1} avg TPS | {:.0}ms avg latency | {} reqs",
356                        snapshot.provider,
357                        snapshot.avg_tps,
358                        snapshot.avg_latency_ms,
359                        snapshot.request_count
360                    ));
361
362                    // Show p50/p95 if we have enough requests
363                    if snapshot.request_count >= 5 {
364                        lines.push(format!(
365                            "      p50: {:.1} TPS / {:.0}ms | p95: {:.1} TPS / {:.0}ms",
366                            snapshot.p50_tps,
367                            snapshot.p50_latency_ms,
368                            snapshot.p95_tps,
369                            snapshot.p95_latency_ms
370                        ));
371                    }
372                }
373            }
374            lines.push("".to_string());
375        }
376
377        // Cost estimates
378        lines.push("  COST ESTIMATES:".to_string());
379        lines.push(format!(
380            "    Session total: {}",
381            total_cost.format_currency()
382        ));
383        lines.push("    Based on approximate pricing".to_string());
384
385        lines
386    }
387}
388
389impl Default for TokenDisplay {
390    fn default() -> Self {
391        Self::new()
392    }
393}