Skip to main content

oxi/
footer_data.rs

1//! Footer data provider for TUI status display
2//!
3//! Provides utilities for gathering and formatting footer data
4//! such as model info, token usage, and keybinding hints.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9/// Footer data containing all status information
10#[derive(Debug, Clone, Default)]
11pub struct FooterData {
12    /// Current model name
13    pub model_name: Option<String>,
14    /// Input tokens used
15    pub input_tokens: u32,
16    /// Output tokens used
17    pub output_tokens: u32,
18    /// Cached tokens
19    pub cached_tokens: Option<u32>,
20    /// Estimated cost in USD
21    pub estimated_cost: Option<f64>,
22    /// Current git branch
23    pub git_branch: Option<String>,
24    /// Session duration
25    pub session_duration: Duration,
26    /// Keybinding hints
27    pub keybinding_hints: Vec<KeybindingHint>,
28    /// Extension status messages
29    pub extension_statuses: HashMap<String, String>,
30    /// Available provider count
31    pub available_providers: usize,
32}
33
34impl FooterData {
35    /// Create empty footer data
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Create with model name
41    pub fn with_model(mut self, model: &str) -> Self {
42        self.model_name = Some(model.to_string());
43        self
44    }
45
46    /// Create with token usage
47    pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
48        self.input_tokens = input;
49        self.output_tokens = output;
50        self
51    }
52
53    /// Create with git branch
54    pub fn with_git_branch(mut self, branch: Option<String>) -> Self {
55        self.git_branch = branch;
56        self
57    }
58
59    /// Format footer as display string
60    pub fn format(&self) -> String {
61        let mut parts = Vec::new();
62
63        // Model name
64        if let Some(model) = &self.model_name {
65            parts.push(format!("Model: {}", model));
66        }
67
68        // Token usage
69        if self.input_tokens > 0 || self.output_tokens > 0 {
70            let mut tokens_str = format!("Tokens: {}/{}", self.input_tokens, self.output_tokens);
71            if let Some(cached) = self.cached_tokens {
72                tokens_str.push_str(&format!(" (+{} cached)", cached));
73            }
74            parts.push(tokens_str);
75        }
76
77        // Cost
78        if let Some(cost) = self.estimated_cost {
79            parts.push(format!("Cost: ${:.4}", cost));
80        }
81
82        // Git branch
83        if let Some(branch) = &self.git_branch {
84            parts.push(format!("Branch: {}", branch));
85        }
86
87        // Session duration
88        if self.session_duration.as_secs() > 0 {
89            parts.push(format!("Duration: {}", format_duration(self.session_duration)));
90        }
91
92        parts.join(" | ")
93    }
94
95    /// Get total tokens
96    pub fn total_tokens(&self) -> u32 {
97        self.input_tokens + self.output_tokens
98    }
99
100    /// Check if there is any data to display
101    pub fn is_empty(&self) -> bool {
102        self.model_name.is_none()
103            && self.input_tokens == 0
104            && self.output_tokens == 0
105            && self.git_branch.is_none()
106            && self.session_duration.is_zero()
107            && self.extension_statuses.is_empty()
108    }
109}
110
111/// A keybinding hint for display
112#[derive(Debug, Clone)]
113pub struct KeybindingHint {
114    /// The key sequence
115    pub keys: String,
116    /// Description of the action
117    pub description: String,
118}
119
120impl KeybindingHint {
121    pub fn new(keys: &str, description: &str) -> Self {
122        Self {
123            keys: keys.to_string(),
124            description: description.to_string(),
125        }
126    }
127}
128
129/// Session timer for tracking duration
130pub struct SessionTimer {
131    start: Instant,
132}
133
134impl SessionTimer {
135    /// Create a new timer starting now
136    pub fn new() -> Self {
137        Self {
138            start: Instant::now(),
139        }
140    }
141
142    /// Get the elapsed duration
143    pub fn elapsed(&self) -> Duration {
144        self.start.elapsed()
145    }
146
147    /// Reset the timer
148    pub fn reset(&mut self) {
149        self.start = Instant::now();
150    }
151}
152
153impl Default for SessionTimer {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159/// Format a duration in a human-readable format
160pub fn format_duration(duration: Duration) -> String {
161    let total_secs = duration.as_secs();
162
163    if total_secs < 60 {
164        return format!("{}s", total_secs);
165    }
166
167    let minutes = total_secs / 60;
168    if minutes < 60 {
169        let seconds = total_secs % 60;
170        return format!("{}m {}s", minutes, seconds);
171    }
172
173    let hours = minutes / 60;
174    let mins = minutes % 60;
175    if hours < 24 {
176        return format!("{}h {}m", hours, mins);
177    }
178
179    let days = hours / 24;
180    let hrs = hours % 24;
181    format!("{}d {}h", days, hrs)
182}
183
184/// Cost estimation for different models
185/// Note: These are rough estimates based on public pricing
186pub struct CostEstimator {
187    /// Price per 1M input tokens
188    input_price_per_m: HashMap<String, f64>,
189    /// Price per 1M output tokens
190    output_price_per_m: HashMap<String, f64>,
191}
192
193impl CostEstimator {
194    /// Create a new cost estimator with default prices
195    pub fn new() -> Self {
196        let mut input_price_per_m = HashMap::new();
197        let mut output_price_per_m = HashMap::new();
198
199        // Anthropic models (approximate)
200        input_price_per_m.insert("claude".to_string(), 3.0);
201        output_price_per_m.insert("claude".to_string(), 15.0);
202
203        // OpenAI models
204        input_price_per_m.insert("gpt-4".to_string(), 30.0);
205        output_price_per_m.insert("gpt-4".to_string(), 60.0);
206        input_price_per_m.insert("gpt-3.5".to_string(), 0.5);
207        output_price_per_m.insert("gpt-3.5".to_string(), 1.5);
208
209        // Google models
210        input_price_per_m.insert("gemini".to_string(), 0.125);
211        output_price_per_m.insert("gemini".to_string(), 0.5);
212
213        Self {
214            input_price_per_m,
215            output_price_per_m,
216        }
217    }
218
219    /// Estimate cost for a given model and token usage
220    pub fn estimate(&self, model: &str, input_tokens: u32, output_tokens: u32) -> Option<f64> {
221        // Find the matching price tier
222        let model_lower = model.to_lowercase();
223
224        let input_price = self
225            .input_price_per_m
226            .iter()
227            .find(|(name, _)| model_lower.contains(&name.to_lowercase()))
228            .map(|(_, price)| *price);
229
230        let output_price = self
231            .output_price_per_m
232            .iter()
233            .find(|(name, _)| model_lower.contains(&name.to_lowercase()))
234            .map(|(_, price)| *price);
235
236        match (input_price, output_price) {
237            (Some(inp), Some(outp)) => {
238                let input_cost = (input_tokens as f64 / 1_000_000.0) * inp;
239                let output_cost = (output_tokens as f64 / 1_000_000.0) * outp;
240                Some(input_cost + output_cost)
241            }
242            _ => None,
243        }
244    }
245}
246
247impl Default for CostEstimator {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253/// Footer data provider trait
254pub trait FooterDataProvider: Send + Sync {
255    /// Get the current footer data
256    fn get_footer_data(&self) -> FooterData;
257
258    /// Get the model name
259    fn get_model_name(&self) -> Option<String>;
260
261    /// Get git branch
262    fn get_git_branch(&self) -> Option<String>;
263
264    /// Get token counts
265    fn get_token_counts(&self) -> (u32, u32);
266
267    /// Get session duration
268    fn get_session_duration(&self) -> Duration;
269
270    /// Get keybinding hints
271    fn get_keybinding_hints(&self) -> Vec<KeybindingHint>;
272}
273
274/// Simple footer data provider implementation
275pub struct SimpleFooterDataProvider {
276    model_name: Option<String>,
277    git_branch: Option<String>,
278    input_tokens: u32,
279    output_tokens: u32,
280    cached_tokens: Option<u32>,
281    session_timer: SessionTimer,
282    keybinding_hints: Vec<KeybindingHint>,
283    extension_statuses: HashMap<String, String>,
284    available_providers: usize,
285}
286
287impl SimpleFooterDataProvider {
288    /// Create a new provider
289    pub fn new() -> Self {
290        Self {
291            model_name: None,
292            git_branch: None,
293            input_tokens: 0,
294            output_tokens: 0,
295            cached_tokens: None,
296            session_timer: SessionTimer::new(),
297            keybinding_hints: Vec::new(),
298            extension_statuses: HashMap::new(),
299            available_providers: 0,
300        }
301    }
302
303    /// Set the model name
304    pub fn with_model(mut self, model: Option<String>) -> Self {
305        self.model_name = model;
306        self
307    }
308
309    /// Set the git branch
310    pub fn with_git_branch(mut self, branch: Option<String>) -> Self {
311        self.git_branch = branch;
312        self
313    }
314
315    /// Set token counts
316    pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
317        self.input_tokens = input;
318        self.output_tokens = output;
319        self
320    }
321
322    /// Add a keybinding hint
323    pub fn add_hint(mut self, keys: &str, description: &str) -> Self {
324        self.keybinding_hints.push(KeybindingHint::new(keys, description));
325        self
326    }
327
328    /// Set the available provider count
329    pub fn with_providers(mut self, count: usize) -> Self {
330        self.available_providers = count;
331        self
332    }
333
334    /// Update token counts
335    pub fn update_tokens(&mut self, input: u32, output: u32) {
336        self.input_tokens = input;
337        self.output_tokens = output;
338    }
339
340    /// Add an extension status
341    pub fn set_extension_status(&mut self, key: &str, status: Option<&str>) {
342        if let Some(s) = status {
343            self.extension_statuses.insert(key.to_string(), s.to_string());
344        } else {
345            self.extension_statuses.remove(key);
346        }
347    }
348}
349
350impl Default for SimpleFooterDataProvider {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356impl FooterDataProvider for SimpleFooterDataProvider {
357    fn get_footer_data(&self) -> FooterData {
358        let mut data = FooterData {
359            model_name: self.model_name.clone(),
360            input_tokens: self.input_tokens,
361            output_tokens: self.output_tokens,
362            cached_tokens: self.cached_tokens,
363            git_branch: self.git_branch.clone(),
364            session_duration: self.session_timer.elapsed(),
365            keybinding_hints: self.keybinding_hints.clone(),
366            extension_statuses: self.extension_statuses.clone(),
367            available_providers: self.available_providers,
368            estimated_cost: None,
369        };
370
371        // Calculate cost if model is available
372        if let Some(ref model) = self.model_name {
373            let cost_estimator = CostEstimator::new();
374            data.estimated_cost = cost_estimator.estimate(
375                model,
376                self.input_tokens,
377                self.output_tokens,
378            );
379        }
380
381        data
382    }
383
384    fn get_model_name(&self) -> Option<String> {
385        self.model_name.clone()
386    }
387
388    fn get_git_branch(&self) -> Option<String> {
389        self.git_branch.clone()
390    }
391
392    fn get_token_counts(&self) -> (u32, u32) {
393        (self.input_tokens, self.output_tokens)
394    }
395
396    fn get_session_duration(&self) -> Duration {
397        self.session_timer.elapsed()
398    }
399
400    fn get_keybinding_hints(&self) -> Vec<KeybindingHint> {
401        self.keybinding_hints.clone()
402    }
403}
404
405/// Extension status helper
406pub struct ExtensionStatusTracker {
407    statuses: HashMap<String, String>,
408}
409
410impl ExtensionStatusTracker {
411    pub fn new() -> Self {
412        Self {
413            statuses: HashMap::new(),
414        }
415    }
416
417    pub fn set(&mut self, extension: &str, status: &str) {
418        self.statuses.insert(extension.to_string(), status.to_string());
419    }
420
421    pub fn clear(&mut self, extension: &str) {
422        self.statuses.remove(extension);
423    }
424
425    pub fn get_all(&self) -> &HashMap<String, String> {
426        &self.statuses
427    }
428}
429
430impl Default for ExtensionStatusTracker {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_footer_data_new() {
442        let data = FooterData::new();
443        assert!(data.is_empty());
444    }
445
446    #[test]
447    fn test_footer_data_with_model() {
448        let data = FooterData::new().with_model("claude-3.5-sonnet");
449        assert_eq!(data.model_name, Some("claude-3.5-sonnet".to_string()));
450    }
451
452    #[test]
453    fn test_footer_data_with_tokens() {
454        let data = FooterData::new().with_tokens(1000, 500);
455        assert_eq!(data.input_tokens, 1000);
456        assert_eq!(data.output_tokens, 500);
457    }
458
459    #[test]
460    fn test_footer_data_format() {
461        let data = FooterData::new()
462            .with_model("gpt-4")
463            .with_tokens(100, 50);
464        let formatted = data.format();
465        assert!(formatted.contains("gpt-4"));
466        assert!(formatted.contains("100/50"));
467    }
468
469    #[test]
470    fn test_footer_data_total_tokens() {
471        let data = FooterData::new().with_tokens(100, 50);
472        assert_eq!(data.total_tokens(), 150);
473    }
474
475    #[test]
476    fn test_session_timer() {
477        let timer = SessionTimer::new();
478        std::thread::sleep(Duration::from_millis(10));
479        let elapsed = timer.elapsed();
480        assert!(elapsed.as_millis() >= 10);
481    }
482
483    #[test]
484    fn test_session_timer_reset() {
485        let mut timer = SessionTimer::new();
486        std::thread::sleep(Duration::from_millis(10));
487        timer.reset();
488        let elapsed = timer.elapsed();
489        assert!(elapsed.as_millis() < 10);
490    }
491
492    #[test]
493    fn test_format_duration() {
494        assert_eq!(format_duration(Duration::from_secs(30)), "30s");
495        assert_eq!(format_duration(Duration::from_secs(90)), "1m 30s");
496        assert_eq!(format_duration(Duration::from_secs(3661)), "1h 1m");
497        assert_eq!(format_duration(Duration::from_secs(86401)), "1d 0h");
498    }
499
500    #[test]
501    fn test_cost_estimator() {
502        let estimator = CostEstimator::new();
503
504        // Test with Claude model
505        let cost = estimator.estimate("claude-3.5-sonnet", 1_000_000, 1_000_000);
506        assert!(cost.is_some());
507        assert!(cost.unwrap() > 0.0);
508
509        // Test with unknown model
510        let cost = estimator.estimate("unknown-model", 1000, 500);
511        assert!(cost.is_none());
512    }
513
514    #[test]
515    fn test_keybinding_hint() {
516        let hint = KeybindingHint::new("Ctrl+C", "Cancel");
517        assert_eq!(hint.keys, "Ctrl+C");
518        assert_eq!(hint.description, "Cancel");
519    }
520
521    #[test]
522    fn test_simple_provider() {
523        let provider = SimpleFooterDataProvider::new()
524            .with_model(Some("gpt-4".to_string()))
525            .with_tokens(100, 50);
526
527        assert_eq!(provider.get_model_name(), Some("gpt-4".to_string()));
528        assert_eq!(provider.get_token_counts(), (100, 50));
529    }
530
531    #[test]
532    fn test_simple_provider_footer_data() {
533        let provider = SimpleFooterDataProvider::new()
534            .with_model(Some("claude".to_string()))
535            .with_tokens(1000, 500);
536
537        let footer = provider.get_footer_data();
538        assert_eq!(footer.model_name, Some("claude".to_string()));
539        assert!(footer.estimated_cost.is_some());
540    }
541
542    #[test]
543    fn test_extension_status_tracker() {
544        let mut tracker = ExtensionStatusTracker::new();
545
546        tracker.set("my-extension", "Working...");
547        assert_eq!(tracker.get_all().get("my-extension"), Some(&"Working...".to_string()));
548
549        tracker.clear("my-extension");
550        assert!(tracker.get_all().get("my-extension").is_none());
551    }
552
553    #[test]
554    fn test_footer_data_with_git_branch() {
555        let data = FooterData::new().with_git_branch(Some("main".to_string()));
556        assert_eq!(data.git_branch, Some("main".to_string()));
557    }
558
559    #[test]
560    fn test_footer_data_extension_statuses() {
561        let mut data = FooterData::new();
562        data.extension_statuses.insert("ext1".to_string(), "status1".to_string());
563        assert_eq!(data.extension_statuses.len(), 1);
564    }
565}