Skip to main content

axonml_tui/views/
training.rs

1//! Training View - Monitor Training Progress
2//!
3//! Displays real-time training metrics including epochs, loss, accuracy,
4//! learning rate, and includes a sparkline for loss history.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use std::path::Path;
10
11use ratatui::{
12    layout::{Alignment, Constraint, Direction, Layout, Rect},
13    style::Style,
14    text::{Line, Span},
15    widgets::{Block, Borders, Gauge, Paragraph, Sparkline},
16    Frame,
17};
18
19use crate::theme::AxonmlTheme;
20
21// =============================================================================
22// Types
23// =============================================================================
24
25/// Training status
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum TrainingStatus {
28    Idle,
29    Running,
30    Paused,
31    Completed,
32    Failed,
33}
34
35impl TrainingStatus {
36    fn as_str(&self) -> &'static str {
37        match self {
38            TrainingStatus::Idle => "Idle",
39            TrainingStatus::Running => "Running",
40            TrainingStatus::Paused => "Paused",
41            TrainingStatus::Completed => "Completed",
42            TrainingStatus::Failed => "Failed",
43        }
44    }
45
46    fn style(&self) -> Style {
47        match self {
48            TrainingStatus::Idle => AxonmlTheme::muted(),
49            TrainingStatus::Running => AxonmlTheme::success(),
50            TrainingStatus::Paused => AxonmlTheme::warning(),
51            TrainingStatus::Completed => AxonmlTheme::info(),
52            TrainingStatus::Failed => AxonmlTheme::error(),
53        }
54    }
55}
56
57/// Training metrics for a single epoch
58#[derive(Debug, Clone)]
59pub struct EpochMetrics {
60    pub epoch: usize,
61    pub train_loss: f32,
62    pub val_loss: Option<f32>,
63    pub train_acc: Option<f32>,
64    pub val_acc: Option<f32>,
65    pub learning_rate: f64,
66    pub duration_secs: f32,
67}
68
69/// Training configuration
70#[derive(Debug, Clone)]
71pub struct TrainingConfig {
72    pub total_epochs: usize,
73    pub batch_size: usize,
74    pub optimizer: String,
75    pub initial_lr: f64,
76    pub model_name: String,
77}
78
79/// Training session information
80#[derive(Debug, Clone)]
81pub struct TrainingSession {
82    pub config: TrainingConfig,
83    pub current_epoch: usize,
84    pub current_batch: usize,
85    pub total_batches: usize,
86    pub status: TrainingStatus,
87    pub metrics_history: Vec<EpochMetrics>,
88    pub loss_history: Vec<u64>,  // Scaled for sparkline
89    pub best_val_loss: Option<f32>,
90    pub best_epoch: Option<usize>,
91    pub elapsed_secs: f64,
92    pub eta_secs: Option<f64>,
93}
94
95// =============================================================================
96// Training View
97// =============================================================================
98
99/// Training progress view state
100pub struct TrainingView {
101    /// Current training session
102    pub session: Option<TrainingSession>,
103
104    /// Show detailed metrics
105    pub show_details: bool,
106}
107
108impl TrainingView {
109    /// Create a new training view with demo data
110    pub fn new() -> Self {
111        let mut view = Self {
112            session: None,
113            show_details: false,
114        };
115
116        // Load demo data
117        view.load_demo_session();
118        view
119    }
120
121    /// Load a demo training session for visualization
122    pub fn load_demo_session(&mut self) {
123        let metrics_history = vec![
124            EpochMetrics { epoch: 1, train_loss: 2.312, val_loss: Some(2.298), train_acc: Some(0.112), val_acc: Some(0.118), learning_rate: 0.001, duration_secs: 45.2 },
125            EpochMetrics { epoch: 2, train_loss: 1.845, val_loss: Some(1.756), train_acc: Some(0.342), val_acc: Some(0.358), learning_rate: 0.001, duration_secs: 44.8 },
126            EpochMetrics { epoch: 3, train_loss: 1.234, val_loss: Some(1.189), train_acc: Some(0.567), val_acc: Some(0.582), learning_rate: 0.001, duration_secs: 45.1 },
127            EpochMetrics { epoch: 4, train_loss: 0.856, val_loss: Some(0.823), train_acc: Some(0.712), val_acc: Some(0.724), learning_rate: 0.001, duration_secs: 44.9 },
128            EpochMetrics { epoch: 5, train_loss: 0.612, val_loss: Some(0.598), train_acc: Some(0.798), val_acc: Some(0.805), learning_rate: 0.0005, duration_secs: 45.3 },
129            EpochMetrics { epoch: 6, train_loss: 0.478, val_loss: Some(0.489), train_acc: Some(0.845), val_acc: Some(0.842), learning_rate: 0.0005, duration_secs: 45.0 },
130            EpochMetrics { epoch: 7, train_loss: 0.389, val_loss: Some(0.412), train_acc: Some(0.878), val_acc: Some(0.869), learning_rate: 0.0005, duration_secs: 44.7 },
131            EpochMetrics { epoch: 8, train_loss: 0.321, val_loss: Some(0.358), train_acc: Some(0.902), val_acc: Some(0.891), learning_rate: 0.00025, duration_secs: 45.2 },
132        ];
133
134        // Scale loss values for sparkline (0-100 range)
135        let loss_history: Vec<u64> = metrics_history
136            .iter()
137            .map(|m| ((2.5 - m.train_loss) / 2.5 * 100.0).max(0.0) as u64)
138            .collect();
139
140        let session = TrainingSession {
141            config: TrainingConfig {
142                total_epochs: 20,
143                batch_size: 64,
144                optimizer: "Adam".to_string(),
145                initial_lr: 0.001,
146                model_name: "mnist_classifier".to_string(),
147            },
148            current_epoch: 8,
149            current_batch: 720,
150            total_batches: 938,
151            status: TrainingStatus::Running,
152            metrics_history,
153            loss_history,
154            best_val_loss: Some(0.358),
155            best_epoch: Some(8),
156            elapsed_secs: 361.2,
157            eta_secs: Some(540.0),
158        };
159
160        self.session = Some(session);
161    }
162
163    /// Toggle detailed view
164    pub fn toggle_details(&mut self) {
165        self.show_details = !self.show_details;
166    }
167
168    /// Pause/resume training
169    pub fn toggle_pause(&mut self) {
170        if let Some(session) = &mut self.session {
171            session.status = match session.status {
172                TrainingStatus::Running => TrainingStatus::Paused,
173                TrainingStatus::Paused => TrainingStatus::Running,
174                other => other,
175            };
176        }
177    }
178
179    /// Watch a training log file
180    pub fn watch_log(&mut self, path: &Path) -> Result<(), String> {
181        // Try to parse the training log file
182        match self.parse_training_log(path) {
183            Ok(session) => {
184                self.session = Some(session);
185                Ok(())
186            }
187            Err(e) => {
188                // Log file not found or invalid - fall back to demo
189                eprintln!("Warning: Could not parse training log: {}", e);
190                self.load_demo_session();
191                Err(e)
192            }
193        }
194    }
195
196    /// Parse a training log file to extract metrics
197    fn parse_training_log(&self, path: &Path) -> Result<TrainingSession, String> {
198        let content = std::fs::read_to_string(path)
199            .map_err(|e| format!("Failed to read log file: {}", e))?;
200
201        let mut metrics_history = Vec::new();
202        let mut config = TrainingConfig {
203            total_epochs: 10,
204            batch_size: 32,
205            optimizer: "Adam".to_string(),
206            initial_lr: 0.001,
207            model_name: "model".to_string(),
208        };
209
210        // Parse JSON log format (one JSON object per line)
211        for line in content.lines() {
212            if line.trim().is_empty() {
213                continue;
214            }
215
216            // Try to parse as JSON
217            if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
218                // Extract epoch info
219                if let Some(epoch) = json.get("epoch").and_then(|v| v.as_u64()) {
220                    let train_loss = json.get("train_loss")
221                        .and_then(|v| v.as_f64())
222                        .unwrap_or(0.0) as f32;
223                    let val_loss = json.get("val_loss")
224                        .and_then(|v| v.as_f64())
225                        .map(|v| v as f32);
226                    let train_acc = json.get("train_acc")
227                        .and_then(|v| v.as_f64())
228                        .map(|v| v as f32);
229                    let val_acc = json.get("val_acc")
230                        .and_then(|v| v.as_f64())
231                        .map(|v| v as f32);
232                    let learning_rate = json.get("learning_rate")
233                        .and_then(|v| v.as_f64())
234                        .unwrap_or(0.001);
235                    let duration = json.get("duration_secs")
236                        .and_then(|v| v.as_f64())
237                        .unwrap_or(0.0) as f32;
238
239                    metrics_history.push(EpochMetrics {
240                        epoch: epoch as usize,
241                        train_loss,
242                        val_loss,
243                        train_acc,
244                        val_acc,
245                        learning_rate,
246                        duration_secs: duration,
247                    });
248                }
249
250                // Extract config info
251                if let Some(total) = json.get("total_epochs").and_then(|v| v.as_u64()) {
252                    config.total_epochs = total as usize;
253                }
254                if let Some(bs) = json.get("batch_size").and_then(|v| v.as_u64()) {
255                    config.batch_size = bs as usize;
256                }
257                if let Some(opt) = json.get("optimizer").and_then(|v| v.as_str()) {
258                    config.optimizer = opt.to_string();
259                }
260                if let Some(name) = json.get("model_name").and_then(|v| v.as_str()) {
261                    config.model_name = name.to_string();
262                }
263            }
264        }
265
266        if metrics_history.is_empty() {
267            return Err("No training metrics found in log file".to_string());
268        }
269
270        // Calculate derived values
271        let current_epoch = metrics_history.last().map(|m| m.epoch).unwrap_or(1);
272        let loss_history: Vec<u64> = metrics_history
273            .iter()
274            .map(|m| {
275                let max_loss = 3.0f32;
276                ((max_loss - m.train_loss.min(max_loss)) / max_loss * 100.0) as u64
277            })
278            .collect();
279
280        let (best_val_loss, best_epoch) = metrics_history
281            .iter()
282            .filter_map(|m| m.val_loss.map(|v| (v, m.epoch)))
283            .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
284            .unwrap_or((0.0, 1));
285
286        let total_duration: f32 = metrics_history.iter().map(|m| m.duration_secs).sum();
287        let avg_epoch_time = total_duration / metrics_history.len() as f32;
288        let total_epochs = config.total_epochs;
289        let remaining_epochs = total_epochs.saturating_sub(current_epoch);
290        let eta = (remaining_epochs as f32 * avg_epoch_time) as f64;
291
292        Ok(TrainingSession {
293            config,
294            current_epoch,
295            current_batch: 0,
296            total_batches: 100,
297            status: if current_epoch >= total_epochs {
298                TrainingStatus::Completed
299            } else {
300                TrainingStatus::Idle
301            },
302            metrics_history,
303            loss_history,
304            best_val_loss: Some(best_val_loss),
305            best_epoch: Some(best_epoch),
306            elapsed_secs: total_duration as f64,
307            eta_secs: if eta > 0.0 { Some(eta) } else { None },
308        })
309    }
310
311    /// Scroll up in the metrics history
312    pub fn scroll_up(&mut self) {
313        // Reserved for scrolling through history
314    }
315
316    /// Scroll down in the metrics history
317    pub fn scroll_down(&mut self) {
318        // Reserved for scrolling through history
319    }
320
321    /// Refresh training data
322    pub fn refresh(&mut self) {
323        // Reload demo data for now
324        self.load_demo_session();
325    }
326
327    /// Tick update for real-time animation
328    pub fn tick(&mut self) {
329        // In real implementation, would update metrics from training process
330        // For demo, we simulate small progress updates
331        if let Some(session) = &mut self.session {
332            if session.status == TrainingStatus::Running {
333                // Simulate batch progress
334                if session.current_batch < session.total_batches {
335                    session.current_batch += 1;
336                } else {
337                    // Next epoch
338                    session.current_batch = 0;
339                    if session.current_epoch < session.config.total_epochs {
340                        session.current_epoch += 1;
341                    }
342                }
343                session.elapsed_secs += 0.1;
344            }
345        }
346    }
347
348    /// Render the training view
349    pub fn render(&mut self, frame: &mut Frame, area: Rect) {
350        if let Some(session) = &self.session.clone() {
351            let chunks = Layout::default()
352                .direction(Direction::Vertical)
353                .constraints([
354                    Constraint::Length(7),  // Header with status
355                    Constraint::Length(3),  // Epoch progress bar
356                    Constraint::Length(3),  // Batch progress bar
357                    Constraint::Min(8),     // Metrics and sparkline
358                    Constraint::Length(10), // Epoch history
359                ])
360                .split(area);
361
362            self.render_header(frame, chunks[0], session);
363            self.render_epoch_progress(frame, chunks[1], session);
364            self.render_batch_progress(frame, chunks[2], session);
365            self.render_metrics(frame, chunks[3], session);
366            self.render_history(frame, chunks[4], session);
367        } else {
368            self.render_empty(frame, area);
369        }
370    }
371
372    fn render_header(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
373        let status_style = session.status.style();
374
375        let header_text = vec![
376            Line::from(vec![
377                Span::styled("Model: ", AxonmlTheme::muted()),
378                Span::styled(&session.config.model_name, AxonmlTheme::title()),
379                Span::raw("  "),
380                Span::styled("Status: ", AxonmlTheme::muted()),
381                Span::styled(session.status.as_str(), status_style),
382            ]),
383            Line::from(vec![
384                Span::styled("Optimizer: ", AxonmlTheme::muted()),
385                Span::styled(&session.config.optimizer, AxonmlTheme::accent()),
386                Span::raw("  "),
387                Span::styled("Batch Size: ", AxonmlTheme::muted()),
388                Span::styled(session.config.batch_size.to_string(), AxonmlTheme::metric_value()),
389                Span::raw("  "),
390                Span::styled("LR: ", AxonmlTheme::muted()),
391                Span::styled(
392                    format!("{:.6}", session.metrics_history.last().map(|m| m.learning_rate).unwrap_or(session.config.initial_lr)),
393                    AxonmlTheme::metric_value(),
394                ),
395            ]),
396            Line::from(vec![
397                Span::styled("Elapsed: ", AxonmlTheme::muted()),
398                Span::styled(format_duration(session.elapsed_secs), AxonmlTheme::accent()),
399                Span::raw("  "),
400                Span::styled("ETA: ", AxonmlTheme::muted()),
401                Span::styled(
402                    session.eta_secs.map(format_duration).unwrap_or_else(|| "--:--".to_string()),
403                    AxonmlTheme::accent(),
404                ),
405            ]),
406            Line::from(vec![
407                Span::styled("Best Val Loss: ", AxonmlTheme::muted()),
408                Span::styled(
409                    session.best_val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
410                    AxonmlTheme::success(),
411                ),
412                Span::raw("  "),
413                Span::styled("@ Epoch: ", AxonmlTheme::muted()),
414                Span::styled(
415                    session.best_epoch.map(|e| e.to_string()).unwrap_or_else(|| "-".to_string()),
416                    AxonmlTheme::success(),
417                ),
418            ]),
419        ];
420
421        let header = Paragraph::new(header_text)
422            .block(
423                Block::default()
424                    .borders(Borders::ALL)
425                    .border_style(AxonmlTheme::border())
426                    .title(Span::styled(" Training Session ", AxonmlTheme::header())),
427            );
428
429        frame.render_widget(header, area);
430    }
431
432    fn render_epoch_progress(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
433        let progress = session.current_epoch as f64 / session.config.total_epochs as f64;
434
435        let gauge = Gauge::default()
436            .block(
437                Block::default()
438                    .borders(Borders::ALL)
439                    .border_style(AxonmlTheme::border())
440                    .title(Span::styled(
441                        format!(" Epoch {}/{} ", session.current_epoch, session.config.total_epochs),
442                        AxonmlTheme::epoch(),
443                    )),
444            )
445            .gauge_style(AxonmlTheme::graph_primary())
446            .ratio(progress)
447            .label(format!("{:.1}%", progress * 100.0));
448
449        frame.render_widget(gauge, area);
450    }
451
452    fn render_batch_progress(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
453        let progress = session.current_batch as f64 / session.total_batches as f64;
454
455        let gauge = Gauge::default()
456            .block(
457                Block::default()
458                    .borders(Borders::ALL)
459                    .border_style(AxonmlTheme::border())
460                    .title(Span::styled(
461                        format!(" Batch {}/{} ", session.current_batch, session.total_batches),
462                        AxonmlTheme::muted(),
463                    )),
464            )
465            .gauge_style(AxonmlTheme::graph_secondary())
466            .ratio(progress)
467            .label(format!("{:.1}%", progress * 100.0));
468
469        frame.render_widget(gauge, area);
470    }
471
472    fn render_metrics(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
473        let chunks = Layout::default()
474            .direction(Direction::Horizontal)
475            .constraints([
476                Constraint::Percentage(50),  // Current metrics
477                Constraint::Percentage(50),  // Loss sparkline
478            ])
479            .split(area);
480
481        // Current metrics
482        let latest = session.metrics_history.last();
483        let metrics_text = if let Some(m) = latest {
484            let loss_style = if m.train_loss < 0.5 {
485                AxonmlTheme::loss_good()
486            } else if m.train_loss < 1.0 {
487                AxonmlTheme::loss_neutral()
488            } else {
489                AxonmlTheme::loss_bad()
490            };
491
492            vec![
493                Line::from(vec![
494                    Span::styled("Train Loss:  ", AxonmlTheme::metric_label()),
495                    Span::styled(format!("{:.4}", m.train_loss), loss_style),
496                ]),
497                Line::from(vec![
498                    Span::styled("Val Loss:    ", AxonmlTheme::metric_label()),
499                    Span::styled(
500                        m.val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
501                        AxonmlTheme::metric_value(),
502                    ),
503                ]),
504                Line::from(vec![
505                    Span::styled("Train Acc:   ", AxonmlTheme::metric_label()),
506                    Span::styled(
507                        m.train_acc.map(|v| format!("{:.2}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
508                        AxonmlTheme::success(),
509                    ),
510                ]),
511                Line::from(vec![
512                    Span::styled("Val Acc:     ", AxonmlTheme::metric_label()),
513                    Span::styled(
514                        m.val_acc.map(|v| format!("{:.2}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
515                        AxonmlTheme::success(),
516                    ),
517                ]),
518            ]
519        } else {
520            vec![Line::from(Span::styled("No metrics yet", AxonmlTheme::muted()))]
521        };
522
523        let metrics = Paragraph::new(metrics_text)
524            .block(
525                Block::default()
526                    .borders(Borders::ALL)
527                    .border_style(AxonmlTheme::border_focused())
528                    .title(Span::styled(" Current Metrics ", AxonmlTheme::header())),
529            );
530
531        frame.render_widget(metrics, chunks[0]);
532
533        // Loss sparkline
534        let sparkline = Sparkline::default()
535            .block(
536                Block::default()
537                    .borders(Borders::ALL)
538                    .border_style(AxonmlTheme::border())
539                    .title(Span::styled(" Loss Trend (inverted) ", AxonmlTheme::header())),
540            )
541            .data(&session.loss_history)
542            .style(AxonmlTheme::graph_primary());
543
544        frame.render_widget(sparkline, chunks[1]);
545    }
546
547    fn render_history(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
548        let history_lines: Vec<Line> = session
549            .metrics_history
550            .iter()
551            .rev()
552            .take(6)
553            .map(|m| {
554                Line::from(vec![
555                    Span::styled(format!("Epoch {:>2} ", m.epoch), AxonmlTheme::epoch()),
556                    Span::styled("Loss: ", AxonmlTheme::muted()),
557                    Span::styled(format!("{:.4}", m.train_loss), AxonmlTheme::metric_value()),
558                    Span::raw(" / "),
559                    Span::styled(
560                        m.val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
561                        AxonmlTheme::accent(),
562                    ),
563                    Span::raw("  "),
564                    Span::styled("Acc: ", AxonmlTheme::muted()),
565                    Span::styled(
566                        m.train_acc.map(|v| format!("{:.1}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
567                        AxonmlTheme::success(),
568                    ),
569                    Span::raw(" / "),
570                    Span::styled(
571                        m.val_acc.map(|v| format!("{:.1}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
572                        AxonmlTheme::success(),
573                    ),
574                    Span::raw("  "),
575                    Span::styled(format!("({:.1}s)", m.duration_secs), AxonmlTheme::muted()),
576                ])
577            })
578            .collect();
579
580        let history = Paragraph::new(history_lines)
581            .block(
582                Block::default()
583                    .borders(Borders::ALL)
584                    .border_style(AxonmlTheme::border())
585                    .title(Span::styled(" Epoch History (recent) ", AxonmlTheme::header())),
586            );
587
588        frame.render_widget(history, area);
589    }
590
591    fn render_empty(&self, frame: &mut Frame, area: Rect) {
592        let text = vec![
593            Line::from(""),
594            Line::from(Span::styled(
595                "No training session active",
596                AxonmlTheme::muted(),
597            )),
598            Line::from(""),
599            Line::from(Span::styled(
600                "Press 't' to start training",
601                AxonmlTheme::info(),
602            )),
603            Line::from(Span::styled(
604                "or load a model first with 'o'",
605                AxonmlTheme::muted(),
606            )),
607        ];
608
609        let paragraph = Paragraph::new(text)
610            .block(
611                Block::default()
612                    .borders(Borders::ALL)
613                    .border_style(AxonmlTheme::border())
614                    .title(Span::styled(" Training ", AxonmlTheme::header())),
615            )
616            .alignment(Alignment::Center);
617
618        frame.render_widget(paragraph, area);
619    }
620}
621
622impl Default for TrainingView {
623    fn default() -> Self {
624        Self::new()
625    }
626}
627
628// =============================================================================
629// Helpers
630// =============================================================================
631
632fn format_duration(secs: f64) -> String {
633    let total_secs = secs as u64;
634    let hours = total_secs / 3600;
635    let minutes = (total_secs % 3600) / 60;
636    let seconds = total_secs % 60;
637
638    if hours > 0 {
639        format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
640    } else {
641        format!("{:02}:{:02}", minutes, seconds)
642    }
643}