entrenar/train/tui/
progress.rs1use std::time::Instant;
6
7#[derive(Debug, Clone)]
9pub struct KalmanEta {
10 estimate: f64,
12 error_cov: f64,
14 process_noise: f64,
16 measurement_noise: f64,
18}
19
20impl Default for KalmanEta {
21 fn default() -> Self {
22 Self { estimate: 1.0, error_cov: 1.0, process_noise: 0.01, measurement_noise: 0.1 }
23 }
24}
25
26impl KalmanEta {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn update(&mut self, measured_duration: f64) {
34 let predicted_estimate = self.estimate;
36 let predicted_error = self.error_cov + self.process_noise;
37
38 let kalman_gain = predicted_error / (predicted_error + self.measurement_noise);
40 self.estimate = predicted_estimate + kalman_gain * (measured_duration - predicted_estimate);
41 self.error_cov = (1.0 - kalman_gain) * predicted_error;
42 }
43
44 pub fn eta_seconds(&self, remaining_steps: usize) -> f64 {
46 self.estimate * remaining_steps as f64
47 }
48
49 pub fn eta_string(&self, remaining_steps: usize) -> String {
51 let secs = self.eta_seconds(remaining_steps);
52 format_duration(secs)
53 }
54}
55
56pub fn format_duration(secs: f64) -> String {
58 if secs < 60.0 {
59 format!("{secs:.0}s")
60 } else if secs < 3600.0 {
61 let mins = (secs / 60.0).floor();
62 let s = (secs % 60.0).floor();
63 format!("{mins}m {s:02.0}s")
64 } else {
65 let hours = (secs / 3600.0).floor();
66 let mins = ((secs % 3600.0) / 60.0).floor();
67 format!("{hours}h {mins:02.0}m")
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct ProgressBar {
74 total: usize,
76 current: usize,
78 width: usize,
80 fill_char: char,
82 empty_char: char,
84 kalman: KalmanEta,
86 last_step_time: Option<Instant>,
88}
89
90impl ProgressBar {
91 pub fn new(total: usize, width: usize) -> Self {
93 Self {
94 total,
95 current: 0,
96 width,
97 fill_char: '█',
98 empty_char: '░',
99 kalman: KalmanEta::new(),
100 last_step_time: None,
101 }
102 }
103
104 pub fn update(&mut self, current: usize) {
106 let now = Instant::now();
107 if let Some(last_time) = self.last_step_time {
108 let elapsed = now.duration_since(last_time).as_secs_f64();
109 let steps = current.saturating_sub(self.current);
110 if steps > 0 {
111 let per_step = elapsed / steps as f64;
112 self.kalman.update(per_step);
113 }
114 }
115 self.current = current;
116 self.last_step_time = Some(now);
117 }
118
119 pub fn percent(&self) -> f32 {
121 if self.total == 0 {
122 return 100.0;
123 }
124 (self.current as f32 / self.total as f32) * 100.0
125 }
126
127 pub fn render(&self) -> String {
129 let percent = self.percent();
130 let filled = ((percent / 100.0) * self.width as f32).round() as usize;
131 let empty = self.width.saturating_sub(filled);
132
133 let bar: String = std::iter::repeat_n(self.fill_char, filled)
134 .chain(std::iter::repeat_n(self.empty_char, empty))
135 .collect();
136
137 let remaining = self.total.saturating_sub(self.current);
138 let eta = self.kalman.eta_string(remaining);
139
140 format!("[{bar}] {percent:>5.1}% │ ETA: {eta}")
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_format_duration_seconds() {
150 assert_eq!(format_duration(30.0), "30s");
151 assert_eq!(format_duration(59.9), "60s");
152 }
153
154 #[test]
155 fn test_format_duration_minutes() {
156 assert_eq!(format_duration(60.0), "1m 00s");
157 assert_eq!(format_duration(90.0), "1m 30s");
158 assert_eq!(format_duration(3599.0), "59m 59s");
159 }
160
161 #[test]
162 fn test_format_duration_hours() {
163 assert_eq!(format_duration(3600.0), "1h 00m");
164 assert_eq!(format_duration(5400.0), "1h 30m");
165 assert_eq!(format_duration(7200.0), "2h 00m");
166 }
167
168 #[test]
169 fn test_kalman_eta_new() {
170 let kalman = KalmanEta::new();
171 assert_eq!(kalman.estimate, 1.0);
172 }
173
174 #[test]
175 fn test_kalman_eta_update() {
176 let mut kalman = KalmanEta::new();
177 kalman.update(0.5);
178 assert!(kalman.estimate < 1.0);
179 assert!(kalman.estimate > 0.5);
180 }
181
182 #[test]
183 fn test_kalman_eta_seconds() {
184 let kalman = KalmanEta::new();
185 assert_eq!(kalman.eta_seconds(10), 10.0);
186 }
187
188 #[test]
189 fn test_progress_bar_new() {
190 let bar = ProgressBar::new(100, 20);
191 assert_eq!(bar.percent(), 0.0);
192 }
193
194 #[test]
195 fn test_progress_bar_percent() {
196 let mut bar = ProgressBar::new(100, 20);
197 bar.current = 50;
198 assert_eq!(bar.percent(), 50.0);
199 }
200
201 #[test]
202 fn test_progress_bar_percent_zero_total() {
203 let bar = ProgressBar::new(0, 20);
204 assert_eq!(bar.percent(), 100.0);
205 }
206
207 #[test]
208 fn test_progress_bar_render() {
209 let bar = ProgressBar::new(100, 10);
210 let rendered = bar.render();
211 assert!(rendered.contains('['));
212 assert!(rendered.contains(']'));
213 assert!(rendered.contains("ETA:"));
214 }
215}