Skip to main content

codetether_agent/tui/
token_display.rs

1use crate::telemetry::{ContextLimit, CostEstimate, TOKEN_USAGE, TokenUsageSnapshot};
2use crate::tui::theme::Theme;
3use ratatui::{
4    style::{Color, Modifier, Style},
5    text::{Line, Span},
6};
7
8/// Enhanced token usage display with costs and warnings
9pub struct TokenDisplay;
10
11impl TokenDisplay {
12    pub fn new() -> Self {
13        Self
14    }
15
16    /// Get context limit for a model by delegating to the canonical
17    /// [`crate::provider::limits::context_window_for_model`].
18    ///
19    /// Returns `None` only for the zero-length model string edge case;
20    /// callers that relied on the old `HashMap::get` returning `None`
21    /// for unknown models will now get the default 128 000 instead,
22    /// which is strictly more correct (all models have *some* context
23    /// window).
24    pub fn get_context_limit(&self, model: &str) -> Option<u64> {
25        if model.is_empty() {
26            return None;
27        }
28        Some(crate::provider::limits::context_window_for_model(model) as u64)
29    }
30
31    /// Get pricing for a model (returns $ per million tokens for input/output).
32    ///
33    /// Delegates to the canonical
34    /// [`crate::provider::pricing::pricing_for_model`] so costs in the TUI
35    /// stay in sync with the cost-guardrail enforcement path.
36    fn get_model_pricing(&self, model: &str) -> (f64, f64) {
37        crate::provider::pricing::pricing_for_model(model)
38    }
39
40    /// Calculate cost for a model given input and output token counts
41    pub fn calculate_cost_for_tokens(
42        &self,
43        model: &str,
44        input_tokens: u64,
45        output_tokens: u64,
46    ) -> CostEstimate {
47        let (input_price, output_price) = self.get_model_pricing(model);
48        CostEstimate::from_tokens(
49            &crate::telemetry::TokenCounts::new(input_tokens, output_tokens),
50            input_price,
51            output_price,
52        )
53    }
54
55    /// Create status bar content with token usage
56    pub fn create_status_bar(&self, theme: &Theme) -> Line<'_> {
57        let global_snapshot = TOKEN_USAGE.global_snapshot();
58        let model_snapshots = TOKEN_USAGE.model_snapshots();
59
60        let total_tokens = global_snapshot.totals.total();
61        let session_cost = self.calculate_session_cost();
62        let tps_display = self.get_tps_display();
63
64        let mut spans = Vec::new();
65
66        // Help indicator
67        spans.push(Span::styled(
68            " ? ",
69            Style::default()
70                .fg(theme.status_bar_foreground.to_color())
71                .bg(theme.status_bar_background.to_color()),
72        ));
73        spans.push(Span::raw(" Help "));
74
75        // Switch agent
76        spans.push(Span::styled(
77            " Tab ",
78            Style::default()
79                .fg(theme.status_bar_foreground.to_color())
80                .bg(theme.status_bar_background.to_color()),
81        ));
82        spans.push(Span::raw(" Switch Agent "));
83
84        // Quit
85        spans.push(Span::styled(
86            " Ctrl+C ",
87            Style::default()
88                .fg(theme.status_bar_foreground.to_color())
89                .bg(theme.status_bar_background.to_color()),
90        ));
91        spans.push(Span::raw(" Quit "));
92
93        // Token usage
94        spans.push(Span::styled(
95            format!(" Tokens: {} ", total_tokens),
96            Style::default().fg(theme.timestamp_color.to_color()),
97        ));
98
99        // TPS (tokens per second)
100        if let Some(tps) = tps_display {
101            spans.push(Span::styled(
102                format!(" TPS: {} ", tps),
103                Style::default().fg(Color::Cyan),
104            ));
105        }
106
107        // Cost — colorized based on the configured cost guardrails so
108        // users get an immediate visual when they cross warn / hard limit
109        // thresholds (see `CODETETHER_COST_WARN_USD` /
110        // `CODETETHER_COST_LIMIT_USD`).
111        let cost_style = match crate::session::helper::cost_guard::cost_guard_level() {
112            crate::session::helper::cost_guard::CostGuardLevel::OverLimit => {
113                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)
114            }
115            crate::session::helper::cost_guard::CostGuardLevel::OverWarn => Style::default()
116                .fg(Color::Yellow)
117                .add_modifier(Modifier::BOLD),
118            crate::session::helper::cost_guard::CostGuardLevel::Ok => {
119                Style::default().fg(theme.timestamp_color.to_color())
120            }
121        };
122        spans.push(Span::styled(
123            format!(" Cost: {} ", session_cost.format_smart()),
124            cost_style,
125        ));
126
127        // Prompt-cache hit rate — surfaces whether Anthropic/Bedrock
128        // prompt caching is actually saving input tokens. Only shown when
129        // the active model has recorded any cache activity at all.
130        if let Some(cache_pct) = self.get_cache_hit_rate(&model_snapshots) {
131            spans.push(Span::styled(
132                format!(" Cache: {:.0}% ", cache_pct),
133                Style::default().fg(Color::Green),
134            ));
135        }
136
137        // Context warning if active model is near limit
138        if let Some(warning) = self.get_context_warning(&model_snapshots) {
139            spans.push(Span::styled(
140                format!(" {} ", warning),
141                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD),
142            ));
143        }
144
145        Line::from(spans)
146    }
147
148    /// Calculate total session cost across all models
149    pub fn calculate_session_cost(&self) -> CostEstimate {
150        let model_snapshots = TOKEN_USAGE.model_snapshots();
151        let mut total = CostEstimate::default();
152
153        for snapshot in model_snapshots {
154            let model_cost = self.calculate_cost_for_tokens(
155                &snapshot.name,
156                snapshot.totals.input,
157                snapshot.totals.output,
158            );
159            total.input_cost += model_cost.input_cost;
160            total.output_cost += model_cost.output_cost;
161            total.total_cost += model_cost.total_cost;
162        }
163
164        total
165    }
166
167    /// Get context warning for the active model based on the **current
168    /// turn's** prompt size (what the next request will send), not
169    /// cumulative lifetime tokens.
170    ///
171    /// This matters because the agent loop re-sends the growing
172    /// conversation on every step and the RLM layer compresses history
173    /// behind the scenes — users need to see the real edge of the
174    /// context window, not a bogus 6000% figure summed over all turns.
175    /// This matters because the agent loop re-sends the growing
176    /// conversation on every step and the RLM layer compresses history
177    /// behind the scenes — users need to see the real edge of the
178    /// context window, not a bogus 6000% figure summed over all turns.
179    fn get_context_warning(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<String> {
180        if model_snapshots.is_empty() {
181            return None;
182        }
183
184        let active_model = model_snapshots.iter().max_by_key(|s| s.totals.total())?;
185        let limit = self.get_context_limit(&active_model.name)?;
186
187        // Prefer the last turn's actual prompt size; fall back to cumulative
188        // only if no turn has been recorded yet (first-render race).
189        let used = crate::telemetry::TOKEN_USAGE
190            .last_prompt_tokens_for(&active_model.name)
191            .unwrap_or_else(|| active_model.totals.total().min(limit));
192
193        let context = ContextLimit::new(used, limit);
194
195        if context.percentage >= 90.0 {
196            Some(format!("🛑 Context: {:.0}%", context.percentage))
197        } else if context.percentage >= 75.0 {
198            Some(format!("⚠️ Context: {:.0}%", context.percentage))
199        } else if context.percentage >= 50.0 {
200            Some(format!("Context: {:.0}%", context.percentage))
201        } else {
202            None
203        }
204    }
205
206    /// Aggregate prompt-cache hit rate across all recorded models.
207    ///
208    /// Defined as `cache_read / (cache_read + full_price_input)` × 100,
209    /// i.e. what fraction of billable input was served from the cache.
210    /// Returns `None` when no cache activity has been recorded (which is
211    /// the common case for providers that don't support it).
212    fn get_cache_hit_rate(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<f64> {
213        let mut full_input: u64 = 0;
214        let mut cache_read: u64 = 0;
215        for s in model_snapshots {
216            full_input += s.prompt_tokens;
217            let (cr, _cw) = crate::telemetry::TOKEN_USAGE.cache_usage_for(&s.name);
218            cache_read += cr;
219        }
220        let denom = full_input + cache_read;
221        if cache_read == 0 || denom == 0 {
222            return None;
223        }
224        Some(cache_read as f64 * 100.0 / denom as f64)
225    }
226
227    /// Get TPS (tokens per second) display string from provider metrics
228    fn get_tps_display(&self) -> Option<String> {
229        use crate::telemetry::PROVIDER_METRICS;
230
231        let snapshots = PROVIDER_METRICS.all_snapshots();
232        if snapshots.is_empty() {
233            return None;
234        }
235
236        // Find the provider with the most recent activity
237        let most_active = snapshots
238            .iter()
239            .filter(|s| s.avg_tps > 0.0)
240            .max_by(|a, b| {
241                a.total_output_tokens
242                    .partial_cmp(&b.total_output_tokens)
243                    .unwrap_or(std::cmp::Ordering::Equal)
244            })?;
245
246        // Format TPS nicely
247        let tps = most_active.avg_tps;
248        let formatted = if tps >= 100.0 {
249            format!("{:.0}", tps)
250        } else if tps >= 10.0 {
251            format!("{:.1}", tps)
252        } else {
253            format!("{:.2}", tps)
254        };
255
256        Some(formatted)
257    }
258
259    /// Create detailed token usage display
260    pub fn create_detailed_display(&self) -> Vec<String> {
261        use crate::telemetry::PROVIDER_METRICS;
262
263        let mut lines = Vec::new();
264        let global_snapshot = TOKEN_USAGE.global_snapshot();
265        let model_snapshots = TOKEN_USAGE.model_snapshots();
266
267        lines.push("".to_string());
268        lines.push("  TOKEN USAGE & COSTS".to_string());
269        lines.push("  ===================".to_string());
270        lines.push("".to_string());
271
272        // Global totals
273        let total_cost = self.calculate_session_cost();
274        lines.push(format!(
275            "  Total: {} tokens ({} requests) - {}",
276            global_snapshot.totals.total(),
277            global_snapshot.request_count,
278            total_cost.format_currency()
279        ));
280        lines.push(format!(
281            "  Current: {} in / {} out",
282            global_snapshot.totals.input, global_snapshot.totals.output
283        ));
284        lines.push("".to_string());
285
286        // Per-model breakdown
287        if !model_snapshots.is_empty() {
288            lines.push("  BY MODEL:".to_string());
289
290            for snapshot in model_snapshots.iter().take(5) {
291                let model_cost = self.calculate_cost_for_tokens(
292                    &snapshot.name,
293                    snapshot.totals.input,
294                    snapshot.totals.output,
295                );
296                lines.push(format!(
297                    "    {}: {} tokens ({} requests) - {}",
298                    snapshot.name,
299                    snapshot.totals.total(),
300                    snapshot.request_count,
301                    model_cost.format_currency()
302                ));
303
304                // Context limit info
305                if let Some(limit) = self.get_context_limit(&snapshot.name) {
306                    let context = ContextLimit::new(snapshot.totals.total(), limit);
307                    if context.percentage >= 50.0 {
308                        lines.push(format!(
309                            "      Context: {:.1}% of {} tokens",
310                            context.percentage, limit
311                        ));
312                    }
313                }
314
315                // Prompt-cache stats (Anthropic / Bedrock).
316                let (cache_read, cache_write) =
317                    crate::telemetry::TOKEN_USAGE.cache_usage_for(&snapshot.name);
318                if cache_read > 0 || cache_write > 0 {
319                    let denom = snapshot.prompt_tokens + cache_read;
320                    let hit_pct = if denom > 0 {
321                        cache_read as f64 * 100.0 / denom as f64
322                    } else {
323                        0.0
324                    };
325                    lines.push(format!(
326                        "      Cache: {} read / {} write ({:.1}% hit)",
327                        cache_read, cache_write, hit_pct
328                    ));
329                }
330            }
331
332            if model_snapshots.len() > 5 {
333                lines.push(format!(
334                    "    ... and {} more models",
335                    model_snapshots.len() - 5
336                ));
337            }
338            lines.push("".to_string());
339        }
340
341        // Provider performance metrics (TPS, latency)
342        let provider_snapshots = PROVIDER_METRICS.all_snapshots();
343        if !provider_snapshots.is_empty() {
344            lines.push("  PROVIDER PERFORMANCE:".to_string());
345
346            for snapshot in provider_snapshots.iter().take(5) {
347                if snapshot.request_count > 0 {
348                    lines.push(format!(
349                        "    {}: {:.1} avg TPS | {:.0}ms avg latency | {} reqs",
350                        snapshot.provider,
351                        snapshot.avg_tps,
352                        snapshot.avg_latency_ms,
353                        snapshot.request_count
354                    ));
355
356                    // Show p50/p95 if we have enough requests
357                    if snapshot.request_count >= 5 {
358                        lines.push(format!(
359                            "      p50: {:.1} TPS / {:.0}ms | p95: {:.1} TPS / {:.0}ms",
360                            snapshot.p50_tps,
361                            snapshot.p50_latency_ms,
362                            snapshot.p95_tps,
363                            snapshot.p95_latency_ms
364                        ));
365                    }
366                }
367            }
368            lines.push("".to_string());
369        }
370
371        // Cost estimates
372        lines.push("  COST ESTIMATES:".to_string());
373        lines.push(format!(
374            "    Session total: {}",
375            total_cost.format_currency()
376        ));
377        lines.push("    Based on approximate pricing".to_string());
378
379        lines
380    }
381}
382
383impl Default for TokenDisplay {
384    fn default() -> Self {
385        Self::new()
386    }
387}