Skip to main content

codetether_agent/tui/
token_display.rs

1use crate::telemetry::{ContextLimit, CostEstimate, TOKEN_USAGE, TokenUsageSnapshot};
2use crate::tui::theme::Theme;
3use ratatui::{
4    style::{Color, Modifier, Style},
5    text::{Line, Span},
6};
7
8/// Enhanced token usage display with costs and warnings
9pub struct TokenDisplay {
10    model_context_limits: std::collections::HashMap<String, u64>,
11}
12
13impl TokenDisplay {
14    pub fn new() -> Self {
15        let mut limits = std::collections::HashMap::new();
16
17        // Common model context limits
18        limits.insert("gpt-4".to_string(), 128_000);
19        limits.insert("gpt-4-turbo".to_string(), 128_000);
20        limits.insert("gpt-4o".to_string(), 128_000);
21        limits.insert("gpt-4o-mini".to_string(), 128_000);
22        limits.insert("claude-3-5-sonnet".to_string(), 200_000);
23        limits.insert("claude-3-5-haiku".to_string(), 200_000);
24        limits.insert("claude-3-opus".to_string(), 200_000);
25        limits.insert("claude-opus-4-6".to_string(), 200_000);
26        limits.insert("gemini-2.0-flash".to_string(), 1_000_000);
27        limits.insert("gemini-1.5-flash".to_string(), 1_000_000);
28        limits.insert("gemini-1.5-pro".to_string(), 2_000_000);
29        limits.insert("k1.5".to_string(), 200_000);
30        limits.insert("k1.6".to_string(), 200_000);
31
32        Self {
33            model_context_limits: limits,
34        }
35    }
36
37    /// Get context limit for a model
38    pub fn get_context_limit(&self, model: &str) -> Option<u64> {
39        self.model_context_limits.get(model).copied()
40    }
41
42    /// Get pricing for a model (returns $ per million tokens for input/output)
43    fn get_model_pricing(&self, model: &str) -> (f64, f64) {
44        match model.to_lowercase().as_str() {
45            m if m.contains("gpt-4o-mini") => (0.15, 0.60), // $0.15 / $0.60 per million
46            m if m.contains("gpt-4o") => (2.50, 10.00),     // $2.50 / $10.00 per million
47            m if m.contains("gpt-4-turbo") => (10.00, 30.00), // $10 / $30 per million
48            m if m.contains("gpt-4") => (30.00, 60.00),     // $30 / $60 per million
49            m if m.contains("claude-3-5-sonnet") => (3.00, 15.00), // $3 / $15 per million
50            m if m.contains("claude-3-5-haiku") => (0.80, 4.00), // $0.80 / $4 per million
51            m if m.contains("claude-opus") => (5.00, 25.00), // $5 / $25 per million (Bedrock Opus 4.6)
52            m if m.contains("gemini-2.0-flash") => (0.075, 0.30), // $0.075 / $0.30 per million
53            m if m.contains("gemini-1.5-flash") => (0.075, 0.30), // $0.075 / $0.30 per million
54            m if m.contains("gemini-1.5-pro") => (1.25, 5.00), // $1.25 / $5 per million
55            m if m.contains("glm-4") => (0.50, 0.50),        // ZhipuAI GLM-4 ~$0.50/million
56            m if m.contains("k1.5") => (8.00, 8.00),         // Moonshot K1.5
57            m if m.contains("k1.6") => (6.00, 6.00),         // Moonshot K1.6
58            _ => (1.00, 3.00),                               // Default fallback
59        }
60    }
61
62    /// Calculate cost for a model given input and output token counts
63    pub fn calculate_cost_for_tokens(
64        &self,
65        model: &str,
66        input_tokens: u64,
67        output_tokens: u64,
68    ) -> CostEstimate {
69        let (input_price, output_price) = self.get_model_pricing(model);
70        CostEstimate::from_tokens(
71            &crate::telemetry::TokenCounts::new(input_tokens, output_tokens),
72            input_price,
73            output_price,
74        )
75    }
76
77    /// Create status bar content with token usage
78    pub fn create_status_bar(&self, theme: &Theme) -> Line<'_> {
79        let global_snapshot = TOKEN_USAGE.global_snapshot();
80        let model_snapshots = TOKEN_USAGE.model_snapshots();
81
82        let total_tokens = global_snapshot.totals.total();
83        let session_cost = self.calculate_session_cost();
84        let tps_display = self.get_tps_display();
85
86        let mut spans = Vec::new();
87
88        // Help indicator
89        spans.push(Span::styled(
90            " ? ",
91            Style::default()
92                .fg(theme.status_bar_foreground.to_color())
93                .bg(theme.status_bar_background.to_color()),
94        ));
95        spans.push(Span::raw(" Help "));
96
97        // Switch agent
98        spans.push(Span::styled(
99            " Tab ",
100            Style::default()
101                .fg(theme.status_bar_foreground.to_color())
102                .bg(theme.status_bar_background.to_color()),
103        ));
104        spans.push(Span::raw(" Switch Agent "));
105
106        // Quit
107        spans.push(Span::styled(
108            " Ctrl+C ",
109            Style::default()
110                .fg(theme.status_bar_foreground.to_color())
111                .bg(theme.status_bar_background.to_color()),
112        ));
113        spans.push(Span::raw(" Quit "));
114
115        // Token usage
116        spans.push(Span::styled(
117            format!(" Tokens: {} ", total_tokens),
118            Style::default().fg(theme.timestamp_color.to_color()),
119        ));
120
121        // TPS (tokens per second)
122        if let Some(tps) = tps_display {
123            spans.push(Span::styled(
124                format!(" TPS: {} ", tps),
125                Style::default().fg(Color::Cyan),
126            ));
127        }
128
129        // Cost
130        spans.push(Span::styled(
131            format!(" Cost: {} ", session_cost.format_smart()),
132            Style::default().fg(theme.timestamp_color.to_color()),
133        ));
134
135        // Context warning if active model is near limit
136        if let Some(warning) = self.get_context_warning(&model_snapshots) {
137            spans.push(Span::styled(
138                format!(" {} ", warning),
139                Style::default().fg(Color::Red).add_modifier(Modifier::BOLD),
140            ));
141        }
142
143        Line::from(spans)
144    }
145
146    /// Calculate total session cost across all models
147    pub fn calculate_session_cost(&self) -> CostEstimate {
148        let model_snapshots = TOKEN_USAGE.model_snapshots();
149        let mut total = CostEstimate::default();
150
151        for snapshot in model_snapshots {
152            let model_cost = self.calculate_cost_for_tokens(
153                &snapshot.name,
154                snapshot.totals.input,
155                snapshot.totals.output,
156            );
157            total.input_cost += model_cost.input_cost;
158            total.output_cost += model_cost.output_cost;
159            total.total_cost += model_cost.total_cost;
160        }
161
162        total
163    }
164
165    /// Get context warning for active model
166    fn get_context_warning(&self, model_snapshots: &[TokenUsageSnapshot]) -> Option<String> {
167        if model_snapshots.is_empty() {
168            return None;
169        }
170
171        // Use the model with highest usage as "active"
172        let active_model = model_snapshots.iter().max_by_key(|s| s.totals.total())?;
173
174        if let Some(limit) = self.get_context_limit(&active_model.name) {
175            let context = ContextLimit::new(active_model.totals.total(), limit);
176
177            if context.percentage >= 75.0 {
178                return Some(format!("⚠️ Context: {:.1}%", context.percentage));
179            }
180        }
181
182        None
183    }
184
185    /// Get TPS (tokens per second) display string from provider metrics
186    fn get_tps_display(&self) -> Option<String> {
187        use crate::telemetry::PROVIDER_METRICS;
188
189        let snapshots = PROVIDER_METRICS.all_snapshots();
190        if snapshots.is_empty() {
191            return None;
192        }
193
194        // Find the provider with the most recent activity
195        let most_active = snapshots
196            .iter()
197            .filter(|s| s.avg_tps > 0.0)
198            .max_by(|a, b| {
199                a.total_output_tokens
200                    .partial_cmp(&b.total_output_tokens)
201                    .unwrap_or(std::cmp::Ordering::Equal)
202            })?;
203
204        // Format TPS nicely
205        let tps = most_active.avg_tps;
206        let formatted = if tps >= 100.0 {
207            format!("{:.0}", tps)
208        } else if tps >= 10.0 {
209            format!("{:.1}", tps)
210        } else {
211            format!("{:.2}", tps)
212        };
213
214        Some(formatted)
215    }
216
217    /// Create detailed token usage display
218    pub fn create_detailed_display(&self) -> Vec<String> {
219        use crate::telemetry::PROVIDER_METRICS;
220
221        let mut lines = Vec::new();
222        let global_snapshot = TOKEN_USAGE.global_snapshot();
223        let model_snapshots = TOKEN_USAGE.model_snapshots();
224
225        lines.push("".to_string());
226        lines.push("  TOKEN USAGE & COSTS".to_string());
227        lines.push("  ===================".to_string());
228        lines.push("".to_string());
229
230        // Global totals
231        let total_cost = self.calculate_session_cost();
232        lines.push(format!(
233            "  Total: {} tokens ({} requests) - {}",
234            global_snapshot.totals.total(),
235            global_snapshot.request_count,
236            total_cost.format_currency()
237        ));
238        lines.push(format!(
239            "  Current: {} in / {} out",
240            global_snapshot.totals.input, global_snapshot.totals.output
241        ));
242        lines.push("".to_string());
243
244        // Per-model breakdown
245        if !model_snapshots.is_empty() {
246            lines.push("  BY MODEL:".to_string());
247
248            for snapshot in model_snapshots.iter().take(5) {
249                let model_cost = self.calculate_cost_for_tokens(
250                    &snapshot.name,
251                    snapshot.totals.input,
252                    snapshot.totals.output,
253                );
254                lines.push(format!(
255                    "    {}: {} tokens ({} requests) - {}",
256                    snapshot.name,
257                    snapshot.totals.total(),
258                    snapshot.request_count,
259                    model_cost.format_currency()
260                ));
261
262                // Context limit info
263                if let Some(limit) = self.get_context_limit(&snapshot.name) {
264                    let context = ContextLimit::new(snapshot.totals.total(), limit);
265                    if context.percentage >= 50.0 {
266                        lines.push(format!(
267                            "      Context: {:.1}% of {} tokens",
268                            context.percentage, limit
269                        ));
270                    }
271                }
272            }
273
274            if model_snapshots.len() > 5 {
275                lines.push(format!(
276                    "    ... and {} more models",
277                    model_snapshots.len() - 5
278                ));
279            }
280            lines.push("".to_string());
281        }
282
283        // Provider performance metrics (TPS, latency)
284        let provider_snapshots = PROVIDER_METRICS.all_snapshots();
285        if !provider_snapshots.is_empty() {
286            lines.push("  PROVIDER PERFORMANCE:".to_string());
287
288            for snapshot in provider_snapshots.iter().take(5) {
289                if snapshot.request_count > 0 {
290                    lines.push(format!(
291                        "    {}: {:.1} avg TPS | {:.0}ms avg latency | {} reqs",
292                        snapshot.provider,
293                        snapshot.avg_tps,
294                        snapshot.avg_latency_ms,
295                        snapshot.request_count
296                    ));
297
298                    // Show p50/p95 if we have enough requests
299                    if snapshot.request_count >= 5 {
300                        lines.push(format!(
301                            "      p50: {:.1} TPS / {:.0}ms | p95: {:.1} TPS / {:.0}ms",
302                            snapshot.p50_tps,
303                            snapshot.p50_latency_ms,
304                            snapshot.p95_tps,
305                            snapshot.p95_latency_ms
306                        ));
307                    }
308                }
309            }
310            lines.push("".to_string());
311        }
312
313        // Cost estimates
314        lines.push("  COST ESTIMATES:".to_string());
315        lines.push(format!(
316            "    Session total: {}",
317            total_cost.format_currency()
318        ));
319        lines.push("    Based on approximate pricing".to_string());
320
321        lines
322    }
323}
324
325impl Default for TokenDisplay {
326    fn default() -> Self {
327        Self::new()
328    }
329}