1use std::path::Path;
10
11use ratatui::{
12 layout::{Alignment, Constraint, Direction, Layout, Rect},
13 style::Style,
14 text::{Line, Span},
15 widgets::{Block, Borders, Gauge, Paragraph, Sparkline},
16 Frame,
17};
18
19use crate::theme::AxonmlTheme;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum TrainingStatus {
28 Idle,
29 Running,
30 Paused,
31 Completed,
32 Failed,
33}
34
35impl TrainingStatus {
36 fn as_str(&self) -> &'static str {
37 match self {
38 TrainingStatus::Idle => "Idle",
39 TrainingStatus::Running => "Running",
40 TrainingStatus::Paused => "Paused",
41 TrainingStatus::Completed => "Completed",
42 TrainingStatus::Failed => "Failed",
43 }
44 }
45
46 fn style(&self) -> Style {
47 match self {
48 TrainingStatus::Idle => AxonmlTheme::muted(),
49 TrainingStatus::Running => AxonmlTheme::success(),
50 TrainingStatus::Paused => AxonmlTheme::warning(),
51 TrainingStatus::Completed => AxonmlTheme::info(),
52 TrainingStatus::Failed => AxonmlTheme::error(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct EpochMetrics {
60 pub epoch: usize,
61 pub train_loss: f32,
62 pub val_loss: Option<f32>,
63 pub train_acc: Option<f32>,
64 pub val_acc: Option<f32>,
65 pub learning_rate: f64,
66 pub duration_secs: f32,
67}
68
69#[derive(Debug, Clone)]
71pub struct TrainingConfig {
72 pub total_epochs: usize,
73 pub batch_size: usize,
74 pub optimizer: String,
75 pub initial_lr: f64,
76 pub model_name: String,
77}
78
79#[derive(Debug, Clone)]
81pub struct TrainingSession {
82 pub config: TrainingConfig,
83 pub current_epoch: usize,
84 pub current_batch: usize,
85 pub total_batches: usize,
86 pub status: TrainingStatus,
87 pub metrics_history: Vec<EpochMetrics>,
88 pub loss_history: Vec<u64>, pub best_val_loss: Option<f32>,
90 pub best_epoch: Option<usize>,
91 pub elapsed_secs: f64,
92 pub eta_secs: Option<f64>,
93}
94
95pub struct TrainingView {
101 pub session: Option<TrainingSession>,
103
104 pub show_details: bool,
106}
107
108impl TrainingView {
109 pub fn new() -> Self {
111 let mut view = Self {
112 session: None,
113 show_details: false,
114 };
115
116 view.load_demo_session();
118 view
119 }
120
121 pub fn load_demo_session(&mut self) {
123 let metrics_history = vec![
124 EpochMetrics { epoch: 1, train_loss: 2.312, val_loss: Some(2.298), train_acc: Some(0.112), val_acc: Some(0.118), learning_rate: 0.001, duration_secs: 45.2 },
125 EpochMetrics { epoch: 2, train_loss: 1.845, val_loss: Some(1.756), train_acc: Some(0.342), val_acc: Some(0.358), learning_rate: 0.001, duration_secs: 44.8 },
126 EpochMetrics { epoch: 3, train_loss: 1.234, val_loss: Some(1.189), train_acc: Some(0.567), val_acc: Some(0.582), learning_rate: 0.001, duration_secs: 45.1 },
127 EpochMetrics { epoch: 4, train_loss: 0.856, val_loss: Some(0.823), train_acc: Some(0.712), val_acc: Some(0.724), learning_rate: 0.001, duration_secs: 44.9 },
128 EpochMetrics { epoch: 5, train_loss: 0.612, val_loss: Some(0.598), train_acc: Some(0.798), val_acc: Some(0.805), learning_rate: 0.0005, duration_secs: 45.3 },
129 EpochMetrics { epoch: 6, train_loss: 0.478, val_loss: Some(0.489), train_acc: Some(0.845), val_acc: Some(0.842), learning_rate: 0.0005, duration_secs: 45.0 },
130 EpochMetrics { epoch: 7, train_loss: 0.389, val_loss: Some(0.412), train_acc: Some(0.878), val_acc: Some(0.869), learning_rate: 0.0005, duration_secs: 44.7 },
131 EpochMetrics { epoch: 8, train_loss: 0.321, val_loss: Some(0.358), train_acc: Some(0.902), val_acc: Some(0.891), learning_rate: 0.00025, duration_secs: 45.2 },
132 ];
133
134 let loss_history: Vec<u64> = metrics_history
136 .iter()
137 .map(|m| ((2.5 - m.train_loss) / 2.5 * 100.0).max(0.0) as u64)
138 .collect();
139
140 let session = TrainingSession {
141 config: TrainingConfig {
142 total_epochs: 20,
143 batch_size: 64,
144 optimizer: "Adam".to_string(),
145 initial_lr: 0.001,
146 model_name: "mnist_classifier".to_string(),
147 },
148 current_epoch: 8,
149 current_batch: 720,
150 total_batches: 938,
151 status: TrainingStatus::Running,
152 metrics_history,
153 loss_history,
154 best_val_loss: Some(0.358),
155 best_epoch: Some(8),
156 elapsed_secs: 361.2,
157 eta_secs: Some(540.0),
158 };
159
160 self.session = Some(session);
161 }
162
163 pub fn toggle_details(&mut self) {
165 self.show_details = !self.show_details;
166 }
167
168 pub fn toggle_pause(&mut self) {
170 if let Some(session) = &mut self.session {
171 session.status = match session.status {
172 TrainingStatus::Running => TrainingStatus::Paused,
173 TrainingStatus::Paused => TrainingStatus::Running,
174 other => other,
175 };
176 }
177 }
178
179 pub fn watch_log(&mut self, path: &Path) -> Result<(), String> {
181 match self.parse_training_log(path) {
183 Ok(session) => {
184 self.session = Some(session);
185 Ok(())
186 }
187 Err(e) => {
188 eprintln!("Warning: Could not parse training log: {}", e);
190 self.load_demo_session();
191 Err(e)
192 }
193 }
194 }
195
196 fn parse_training_log(&self, path: &Path) -> Result<TrainingSession, String> {
198 let content = std::fs::read_to_string(path)
199 .map_err(|e| format!("Failed to read log file: {}", e))?;
200
201 let mut metrics_history = Vec::new();
202 let mut config = TrainingConfig {
203 total_epochs: 10,
204 batch_size: 32,
205 optimizer: "Adam".to_string(),
206 initial_lr: 0.001,
207 model_name: "model".to_string(),
208 };
209
210 for line in content.lines() {
212 if line.trim().is_empty() {
213 continue;
214 }
215
216 if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
218 if let Some(epoch) = json.get("epoch").and_then(|v| v.as_u64()) {
220 let train_loss = json.get("train_loss")
221 .and_then(|v| v.as_f64())
222 .unwrap_or(0.0) as f32;
223 let val_loss = json.get("val_loss")
224 .and_then(|v| v.as_f64())
225 .map(|v| v as f32);
226 let train_acc = json.get("train_acc")
227 .and_then(|v| v.as_f64())
228 .map(|v| v as f32);
229 let val_acc = json.get("val_acc")
230 .and_then(|v| v.as_f64())
231 .map(|v| v as f32);
232 let learning_rate = json.get("learning_rate")
233 .and_then(|v| v.as_f64())
234 .unwrap_or(0.001);
235 let duration = json.get("duration_secs")
236 .and_then(|v| v.as_f64())
237 .unwrap_or(0.0) as f32;
238
239 metrics_history.push(EpochMetrics {
240 epoch: epoch as usize,
241 train_loss,
242 val_loss,
243 train_acc,
244 val_acc,
245 learning_rate,
246 duration_secs: duration,
247 });
248 }
249
250 if let Some(total) = json.get("total_epochs").and_then(|v| v.as_u64()) {
252 config.total_epochs = total as usize;
253 }
254 if let Some(bs) = json.get("batch_size").and_then(|v| v.as_u64()) {
255 config.batch_size = bs as usize;
256 }
257 if let Some(opt) = json.get("optimizer").and_then(|v| v.as_str()) {
258 config.optimizer = opt.to_string();
259 }
260 if let Some(name) = json.get("model_name").and_then(|v| v.as_str()) {
261 config.model_name = name.to_string();
262 }
263 }
264 }
265
266 if metrics_history.is_empty() {
267 return Err("No training metrics found in log file".to_string());
268 }
269
270 let current_epoch = metrics_history.last().map(|m| m.epoch).unwrap_or(1);
272 let loss_history: Vec<u64> = metrics_history
273 .iter()
274 .map(|m| {
275 let max_loss = 3.0f32;
276 ((max_loss - m.train_loss.min(max_loss)) / max_loss * 100.0) as u64
277 })
278 .collect();
279
280 let (best_val_loss, best_epoch) = metrics_history
281 .iter()
282 .filter_map(|m| m.val_loss.map(|v| (v, m.epoch)))
283 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
284 .unwrap_or((0.0, 1));
285
286 let total_duration: f32 = metrics_history.iter().map(|m| m.duration_secs).sum();
287 let avg_epoch_time = total_duration / metrics_history.len() as f32;
288 let total_epochs = config.total_epochs;
289 let remaining_epochs = total_epochs.saturating_sub(current_epoch);
290 let eta = (remaining_epochs as f32 * avg_epoch_time) as f64;
291
292 Ok(TrainingSession {
293 config,
294 current_epoch,
295 current_batch: 0,
296 total_batches: 100,
297 status: if current_epoch >= total_epochs {
298 TrainingStatus::Completed
299 } else {
300 TrainingStatus::Idle
301 },
302 metrics_history,
303 loss_history,
304 best_val_loss: Some(best_val_loss),
305 best_epoch: Some(best_epoch),
306 elapsed_secs: total_duration as f64,
307 eta_secs: if eta > 0.0 { Some(eta) } else { None },
308 })
309 }
310
311 pub fn scroll_up(&mut self) {
313 }
315
316 pub fn scroll_down(&mut self) {
318 }
320
321 pub fn refresh(&mut self) {
323 self.load_demo_session();
325 }
326
327 pub fn tick(&mut self) {
329 if let Some(session) = &mut self.session {
332 if session.status == TrainingStatus::Running {
333 if session.current_batch < session.total_batches {
335 session.current_batch += 1;
336 } else {
337 session.current_batch = 0;
339 if session.current_epoch < session.config.total_epochs {
340 session.current_epoch += 1;
341 }
342 }
343 session.elapsed_secs += 0.1;
344 }
345 }
346 }
347
348 pub fn render(&mut self, frame: &mut Frame, area: Rect) {
350 if let Some(session) = &self.session.clone() {
351 let chunks = Layout::default()
352 .direction(Direction::Vertical)
353 .constraints([
354 Constraint::Length(7), Constraint::Length(3), Constraint::Length(3), Constraint::Min(8), Constraint::Length(10), ])
360 .split(area);
361
362 self.render_header(frame, chunks[0], session);
363 self.render_epoch_progress(frame, chunks[1], session);
364 self.render_batch_progress(frame, chunks[2], session);
365 self.render_metrics(frame, chunks[3], session);
366 self.render_history(frame, chunks[4], session);
367 } else {
368 self.render_empty(frame, area);
369 }
370 }
371
372 fn render_header(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
373 let status_style = session.status.style();
374
375 let header_text = vec![
376 Line::from(vec![
377 Span::styled("Model: ", AxonmlTheme::muted()),
378 Span::styled(&session.config.model_name, AxonmlTheme::title()),
379 Span::raw(" "),
380 Span::styled("Status: ", AxonmlTheme::muted()),
381 Span::styled(session.status.as_str(), status_style),
382 ]),
383 Line::from(vec![
384 Span::styled("Optimizer: ", AxonmlTheme::muted()),
385 Span::styled(&session.config.optimizer, AxonmlTheme::accent()),
386 Span::raw(" "),
387 Span::styled("Batch Size: ", AxonmlTheme::muted()),
388 Span::styled(session.config.batch_size.to_string(), AxonmlTheme::metric_value()),
389 Span::raw(" "),
390 Span::styled("LR: ", AxonmlTheme::muted()),
391 Span::styled(
392 format!("{:.6}", session.metrics_history.last().map(|m| m.learning_rate).unwrap_or(session.config.initial_lr)),
393 AxonmlTheme::metric_value(),
394 ),
395 ]),
396 Line::from(vec![
397 Span::styled("Elapsed: ", AxonmlTheme::muted()),
398 Span::styled(format_duration(session.elapsed_secs), AxonmlTheme::accent()),
399 Span::raw(" "),
400 Span::styled("ETA: ", AxonmlTheme::muted()),
401 Span::styled(
402 session.eta_secs.map(format_duration).unwrap_or_else(|| "--:--".to_string()),
403 AxonmlTheme::accent(),
404 ),
405 ]),
406 Line::from(vec![
407 Span::styled("Best Val Loss: ", AxonmlTheme::muted()),
408 Span::styled(
409 session.best_val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
410 AxonmlTheme::success(),
411 ),
412 Span::raw(" "),
413 Span::styled("@ Epoch: ", AxonmlTheme::muted()),
414 Span::styled(
415 session.best_epoch.map(|e| e.to_string()).unwrap_or_else(|| "-".to_string()),
416 AxonmlTheme::success(),
417 ),
418 ]),
419 ];
420
421 let header = Paragraph::new(header_text)
422 .block(
423 Block::default()
424 .borders(Borders::ALL)
425 .border_style(AxonmlTheme::border())
426 .title(Span::styled(" Training Session ", AxonmlTheme::header())),
427 );
428
429 frame.render_widget(header, area);
430 }
431
432 fn render_epoch_progress(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
433 let progress = session.current_epoch as f64 / session.config.total_epochs as f64;
434
435 let gauge = Gauge::default()
436 .block(
437 Block::default()
438 .borders(Borders::ALL)
439 .border_style(AxonmlTheme::border())
440 .title(Span::styled(
441 format!(" Epoch {}/{} ", session.current_epoch, session.config.total_epochs),
442 AxonmlTheme::epoch(),
443 )),
444 )
445 .gauge_style(AxonmlTheme::graph_primary())
446 .ratio(progress)
447 .label(format!("{:.1}%", progress * 100.0));
448
449 frame.render_widget(gauge, area);
450 }
451
452 fn render_batch_progress(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
453 let progress = session.current_batch as f64 / session.total_batches as f64;
454
455 let gauge = Gauge::default()
456 .block(
457 Block::default()
458 .borders(Borders::ALL)
459 .border_style(AxonmlTheme::border())
460 .title(Span::styled(
461 format!(" Batch {}/{} ", session.current_batch, session.total_batches),
462 AxonmlTheme::muted(),
463 )),
464 )
465 .gauge_style(AxonmlTheme::graph_secondary())
466 .ratio(progress)
467 .label(format!("{:.1}%", progress * 100.0));
468
469 frame.render_widget(gauge, area);
470 }
471
472 fn render_metrics(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
473 let chunks = Layout::default()
474 .direction(Direction::Horizontal)
475 .constraints([
476 Constraint::Percentage(50), Constraint::Percentage(50), ])
479 .split(area);
480
481 let latest = session.metrics_history.last();
483 let metrics_text = if let Some(m) = latest {
484 let loss_style = if m.train_loss < 0.5 {
485 AxonmlTheme::loss_good()
486 } else if m.train_loss < 1.0 {
487 AxonmlTheme::loss_neutral()
488 } else {
489 AxonmlTheme::loss_bad()
490 };
491
492 vec![
493 Line::from(vec![
494 Span::styled("Train Loss: ", AxonmlTheme::metric_label()),
495 Span::styled(format!("{:.4}", m.train_loss), loss_style),
496 ]),
497 Line::from(vec![
498 Span::styled("Val Loss: ", AxonmlTheme::metric_label()),
499 Span::styled(
500 m.val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
501 AxonmlTheme::metric_value(),
502 ),
503 ]),
504 Line::from(vec![
505 Span::styled("Train Acc: ", AxonmlTheme::metric_label()),
506 Span::styled(
507 m.train_acc.map(|v| format!("{:.2}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
508 AxonmlTheme::success(),
509 ),
510 ]),
511 Line::from(vec![
512 Span::styled("Val Acc: ", AxonmlTheme::metric_label()),
513 Span::styled(
514 m.val_acc.map(|v| format!("{:.2}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
515 AxonmlTheme::success(),
516 ),
517 ]),
518 ]
519 } else {
520 vec![Line::from(Span::styled("No metrics yet", AxonmlTheme::muted()))]
521 };
522
523 let metrics = Paragraph::new(metrics_text)
524 .block(
525 Block::default()
526 .borders(Borders::ALL)
527 .border_style(AxonmlTheme::border_focused())
528 .title(Span::styled(" Current Metrics ", AxonmlTheme::header())),
529 );
530
531 frame.render_widget(metrics, chunks[0]);
532
533 let sparkline = Sparkline::default()
535 .block(
536 Block::default()
537 .borders(Borders::ALL)
538 .border_style(AxonmlTheme::border())
539 .title(Span::styled(" Loss Trend (inverted) ", AxonmlTheme::header())),
540 )
541 .data(&session.loss_history)
542 .style(AxonmlTheme::graph_primary());
543
544 frame.render_widget(sparkline, chunks[1]);
545 }
546
547 fn render_history(&self, frame: &mut Frame, area: Rect, session: &TrainingSession) {
548 let history_lines: Vec<Line> = session
549 .metrics_history
550 .iter()
551 .rev()
552 .take(6)
553 .map(|m| {
554 Line::from(vec![
555 Span::styled(format!("Epoch {:>2} ", m.epoch), AxonmlTheme::epoch()),
556 Span::styled("Loss: ", AxonmlTheme::muted()),
557 Span::styled(format!("{:.4}", m.train_loss), AxonmlTheme::metric_value()),
558 Span::raw(" / "),
559 Span::styled(
560 m.val_loss.map(|v| format!("{:.4}", v)).unwrap_or_else(|| "-".to_string()),
561 AxonmlTheme::accent(),
562 ),
563 Span::raw(" "),
564 Span::styled("Acc: ", AxonmlTheme::muted()),
565 Span::styled(
566 m.train_acc.map(|v| format!("{:.1}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
567 AxonmlTheme::success(),
568 ),
569 Span::raw(" / "),
570 Span::styled(
571 m.val_acc.map(|v| format!("{:.1}%", v * 100.0)).unwrap_or_else(|| "-".to_string()),
572 AxonmlTheme::success(),
573 ),
574 Span::raw(" "),
575 Span::styled(format!("({:.1}s)", m.duration_secs), AxonmlTheme::muted()),
576 ])
577 })
578 .collect();
579
580 let history = Paragraph::new(history_lines)
581 .block(
582 Block::default()
583 .borders(Borders::ALL)
584 .border_style(AxonmlTheme::border())
585 .title(Span::styled(" Epoch History (recent) ", AxonmlTheme::header())),
586 );
587
588 frame.render_widget(history, area);
589 }
590
591 fn render_empty(&self, frame: &mut Frame, area: Rect) {
592 let text = vec![
593 Line::from(""),
594 Line::from(Span::styled(
595 "No training session active",
596 AxonmlTheme::muted(),
597 )),
598 Line::from(""),
599 Line::from(Span::styled(
600 "Press 't' to start training",
601 AxonmlTheme::info(),
602 )),
603 Line::from(Span::styled(
604 "or load a model first with 'o'",
605 AxonmlTheme::muted(),
606 )),
607 ];
608
609 let paragraph = Paragraph::new(text)
610 .block(
611 Block::default()
612 .borders(Borders::ALL)
613 .border_style(AxonmlTheme::border())
614 .title(Span::styled(" Training ", AxonmlTheme::header())),
615 )
616 .alignment(Alignment::Center);
617
618 frame.render_widget(paragraph, area);
619 }
620}
621
622impl Default for TrainingView {
623 fn default() -> Self {
624 Self::new()
625 }
626}
627
628fn format_duration(secs: f64) -> String {
633 let total_secs = secs as u64;
634 let hours = total_secs / 3600;
635 let minutes = (total_secs % 3600) / 60;
636 let seconds = total_secs % 60;
637
638 if hours > 0 {
639 format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
640 } else {
641 format!("{:02}:{:02}", minutes, seconds)
642 }
643}