Skip to main content

codetether_agent/tui/
token_display.rs

1use crate::telemetry::{ContextLimit, CostEstimate, TOKEN_USAGE, TokenUsageSnapshot};
2use crate::tui::theme::Theme;
3use ratatui::{
4    style::{Color, Modifier, Style},
5    text::{Line, Span},
6};
7
8/// Enhanced token usage display with costs and warnings
9pub struct TokenDisplay {
10    /// Pricing data (not context limits — those come from the canonical
11    /// [`crate::provider::limits::context_window_for_model`]).
12    model_pricing: std::collections::HashMap<String, (f64, f64)>,
13}
14
15impl TokenDisplay {
16    pub fn new() -> Self {
17        Self {
18            model_pricing: std::collections::HashMap::new(),
19        }
20    }
21
22    /// Get context limit for a model by delegating to the canonical
23    /// [`crate::provider::limits::context_window_for_model`].
24    ///
25    /// Returns `None` only for the zero-length model string edge case;
26    /// callers that relied on the old `HashMap::get` returning `None`
27    /// for unknown models will now get the default 128 000 instead,
28    /// which is strictly more correct (all models have *some* context
29    /// window).
30    pub fn get_context_limit(&self, model: &str) -> Option<u64> {
31        if model.is_empty() {
32            return None;
33        }
34        Some(crate::provider::limits::context_window_for_model(model) as u64)
35    }
36
37    /// Get pricing for a model (returns $ per million tokens for input/output)
38    fn get_model_pricing(&self, model: &str) -> (f64, f64) {
39        match model.to_lowercase().as_str() {
40            m if m.contains("gpt-4o-mini") => (0.15, 0.60), // $0.15 / $0.60 per million
41            m if m.contains("gpt-4o") => (2.50, 10.00),     // $2.50 / $10.00 per million
42            m if m.contains("gpt-4-turbo") => (10.00, 30.00), // $10 / $30 per million
43            m if m.contains("gpt-4") => (30.00, 60.00),     // $30 / $60 per million
44            m if m.contains("claude-3-5-sonnet") => (3.00, 15.00), // $3 / $15 per million
45            m if m.contains("claude-3-5-haiku") => (0.80, 4.00), // $0.80 / $4 per million
46            m if m.contains("claude-opus") => (5.00, 25.00), // $5 / $25 per million (Bedrock Opus 4.6)
47            m if m.contains("gemini-2.0-flash") => (0.075, 0.30), // $0.075 / $0.30 per million
48            m if m.contains("gemini-1.5-flash") => (0.075, 0.30), // $0.075 / $0.30 per million
49            m if m.contains("gemini-1.5-pro") => (1.25, 5.00), // $1.25 / $5 per million
50            m if m.contains("glm-4") => (0.50, 0.50),        // ZhipuAI GLM-4 ~$0.50/million
51            m if m.contains("k1.5") => (8.00, 8.00),         // Moonshot K1.5
52            m if m.contains("k1.6") => (6.00, 6.00),         // Moonshot K1.6
53            _ => (1.00, 3.00),                               // Default fallback
54        }
55    }
56
57    /// Calculate cost for a model given input and output token counts
58    pub fn calculate_cost_for_tokens(
59        &self,
60        model: &str,
61        input_tokens: u64,
62        output_tokens: u64,
63    ) -> CostEstimate {
64        let (input_price, output_price) = self.get_model_pricing(model);
65        CostEstimate::from_tokens(
66            &crate::telemetry::TokenCounts::new(input_tokens, output_tokens),
67            input_price,
68            output_price,
69        )
70    }
71
72    /// Create status bar content with token usage
73    pub fn create_status_bar(&self, theme: &Theme) -> Line<'_> {
74        let global_snapshot = TOKEN_USAGE.global_snapshot();
75        let model_snapshots = TOKEN_USAGE.model_snapshots();
76
77        let total_tokens = global_snapshot.totals.total();
78        let session_cost = self.calculate_session_cost();
79        let tps_display = self.get_tps_display();
80
81        let mut spans = Vec::new();
82
83        // Help indicator
84        spans.push(Span::styled(
85            " ? ",
86            Style::default()
87                .fg(theme.status_bar_foreground.to_color())
88                .bg(theme.status_bar_background.to_color()),
89        ));
90        spans.push(Span::raw(" Help "));
91
92        // Switch agent
93        spans.push(Span::styled(
94            " Tab ",
95            Style::default()
96                .fg(theme.status_bar_foreground.to_color())
97                .bg(theme.status_bar_background.to_color()),
98        ));
99        spans.push(Span::raw(" Switch Agent "));
100
101        // Quit
102        spans.push(Span::styled(
103            " Ctrl+C ",
104            Style::default()
105                .fg(theme.status_bar_foreground.to_color())
106                .bg(theme.status_bar_background.to_color()),
107        ));
108        spans.push(Span::raw(" Quit "));
109
110        // Token usage
111        spans.push(Span::styled(
112            format!(" Tokens: {} ", total_tokens),
113            Style::default().fg(theme.timestamp_color.to_color()),
114        ));
115
116        // TPS (tokens per second)
117        if let Some(tps) = tps_display {
118            spans.push(Span::styled(
119                format!(" TPS: {} ", tps),
120                Style::default().fg(Color::Cyan),
121            ));
122        }
123
124        // Cost
125        spans.push(Span::styled(
126            format!(" Cost: {} ", session_cost.format_smart()),
127            Style::default().fg(theme.timestamp_color.to_color()),
128        ));
129
130        // Context warning if active model is near limit
131        if let Some(warning) = self.get_context_warning(&model_snapshots) {
132            spans.push(Span::styled(
133                format!(" {} ", warning),
134                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD),
135            ));
136        }
137
138        Line::from(spans)
139    }
140
141    /// Calculate total session cost across all models
142    pub fn calculate_session_cost(&self) -> CostEstimate {
143        let model_snapshots = TOKEN_USAGE.model_snapshots();
144        let mut total = CostEstimate::default();
145
146        for snapshot in model_snapshots {
147            let model_cost = self.calculate_cost_for_tokens(
148                &snapshot.name,
149                snapshot.totals.input,
150                snapshot.totals.output,
151            );
152            total.input_cost += model_cost.input_cost;
153            total.output_cost += model_cost.output_cost;
154            total.total_cost += model_cost.total_cost;
155        }
156
157        total
158    }
159
160    /// Get context warning for active model
161    fn get_context_warning(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<String> {
162        if model_snapshots.is_empty() {
163            return None;
164        }
165
166        // Use the model with highest usage as "active"
167        let active_model = model_snapshots.iter().max_by_key(|s| s.totals.total())?;
168
169        if let Some(limit) = self.get_context_limit(&active_model.name) {
170            let context = ContextLimit::new(active_model.totals.total(), limit);
171
172            if context.percentage >= 75.0 {
173                return Some(format!("⚠️ Context: {:.1}%", context.percentage));
174            }
175        }
176
177        None
178    }
179
180    /// Get TPS (tokens per second) display string from provider metrics
181    fn get_tps_display(&self) -> Option<String> {
182        use crate::telemetry::PROVIDER_METRICS;
183
184        let snapshots = PROVIDER_METRICS.all_snapshots();
185        if snapshots.is_empty() {
186            return None;
187        }
188
189        // Find the provider with the most recent activity
190        let most_active = snapshots
191            .iter()
192            .filter(|s| s.avg_tps > 0.0)
193            .max_by(|a, b| {
194                a.total_output_tokens
195                    .partial_cmp(&b.total_output_tokens)
196                    .unwrap_or(std::cmp::Ordering::Equal)
197            })?;
198
199        // Format TPS nicely
200        let tps = most_active.avg_tps;
201        let formatted = if tps >= 100.0 {
202            format!("{:.0}", tps)
203        } else if tps >= 10.0 {
204            format!("{:.1}", tps)
205        } else {
206            format!("{:.2}", tps)
207        };
208
209        Some(formatted)
210    }
211
212    /// Create detailed token usage display
213    pub fn create_detailed_display(&self) -> Vec<String> {
214        use crate::telemetry::PROVIDER_METRICS;
215
216        let mut lines = Vec::new();
217        let global_snapshot = TOKEN_USAGE.global_snapshot();
218        let model_snapshots = TOKEN_USAGE.model_snapshots();
219
220        lines.push("".to_string());
221        lines.push("  TOKEN USAGE & COSTS".to_string());
222        lines.push("  ===================".to_string());
223        lines.push("".to_string());
224
225        // Global totals
226        let total_cost = self.calculate_session_cost();
227        lines.push(format!(
228            "  Total: {} tokens ({} requests) - {}",
229            global_snapshot.totals.total(),
230            global_snapshot.request_count,
231            total_cost.format_currency()
232        ));
233        lines.push(format!(
234            "  Current: {} in / {} out",
235            global_snapshot.totals.input, global_snapshot.totals.output
236        ));
237        lines.push("".to_string());
238
239        // Per-model breakdown
240        if !model_snapshots.is_empty() {
241            lines.push("  BY MODEL:".to_string());
242
243            for snapshot in model_snapshots.iter().take(5) {
244                let model_cost = self.calculate_cost_for_tokens(
245                    &snapshot.name,
246                    snapshot.totals.input,
247                    snapshot.totals.output,
248                );
249                lines.push(format!(
250                    "    {}: {} tokens ({} requests) - {}",
251                    snapshot.name,
252                    snapshot.totals.total(),
253                    snapshot.request_count,
254                    model_cost.format_currency()
255                ));
256
257                // Context limit info
258                if let Some(limit) = self.get_context_limit(&snapshot.name) {
259                    let context = ContextLimit::new(snapshot.totals.total(), limit);
260                    if context.percentage >= 50.0 {
261                        lines.push(format!(
262                            "      Context: {:.1}% of {} tokens",
263                            context.percentage, limit
264                        ));
265                    }
266                }
267            }
268
269            if model_snapshots.len() > 5 {
270                lines.push(format!(
271                    "    ... and {} more models",
272                    model_snapshots.len() - 5
273                ));
274            }
275            lines.push("".to_string());
276        }
277
278        // Provider performance metrics (TPS, latency)
279        let provider_snapshots = PROVIDER_METRICS.all_snapshots();
280        if !provider_snapshots.is_empty() {
281            lines.push("  PROVIDER PERFORMANCE:".to_string());
282
283            for snapshot in provider_snapshots.iter().take(5) {
284                if snapshot.request_count > 0 {
285                    lines.push(format!(
286                        "    {}: {:.1} avg TPS | {:.0}ms avg latency | {} reqs",
287                        snapshot.provider,
288                        snapshot.avg_tps,
289                        snapshot.avg_latency_ms,
290                        snapshot.request_count
291                    ));
292
293                    // Show p50/p95 if we have enough requests
294                    if snapshot.request_count >= 5 {
295                        lines.push(format!(
296                            "      p50: {:.1} TPS / {:.0}ms | p95: {:.1} TPS / {:.0}ms",
297                            snapshot.p50_tps,
298                            snapshot.p50_latency_ms,
299                            snapshot.p95_tps,
300                            snapshot.p95_latency_ms
301                        ));
302                    }
303                }
304            }
305            lines.push("".to_string());
306        }
307
308        // Cost estimates
309        lines.push("  COST ESTIMATES:".to_string());
310        lines.push(format!(
311            "    Session total: {}",
312            total_cost.format_currency()
313        ));
314        lines.push("    Based on approximate pricing".to_string());
315
316        lines
317    }
318}
319
320impl Default for TokenDisplay {
321    fn default() -> Self {
322        Self::new()
323    }
324}