Skip to main content

entrenar/train/tui/
progress.rs

1//! Progress Bar with Kalman-filtered ETA (ENT-058)
2//!
3//! Reference: Welch, G., & Bishop, G. (1995). "An Introduction to the Kalman Filter."
4
5use std::time::Instant;
6
7/// Kalman filter for ETA estimation.
8#[derive(Debug, Clone)]
9pub struct KalmanEta {
10    /// Estimated step duration (seconds)
11    estimate: f64,
12    /// Error covariance
13    error_cov: f64,
14    /// Process noise
15    process_noise: f64,
16    /// Measurement noise
17    measurement_noise: f64,
18}
19
20impl Default for KalmanEta {
21    fn default() -> Self {
22        Self { estimate: 1.0, error_cov: 1.0, process_noise: 0.01, measurement_noise: 0.1 }
23    }
24}
25
26impl KalmanEta {
27    /// Create a new Kalman filter for ETA estimation.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Update with a new step duration measurement.
33    pub fn update(&mut self, measured_duration: f64) {
34        // Prediction step
35        let predicted_estimate = self.estimate;
36        let predicted_error = self.error_cov + self.process_noise;
37
38        // Update step
39        let kalman_gain = predicted_error / (predicted_error + self.measurement_noise);
40        self.estimate = predicted_estimate + kalman_gain * (measured_duration - predicted_estimate);
41        self.error_cov = (1.0 - kalman_gain) * predicted_error;
42    }
43
44    /// Get estimated time remaining for N steps.
45    pub fn eta_seconds(&self, remaining_steps: usize) -> f64 {
46        self.estimate * remaining_steps as f64
47    }
48
49    /// Format ETA as human-readable string.
50    pub fn eta_string(&self, remaining_steps: usize) -> String {
51        let secs = self.eta_seconds(remaining_steps);
52        format_duration(secs)
53    }
54}
55
56/// Format duration in seconds to human-readable string.
57pub fn format_duration(secs: f64) -> String {
58    if secs < 60.0 {
59        format!("{secs:.0}s")
60    } else if secs < 3600.0 {
61        let mins = (secs / 60.0).floor();
62        let s = (secs % 60.0).floor();
63        format!("{mins}m {s:02.0}s")
64    } else {
65        let hours = (secs / 3600.0).floor();
66        let mins = ((secs % 3600.0) / 60.0).floor();
67        format!("{hours}h {mins:02.0}m")
68    }
69}
70
71/// Progress bar renderer.
72#[derive(Debug, Clone)]
73pub struct ProgressBar {
74    /// Total steps
75    total: usize,
76    /// Current step
77    current: usize,
78    /// Bar width in characters
79    width: usize,
80    /// Fill character
81    fill_char: char,
82    /// Empty character
83    empty_char: char,
84    /// Kalman filter for ETA
85    kalman: KalmanEta,
86    /// Last step time
87    last_step_time: Option<Instant>,
88}
89
90impl ProgressBar {
91    /// Create a new progress bar.
92    pub fn new(total: usize, width: usize) -> Self {
93        Self {
94            total,
95            current: 0,
96            width,
97            fill_char: '█',
98            empty_char: '░',
99            kalman: KalmanEta::new(),
100            last_step_time: None,
101        }
102    }
103
104    /// Update progress.
105    pub fn update(&mut self, current: usize) {
106        let now = Instant::now();
107        if let Some(last_time) = self.last_step_time {
108            let elapsed = now.duration_since(last_time).as_secs_f64();
109            let steps = current.saturating_sub(self.current);
110            if steps > 0 {
111                let per_step = elapsed / steps as f64;
112                self.kalman.update(per_step);
113            }
114        }
115        self.current = current;
116        self.last_step_time = Some(now);
117    }
118
119    /// Get progress percentage.
120    pub fn percent(&self) -> f32 {
121        if self.total == 0 {
122            return 100.0;
123        }
124        (self.current as f32 / self.total as f32) * 100.0
125    }
126
127    /// Render progress bar to string.
128    pub fn render(&self) -> String {
129        let percent = self.percent();
130        let filled = ((percent / 100.0) * self.width as f32).round() as usize;
131        let empty = self.width.saturating_sub(filled);
132
133        let bar: String = std::iter::repeat_n(self.fill_char, filled)
134            .chain(std::iter::repeat_n(self.empty_char, empty))
135            .collect();
136
137        let remaining = self.total.saturating_sub(self.current);
138        let eta = self.kalman.eta_string(remaining);
139
140        format!("[{bar}] {percent:>5.1}% │ ETA: {eta}")
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_format_duration_seconds() {
150        assert_eq!(format_duration(30.0), "30s");
151        assert_eq!(format_duration(59.9), "60s");
152    }
153
154    #[test]
155    fn test_format_duration_minutes() {
156        assert_eq!(format_duration(60.0), "1m 00s");
157        assert_eq!(format_duration(90.0), "1m 30s");
158        assert_eq!(format_duration(3599.0), "59m 59s");
159    }
160
161    #[test]
162    fn test_format_duration_hours() {
163        assert_eq!(format_duration(3600.0), "1h 00m");
164        assert_eq!(format_duration(5400.0), "1h 30m");
165        assert_eq!(format_duration(7200.0), "2h 00m");
166    }
167
168    #[test]
169    fn test_kalman_eta_new() {
170        let kalman = KalmanEta::new();
171        assert_eq!(kalman.estimate, 1.0);
172    }
173
174    #[test]
175    fn test_kalman_eta_update() {
176        let mut kalman = KalmanEta::new();
177        kalman.update(0.5);
178        assert!(kalman.estimate < 1.0);
179        assert!(kalman.estimate > 0.5);
180    }
181
182    #[test]
183    fn test_kalman_eta_seconds() {
184        let kalman = KalmanEta::new();
185        assert_eq!(kalman.eta_seconds(10), 10.0);
186    }
187
188    #[test]
189    fn test_progress_bar_new() {
190        let bar = ProgressBar::new(100, 20);
191        assert_eq!(bar.percent(), 0.0);
192    }
193
194    #[test]
195    fn test_progress_bar_percent() {
196        let mut bar = ProgressBar::new(100, 20);
197        bar.current = 50;
198        assert_eq!(bar.percent(), 50.0);
199    }
200
201    #[test]
202    fn test_progress_bar_percent_zero_total() {
203        let bar = ProgressBar::new(0, 20);
204        assert_eq!(bar.percent(), 100.0);
205    }
206
207    #[test]
208    fn test_progress_bar_render() {
209        let bar = ProgressBar::new(100, 10);
210        let rendered = bar.render();
211        assert!(rendered.contains('['));
212        assert!(rendered.contains(']'));
213        assert!(rendered.contains("ETA:"));
214    }
215}