1use super::state::{TrainingSnapshot, TrainingState, TrainingStatus};
8use std::io;
9use std::path::Path;
10use std::time::Duration;
11
12const DEFAULT_REFRESH_MS: u64 = 500;
14
15const LOSS_HISTORY_MAX: usize = 200;
17
18#[derive(Debug, Clone)]
20pub struct TuiMonitorConfig {
21 pub refresh_ms: u64,
23 pub width: usize,
25 pub height: usize,
27 pub compact: bool,
29 pub exit_on_complete: bool,
31}
32
33impl Default for TuiMonitorConfig {
34 fn default() -> Self {
35 Self {
36 refresh_ms: DEFAULT_REFRESH_MS,
37 width: 80,
38 height: 24,
39 compact: false,
40 exit_on_complete: true,
41 }
42 }
43}
44
45pub struct TuiMonitor {
51 config: TuiMonitorConfig,
52 state: TrainingState,
53}
54
55impl TuiMonitor {
56 pub fn new<P: AsRef<Path>>(experiment_dir: P, config: TuiMonitorConfig) -> Self {
58 Self { config, state: TrainingState::new(experiment_dir) }
59 }
60
61 pub fn run(&mut self) -> io::Result<()> {
70 eprintln!("Waiting for training state file at {}...", self.state.path().display());
72
73 if !self.state.wait_for_state(Duration::from_secs(60))? {
74 eprintln!("Timeout waiting for training state file.");
75 return Ok(());
76 }
77
78 eprintln!("Connected to training session. Press 'q' or Ctrl+C to detach.\n");
79
80 let experiment_dir =
82 self.state.path().parent().unwrap_or(std::path::Path::new(".")).to_path_buf();
83 let dashboard = super::dashboard::TrainingDashboard::new(experiment_dir);
84
85 let config = presentar_terminal::TuiConfig {
87 tick_rate_ms: self.config.refresh_ms,
88 ..Default::default()
89 };
90 let mut app = presentar_terminal::TuiApp::new(dashboard)
91 .map_err(|e| io::Error::other(e.to_string()))?
92 .with_config(config);
93
94 app.run().map_err(|e| io::Error::other(e.to_string()))?;
95
96 eprintln!("\nDetached from training session. Training continues in background.");
97 Ok(())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_training_status_match_all_variants() {
107 let statuses = [
108 TrainingStatus::Initializing,
109 TrainingStatus::Running,
110 TrainingStatus::Paused,
111 TrainingStatus::Completed,
112 TrainingStatus::Failed("oom".to_string()),
113 ];
114
115 for status in &statuses {
116 let description = match status {
118 TrainingStatus::Completed => "Completed Successfully".to_string(),
119 TrainingStatus::Failed(msg) => format!("FAILED - {msg}"),
120 TrainingStatus::Initializing | TrainingStatus::Running | TrainingStatus::Paused => {
121 "In progress".to_string()
122 }
123 };
124 assert!(!description.is_empty());
125 }
126 }
127
128 #[test]
129 fn test_tui_monitor_config_default() {
130 let config = TuiMonitorConfig::default();
131 assert_eq!(config.refresh_ms, 500);
132 assert_eq!(config.width, 80);
133 assert_eq!(config.height, 24);
134 assert!(!config.compact);
135 assert!(config.exit_on_complete);
136 }
137 #[test]
138 fn test_tui_monitor_config_custom() {
139 let config = TuiMonitorConfig {
140 refresh_ms: 1000,
141 width: 120,
142 height: 40,
143 compact: true,
144 exit_on_complete: false,
145 };
146 assert_eq!(config.refresh_ms, 1000);
147 assert_eq!(config.width, 120);
148 assert_eq!(config.height, 40);
149 assert!(config.compact);
150 assert!(!config.exit_on_complete);
151 }
152
153 #[test]
154 fn test_tui_monitor_config_clone() {
155 let config = TuiMonitorConfig::default();
156 let cloned = config.clone();
157 assert_eq!(config.refresh_ms, cloned.refresh_ms);
158 assert_eq!(config.width, cloned.width);
159 assert_eq!(config.compact, cloned.compact);
160 }
161
162 #[test]
163 fn test_tui_monitor_config_debug() {
164 let config = TuiMonitorConfig::default();
165 let debug = format!("{config:?}");
166 assert!(debug.contains("TuiMonitorConfig"));
167 assert!(debug.contains("500"));
168 }
169
170 #[test]
171 fn test_default_refresh_constant() {
172 assert_eq!(DEFAULT_REFRESH_MS, 500);
173 }
174
175 #[test]
176 fn test_loss_history_max_constant() {
177 assert_eq!(LOSS_HISTORY_MAX, 200);
178 }
179} pub struct TrainingStateWriter {
185 state: TrainingState,
186 snapshot: TrainingSnapshot,
187 history_max: usize,
188 console_progress: bool,
190}
191
192impl TrainingStateWriter {
193 pub fn new<P: AsRef<Path>>(experiment_dir: P, experiment_id: &str, model_name: &str) -> Self {
195 let mut snapshot = TrainingSnapshot::default();
196 snapshot.experiment_id = experiment_id.to_string();
197 snapshot.model_name = model_name.to_string();
198 snapshot.status = TrainingStatus::Initializing;
199
200 Self {
201 state: TrainingState::new(experiment_dir),
202 snapshot,
203 history_max: LOSS_HISTORY_MAX,
204 console_progress: false,
205 }
206 }
207
208 pub fn with_console_progress(mut self, enabled: bool) -> Self {
213 self.console_progress = enabled;
214 self
215 }
216
217 pub fn set_epochs(&mut self, total_epochs: usize, steps_per_epoch: usize) {
222 self.snapshot.total_epochs = total_epochs;
223 self.snapshot.steps_per_epoch = steps_per_epoch;
224 self.snapshot.epoch = 0;
226 self.snapshot.step = 0;
227 }
228
229 pub fn set_config(
231 &mut self,
232 optimizer_name: &str,
233 batch_size: usize,
234 model_path: &str,
235 checkpoint_path: &str,
236 ) {
237 self.snapshot.optimizer_name = optimizer_name.to_string();
238 self.snapshot.batch_size = batch_size;
239 self.snapshot.model_path = model_path.to_string();
240 self.snapshot.checkpoint_path = checkpoint_path.to_string();
241 if let Ok(exe) = std::env::current_exe() {
243 self.snapshot.executable_path = exe.display().to_string();
244 }
245 }
246
247 pub fn set_gpu(&mut self, device_name: &str, vram_total_gb: f32) {
252 self.snapshot.gpu = Some(super::state::GpuTelemetry {
253 device_name: device_name.to_string(),
254 vram_total_gb,
255 ..Default::default()
256 });
257 }
258
259 pub fn start(&mut self) -> io::Result<()> {
261 self.snapshot.status = TrainingStatus::Running;
262 let now = std::time::SystemTime::now()
263 .duration_since(std::time::UNIX_EPOCH)
264 .map(|d| d.as_millis() as u64)
265 .unwrap_or(0);
266 self.snapshot.start_timestamp_ms = now;
267 self.snapshot.timestamp_ms = now;
268 self.state.write(&self.snapshot)
269 }
270
271 pub fn update_step(
277 &mut self,
278 epoch: usize,
279 step: usize,
280 loss: f32,
281 learning_rate: f32,
282 gradient_norm: f32,
283 tokens_per_second: f32,
284 accuracy: f32,
285 ) -> io::Result<()> {
286 if self.snapshot.steps_per_epoch > 0 && step > self.snapshot.steps_per_epoch {
288 eprintln!(
289 "Warning: step {} exceeds steps_per_epoch {} - call set_epochs() at phase start",
290 step, self.snapshot.steps_per_epoch
291 );
292 }
293
294 if self.snapshot.total_epochs > 0 && epoch > self.snapshot.total_epochs {
296 eprintln!(
297 "Warning: epoch {} exceeds total_epochs {} - call set_epochs() at phase start",
298 epoch, self.snapshot.total_epochs
299 );
300 }
301
302 self.snapshot.epoch = epoch;
303 self.snapshot.step = step;
304 self.snapshot.loss = loss;
305 self.snapshot.learning_rate = learning_rate;
306 self.snapshot.gradient_norm = gradient_norm;
307 self.snapshot.tokens_per_second = tokens_per_second;
308 self.snapshot.accuracy = accuracy;
309 self.snapshot.samples_per_second = tokens_per_second;
310
311 self.snapshot.timestamp_ms = std::time::SystemTime::now()
313 .duration_since(std::time::UNIX_EPOCH)
314 .map(|d| d.as_millis() as u64)
315 .unwrap_or(0);
316
317 self.snapshot.loss_history.push(loss);
319 if self.snapshot.loss_history.len() > self.history_max {
320 self.snapshot.loss_history.remove(0);
321 }
322
323 self.snapshot.lr_history.push(learning_rate);
325 if self.snapshot.lr_history.len() > self.history_max {
326 self.snapshot.lr_history.remove(0);
327 }
328
329 if self.console_progress {
331 let log_every = (self.snapshot.steps_per_epoch / 10)
332 .max(10)
333 .min(self.snapshot.steps_per_epoch.max(1));
334 if step == 1 || step.is_multiple_of(log_every) || step == self.snapshot.steps_per_epoch
335 {
336 self.refresh_gpu_telemetry();
337 self.emit_console_progress();
338 }
339 }
340
341 self.state.write(&self.snapshot)
342 }
343
344 fn refresh_gpu_telemetry(&mut self) {
349 let device_name = match &self.snapshot.gpu {
350 Some(gpu) => gpu.device_name.clone(),
351 None => return,
352 };
353
354 let output = std::process::Command::new("nvidia-smi")
355 .args([
356 "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw,power.limit",
357 "--format=csv,noheader,nounits",
358 ])
359 .output();
360
361 let output = match output {
362 Ok(o) if o.status.success() => o,
363 _ => return,
364 };
365
366 let stdout = String::from_utf8_lossy(&output.stdout);
367 let line = match stdout.lines().next() {
368 Some(l) => l.trim(),
369 None => return,
370 };
371 let fields: Vec<&str> = line.split(',').map(str::trim).collect();
372 if fields.len() < 6 {
373 return;
374 }
375
376 self.snapshot.gpu = Some(super::state::GpuTelemetry {
377 device_name,
378 utilization_percent: fields[0].parse().unwrap_or(0.0),
379 vram_used_gb: fields[1].parse::<f32>().unwrap_or(0.0) / 1024.0,
380 vram_total_gb: fields[2].parse::<f32>().unwrap_or(0.0) / 1024.0,
381 temperature_celsius: fields[3].parse().unwrap_or(0.0),
382 power_watts: fields[4].parse().unwrap_or(0.0),
383 power_limit_watts: fields[5].parse().unwrap_or(0.0),
384 processes: Vec::new(),
385 });
386 }
387
388 fn emit_console_progress(&self) {
390 let mut buf = Vec::new();
391 let mut writer =
392 super::headless::HeadlessWriter::new(&mut buf, super::headless::OutputFormat::Text);
393 let _ = writer.write(&self.snapshot);
394 if let Ok(s) = String::from_utf8(buf) {
395 print!(" {s}");
396 }
397 }
398
399 pub fn emit_epoch_summary(
401 &self,
402 epoch: usize,
403 total_epochs: usize,
404 train_loss: f32,
405 train_acc: f32,
406 val_loss: f32,
407 val_acc: f32,
408 epoch_secs: f32,
409 lr: f32,
410 is_best: bool,
411 ) {
412 if self.console_progress {
413 let best = if is_best { " *best*" } else { "" };
414 println!(
415 " Epoch {epoch}/{total_epochs} done in {epoch_secs:.0}s — \
416 train_loss: {train_loss:.4}, train_acc: {:.1}%, \
417 val_loss: {val_loss:.4}, val_acc: {:.1}%, LR: {lr:.2e}{best}",
418 train_acc * 100.0,
419 val_acc * 100.0,
420 );
421 }
422 }
423
424 pub fn emit_info(&self, msg: &str) {
426 if self.console_progress {
427 println!(" {msg}");
428 }
429 }
430
431 pub fn update_gpu(&mut self, gpu: super::state::GpuTelemetry) -> io::Result<()> {
433 self.snapshot.gpu = Some(gpu);
434 self.state.write(&self.snapshot)
435 }
436
437 pub fn update_sample(&mut self, sample: super::state::SamplePeek) -> io::Result<()> {
439 self.snapshot.sample = Some(sample);
440 self.state.write(&self.snapshot)
441 }
442
443 pub fn complete(&mut self) -> io::Result<()> {
445 self.snapshot.status = TrainingStatus::Completed;
446 self.state.write(&self.snapshot)
447 }
448
449 pub fn fail(&mut self, error: &str) -> io::Result<()> {
451 self.snapshot.status = TrainingStatus::Failed(error.to_string());
452 self.state.write(&self.snapshot)
453 }
454
455 pub fn state_path(&self) -> &Path {
457 self.state.path()
458 }
459}
460
461#[cfg(test)]
462mod state_writer_tests {
463 use super::*;
464 use tempfile::TempDir;
465
466 #[test]
467 fn test_training_state_writer() {
468 let temp_dir = TempDir::new().expect("temp file creation should succeed");
469 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
470
471 writer.set_epochs(10, 100);
472 writer.start().expect("file write should succeed");
473
474 writer
475 .update_step(1, 10, 0.5, 0.0002, 1.5, 1200.0, 0.75)
476 .expect("file write should succeed");
477
478 let mut state = TrainingState::new(temp_dir.path());
480 let snapshot =
481 state.read().expect("file read should succeed").expect("file read should succeed");
482
483 assert_eq!(snapshot.epoch, 1);
484 assert_eq!(snapshot.step, 10);
485 assert!((snapshot.loss - 0.5).abs() < 0.001);
486 assert_eq!(snapshot.status, TrainingStatus::Running);
487 }
488
489 #[test]
490 fn test_training_state_writer_complete() {
491 let temp_dir = TempDir::new().expect("temp file creation should succeed");
492 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
493
494 writer.start().expect("file write should succeed");
495 writer.complete().expect("file write should succeed");
496
497 let mut state = TrainingState::new(temp_dir.path());
498 let snapshot =
499 state.read().expect("file read should succeed").expect("file read should succeed");
500
501 assert_eq!(snapshot.status, TrainingStatus::Completed);
502 }
503
504 #[test]
505 fn test_training_state_writer_fail() {
506 let temp_dir = TempDir::new().expect("temp file creation should succeed");
507 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
508
509 writer.start().expect("file write should succeed");
510 writer.fail("OOM error").expect("file write should succeed");
511
512 let mut state = TrainingState::new(temp_dir.path());
513 let snapshot =
514 state.read().expect("file read should succeed").expect("file read should succeed");
515
516 match snapshot.status {
517 TrainingStatus::Failed(msg) => assert!(msg.contains("OOM")),
518 _ => panic!("Expected Failed status"),
519 }
520 }
521
522 #[test]
523 fn test_loss_history_truncation() {
524 let temp_dir = TempDir::new().expect("temp file creation should succeed");
525 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
526 writer.history_max = 5; writer.start().expect("file write should succeed");
529
530 for i in 0..10 {
532 writer
533 .update_step(1, i, i as f32 * 0.1, 0.0002, 1.5, 1200.0, 0.0)
534 .expect("file write should succeed");
535 }
536
537 let mut state = TrainingState::new(temp_dir.path());
538 let snapshot =
539 state.read().expect("file read should succeed").expect("file read should succeed");
540
541 assert_eq!(snapshot.loss_history.len(), 5);
542 assert!((snapshot.loss_history[0] - 0.5).abs() < 0.001);
544 }
545
546 #[test]
547 fn test_with_console_progress() {
548 let temp_dir = TempDir::new().expect("temp file creation should succeed");
549 let writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
550 assert!(!writer.console_progress);
551 let writer = writer.with_console_progress(true);
552 assert!(writer.console_progress);
553 let writer = writer.with_console_progress(false);
554 assert!(!writer.console_progress);
555 }
556
557 #[test]
558 fn test_set_epochs() {
559 let temp_dir = TempDir::new().expect("temp file creation should succeed");
560 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
561 writer.set_epochs(10, 500);
562 assert_eq!(writer.snapshot.total_epochs, 10);
563 assert_eq!(writer.snapshot.steps_per_epoch, 500);
564 assert_eq!(writer.snapshot.epoch, 0);
565 assert_eq!(writer.snapshot.step, 0);
566 }
567
568 #[test]
569 fn test_set_config() {
570 let temp_dir = TempDir::new().expect("temp file creation should succeed");
571 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
572 writer.set_config("AdamW", 32, "/tmp/model", "/tmp/ckpt");
573 assert_eq!(writer.snapshot.optimizer_name, "AdamW");
574 assert_eq!(writer.snapshot.batch_size, 32);
575 assert_eq!(writer.snapshot.model_path, "/tmp/model");
576 assert_eq!(writer.snapshot.checkpoint_path, "/tmp/ckpt");
577 assert!(!writer.snapshot.executable_path.is_empty());
579 }
580
581 #[test]
582 fn test_set_gpu() {
583 let temp_dir = TempDir::new().expect("temp file creation should succeed");
584 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
585 writer.set_gpu("RTX 4090", 24.0);
586 let gpu = writer.snapshot.gpu.as_ref().expect("gpu should be set");
587 assert_eq!(gpu.device_name, "RTX 4090");
588 assert!((gpu.vram_total_gb - 24.0).abs() < f32::EPSILON);
589 }
590
591 #[test]
592 fn test_update_step_stores_lr_history() {
593 let temp_dir = TempDir::new().expect("temp file creation should succeed");
594 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
595 writer.set_epochs(2, 3);
596 writer.start().expect("file write should succeed");
597
598 writer.update_step(1, 1, 0.5, 0.001, 1.0, 100.0, 0.8).expect("file write should succeed");
599 writer.update_step(1, 2, 0.4, 0.0005, 0.9, 110.0, 0.85).expect("file write should succeed");
600
601 assert_eq!(writer.snapshot.lr_history.len(), 2);
602 assert!((writer.snapshot.lr_history[0] - 0.001).abs() < 1e-6);
603 assert!((writer.snapshot.lr_history[1] - 0.0005).abs() < 1e-6);
604 }
605
606 #[test]
607 fn test_lr_history_truncation() {
608 let temp_dir = TempDir::new().expect("temp file creation should succeed");
609 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
610 writer.history_max = 3;
611 writer.set_epochs(1, 10);
612 writer.start().expect("file write should succeed");
613
614 for i in 0..5 {
615 writer
616 .update_step(1, i, 0.5, i as f32 * 0.001, 1.0, 100.0, 0.8)
617 .expect("file write should succeed");
618 }
619
620 assert_eq!(writer.snapshot.lr_history.len(), 3);
621 assert!((writer.snapshot.lr_history[0] - 0.002).abs() < 1e-6);
623 }
624
625 #[test]
626 fn test_update_gpu() {
627 let temp_dir = TempDir::new().expect("temp file creation should succeed");
628 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
629 writer.start().expect("file write should succeed");
630
631 let gpu = crate::monitor::tui::state::GpuTelemetry {
632 device_name: "Test GPU".to_string(),
633 utilization_percent: 95.0,
634 vram_used_gb: 20.0,
635 vram_total_gb: 24.0,
636 temperature_celsius: 75.0,
637 power_watts: 350.0,
638 power_limit_watts: 400.0,
639 processes: Vec::new(),
640 };
641 writer.update_gpu(gpu).expect("file write should succeed");
642
643 let mut state = TrainingState::new(temp_dir.path());
644 let snapshot =
645 state.read().expect("file read should succeed").expect("file read should succeed");
646 let gpu = snapshot.gpu.expect("gpu should be present");
647 assert_eq!(gpu.device_name, "Test GPU");
648 assert!((gpu.utilization_percent - 95.0).abs() < f32::EPSILON);
649 }
650
651 #[test]
652 fn test_update_sample() {
653 let temp_dir = TempDir::new().expect("temp file creation should succeed");
654 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
655 writer.start().expect("file write should succeed");
656
657 let sample = crate::monitor::tui::state::SamplePeek {
658 input_preview: "def hello():".to_string(),
659 target_preview: "def test_hello():".to_string(),
660 generated_preview: "def test_hello():".to_string(),
661 token_match_percent: 95.0,
662 };
663 writer.update_sample(sample).expect("file write should succeed");
664
665 let mut state = TrainingState::new(temp_dir.path());
666 let snapshot =
667 state.read().expect("file read should succeed").expect("file read should succeed");
668 let sample = snapshot.sample.expect("sample should be present");
669 assert_eq!(sample.input_preview, "def hello():");
670 assert!((sample.token_match_percent - 95.0).abs() < f32::EPSILON);
671 }
672
673 #[test]
674 fn test_state_path() {
675 let temp_dir = TempDir::new().expect("temp file creation should succeed");
676 let writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
677 let path = writer.state_path();
678 assert!(path.to_str().unwrap_or("").contains("training_state"));
679 }
680
681 #[test]
682 fn test_start_sets_timestamps() {
683 let temp_dir = TempDir::new().expect("temp file creation should succeed");
684 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
685 writer.start().expect("file write should succeed");
686
687 assert_eq!(writer.snapshot.status, TrainingStatus::Running);
688 assert!(writer.snapshot.start_timestamp_ms > 0);
689 assert!(writer.snapshot.timestamp_ms > 0);
690 }
691
692 #[test]
693 fn test_emit_epoch_summary_noop_when_disabled() {
694 let temp_dir = TempDir::new().expect("temp file creation should succeed");
695 let writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
696 writer.emit_epoch_summary(1, 10, 0.5, 0.8, 0.4, 0.85, 60.0, 0.001, true);
698 }
700
701 #[test]
702 fn test_emit_info_noop_when_disabled() {
703 let temp_dir = TempDir::new().expect("temp file creation should succeed");
704 let writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
705 writer.emit_info("Test message");
706 }
708
709 #[test]
710 fn test_emit_console_progress_format() {
711 let temp_dir = TempDir::new().expect("temp file creation should succeed");
712 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
713 writer.set_epochs(3, 100);
714 writer.snapshot.epoch = 1;
715 writer.snapshot.step = 50;
716 writer.snapshot.loss = 2.5;
717 writer.snapshot.learning_rate = 0.001;
718 writer.snapshot.gradient_norm = 1.5;
719 writer.snapshot.tokens_per_second = 1200.0;
720 writer.snapshot.accuracy = 0.85;
721 writer.snapshot.status = TrainingStatus::Running;
722 writer.emit_console_progress();
724 }
725
726 #[test]
727 fn test_update_step_updates_timestamp() {
728 let temp_dir = TempDir::new().expect("temp file creation should succeed");
729 let mut writer = TrainingStateWriter::new(temp_dir.path(), "test-001", "test-model");
730 writer.set_epochs(1, 10);
731 writer.start().expect("file write should succeed");
732
733 let ts_before = writer.snapshot.timestamp_ms;
734 std::thread::sleep(std::time::Duration::from_millis(2));
735 writer.update_step(1, 1, 0.5, 0.001, 1.0, 100.0, 0.8).expect("file write should succeed");
736 assert!(writer.snapshot.timestamp_ms >= ts_before);
737 }
738}