1use serde::{Deserialize, Serialize};
7use std::fs::{self, File};
8use std::io::{BufReader, BufWriter};
9use std::path::Path;
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct GpuProcessInfo {
15 pub pid: u32,
17 pub exe_path: String,
19 pub gpu_memory_mb: u64,
21 pub cpu_percent: f32,
23 pub rss_mb: u64,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct GpuTelemetry {
30 pub device_name: String,
32 pub utilization_percent: f32,
34 pub vram_used_gb: f32,
36 pub vram_total_gb: f32,
38 pub temperature_celsius: f32,
40 pub power_watts: f32,
42 pub power_limit_watts: f32,
44 #[serde(default)]
46 pub processes: Vec<GpuProcessInfo>,
47}
48
49impl GpuTelemetry {
50 pub fn vram_percent(&self) -> f32 {
52 if self.vram_total_gb > 0.0 {
53 (self.vram_used_gb / self.vram_total_gb) * 100.0
54 } else {
55 0.0
56 }
57 }
58
59 pub fn is_thermal_throttling(&self) -> bool {
61 self.temperature_celsius > 83.0
62 }
63
64 pub fn is_power_limited(&self) -> bool {
66 self.power_watts >= self.power_limit_watts * 0.95
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72pub struct SamplePeek {
73 pub input_preview: String,
75 pub target_preview: String,
77 pub generated_preview: String,
79 pub token_match_percent: f32,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct TrainingSnapshot {
86 pub timestamp_ms: u64,
88 pub epoch: usize,
90 pub total_epochs: usize,
92 pub step: usize,
94 pub steps_per_epoch: usize,
96 pub loss: f32,
98 pub loss_history: Vec<f32>,
100 pub learning_rate: f32,
102 #[serde(default)]
104 pub lr_history: Vec<f32>,
105 pub gradient_norm: f32,
107 #[serde(default)]
109 pub accuracy: f32,
110 pub tokens_per_second: f32,
112 #[serde(default)]
114 pub samples_per_second: f32,
115 pub start_timestamp_ms: u64,
117 pub gpu: Option<GpuTelemetry>,
119 pub sample: Option<SamplePeek>,
121 pub status: TrainingStatus,
123 pub experiment_id: String,
125 pub model_name: String,
127 #[serde(default)]
129 pub model_path: String,
130 #[serde(default)]
132 pub optimizer_name: String,
133 #[serde(default)]
135 pub batch_size: usize,
136 #[serde(default)]
138 pub checkpoint_path: String,
139 #[serde(default)]
141 pub executable_path: String,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146pub enum TrainingStatus {
147 Initializing,
149 Running,
151 Paused,
153 Completed,
155 Failed(String),
157}
158
159impl Default for TrainingSnapshot {
160 fn default() -> Self {
161 let now =
162 SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_millis() as u64).unwrap_or(0);
163
164 Self {
165 timestamp_ms: now,
166 epoch: 0,
167 total_epochs: 0,
168 step: 0,
169 steps_per_epoch: 0,
170 loss: 0.0,
171 loss_history: Vec::new(),
172 learning_rate: 0.0,
173 lr_history: Vec::new(),
174 gradient_norm: 0.0,
175 accuracy: 0.0,
176 tokens_per_second: 0.0,
177 samples_per_second: 0.0,
178 start_timestamp_ms: now,
179 gpu: None,
180 sample: None,
181 status: TrainingStatus::Initializing,
182 experiment_id: String::new(),
183 model_name: String::new(),
184 model_path: String::new(),
185 optimizer_name: String::new(),
186 batch_size: 0,
187 checkpoint_path: String::new(),
188 executable_path: String::new(),
189 }
190 }
191}
192
193impl TrainingSnapshot {
194 pub fn elapsed(&self) -> Duration {
197 Duration::from_millis(self.timestamp_ms.saturating_sub(self.start_timestamp_ms))
198 }
199
200 pub fn estimated_remaining(&self) -> Option<Duration> {
202 if self.tokens_per_second <= 0.0 {
203 return None;
204 }
205
206 let total_steps = self.total_epochs * self.steps_per_epoch;
207 let completed_steps = (self.epoch.saturating_sub(1)) * self.steps_per_epoch + self.step;
208
209 if completed_steps == 0 || total_steps == 0 {
210 return None;
211 }
212
213 let progress = completed_steps as f64 / total_steps as f64;
214 if progress >= 1.0 {
215 return Some(Duration::ZERO);
216 }
217
218 let elapsed_ms = self.timestamp_ms.saturating_sub(self.start_timestamp_ms);
219 let total_estimated_ms = (elapsed_ms as f64 / progress) as u64;
220 let remaining_ms = total_estimated_ms.saturating_sub(elapsed_ms);
221
222 Some(Duration::from_millis(remaining_ms))
223 }
224
225 pub fn global_step(&self) -> usize {
227 (self.epoch.saturating_sub(1)) * self.steps_per_epoch + self.step
228 }
229
230 pub fn progress_percent(&self) -> f32 {
232 let total = self.total_epochs * self.steps_per_epoch;
233 if total == 0 {
234 return 0.0;
235 }
236 (self.global_step() as f32 / total as f32) * 100.0
237 }
238
239 pub fn loss_trend(&self) -> LossTrend {
247 if self.loss_history.len() < 5 {
249 return LossTrend::Unknown;
250 }
251
252 let window = self.loss_history.len().min(10);
254 let recent = &self.loss_history[self.loss_history.len() - window..];
255
256 let mid = window / 2;
258 let first_half: f32 = recent[..mid].iter().sum::<f32>() / mid as f32;
259 let second_half: f32 = recent[mid..].iter().sum::<f32>() / (window - mid) as f32;
260
261 let change = (second_half - first_half) / first_half.abs().max(1e-6);
263
264 const THRESHOLD: f32 = 0.02;
266
267 if change < -THRESHOLD {
268 LossTrend::Decreasing
269 } else if change > THRESHOLD {
270 LossTrend::Increasing
271 } else {
272 LossTrend::Stable
273 }
274 }
275}
276
277#[derive(Debug, Clone, Copy, PartialEq, Eq)]
279pub enum LossTrend {
280 Decreasing,
282 Stable,
284 Increasing,
286 Unknown,
288}
289
290impl LossTrend {
291 pub fn arrow(&self) -> &'static str {
293 match self {
294 LossTrend::Decreasing => "↓",
295 LossTrend::Stable => "→",
296 LossTrend::Increasing => "↑",
297 LossTrend::Unknown => "?",
298 }
299 }
300
301 pub fn description(&self) -> &'static str {
303 match self {
304 LossTrend::Decreasing => "decreasing",
305 LossTrend::Stable => "stable",
306 LossTrend::Increasing => "increasing",
307 LossTrend::Unknown => "unknown",
308 }
309 }
310}
311
312pub struct TrainingState {
316 state_path: std::path::PathBuf,
318 last_snapshot: Option<TrainingSnapshot>,
320 last_modified: Option<std::time::SystemTime>,
322}
323
324impl TrainingState {
325 pub fn new<P: AsRef<Path>>(experiment_dir: P) -> Self {
331 let state_path = experiment_dir.as_ref().join("training_state.json");
332 Self { state_path, last_snapshot: None, last_modified: None }
333 }
334
335 pub fn write(&self, snapshot: &TrainingSnapshot) -> std::io::Result<()> {
339 if let Some(parent) = self.state_path.parent() {
341 fs::create_dir_all(parent)?;
342 }
343
344 let temp_path = self.state_path.with_extension("json.tmp");
346 let file = File::create(&temp_path)?;
347 let writer = BufWriter::new(file);
348 serde_json::to_writer_pretty(writer, snapshot)
349 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
350
351 fs::rename(&temp_path, &self.state_path)?;
353
354 Ok(())
355 }
356
357 pub fn read(&mut self) -> std::io::Result<Option<TrainingSnapshot>> {
361 if !self.state_path.exists() {
362 return Ok(None);
363 }
364
365 let metadata = fs::metadata(&self.state_path)?;
367 let modified = metadata.modified()?;
368
369 if self.last_modified == Some(modified) {
370 return Ok(self.last_snapshot.clone());
372 }
373
374 let file = File::open(&self.state_path)?;
376 let reader = BufReader::new(file);
377 let snapshot: TrainingSnapshot = serde_json::from_reader(reader)
378 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
379
380 self.last_snapshot = Some(snapshot.clone());
382 self.last_modified = Some(modified);
383
384 Ok(Some(snapshot))
385 }
386
387 pub fn exists(&self) -> bool {
389 self.state_path.exists()
390 }
391
392 pub fn path(&self) -> &Path {
394 &self.state_path
395 }
396
397 pub fn wait_for_state(&mut self, timeout: Duration) -> std::io::Result<bool> {
399 let start = Instant::now();
400 while start.elapsed() < timeout {
401 if self.exists() {
402 return Ok(true);
403 }
404 std::thread::sleep(Duration::from_millis(100));
405 }
406 Ok(false)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use proptest::prelude::*;
414 use tempfile::TempDir;
415
416 #[test]
417 fn test_training_snapshot_default() {
418 let snapshot = TrainingSnapshot::default();
419 assert_eq!(snapshot.epoch, 0);
420 assert_eq!(snapshot.status, TrainingStatus::Initializing);
421 }
422
423 #[test]
424 fn test_training_snapshot_progress() {
425 let mut snapshot = TrainingSnapshot::default();
426 snapshot.epoch = 2;
427 snapshot.total_epochs = 10;
428 snapshot.step = 50;
429 snapshot.steps_per_epoch = 100;
430
431 assert_eq!(snapshot.global_step(), 150);
433
434 assert!((snapshot.progress_percent() - 15.0).abs() < 0.01);
436 }
437
438 #[test]
439 fn test_gpu_telemetry_vram_percent() {
440 let gpu = GpuTelemetry { vram_used_gb: 4.0, vram_total_gb: 24.0, ..Default::default() };
441 assert!((gpu.vram_percent() - 16.67).abs() < 0.1);
442 }
443
444 #[test]
445 fn test_training_state_write_read() {
446 let temp_dir = TempDir::new().expect("temp file creation should succeed");
447 let mut state = TrainingState::new(temp_dir.path());
448
449 let snapshot = TrainingSnapshot {
450 epoch: 5,
451 total_epochs: 10,
452 loss: 0.42,
453 status: TrainingStatus::Running,
454 ..Default::default()
455 };
456
457 state.write(&snapshot).expect("file write should succeed");
458 assert!(state.exists());
459
460 let read_snapshot =
461 state.read().expect("file read should succeed").expect("file read should succeed");
462 assert_eq!(read_snapshot.epoch, 5);
463 assert!((read_snapshot.loss - 0.42).abs() < 0.001);
464 }
465
466 #[test]
467 fn test_training_state_caching() {
468 let temp_dir = TempDir::new().expect("temp file creation should succeed");
469 let mut state = TrainingState::new(temp_dir.path());
470
471 let snapshot = TrainingSnapshot { epoch: 1, ..Default::default() };
472
473 state.write(&snapshot).expect("file write should succeed");
474
475 let _ = state.read().expect("file read should succeed");
477
478 let cached =
480 state.read().expect("file read should succeed").expect("file read should succeed");
481 assert_eq!(cached.epoch, 1);
482 }
483
484 proptest! {
487 #[test]
489 fn prop_snapshot_json_roundtrip(
490 epoch in 1usize..1000,
491 total_epochs in 1usize..100,
492 step in 0usize..10000,
493 steps_per_epoch in 1usize..10000,
494 loss in 0.0f32..100.0,
495 learning_rate in 1e-10f32..1.0,
496 gradient_norm in 0.0f32..1000.0,
497 tokens_per_second in 0.0f32..10000.0,
498 ) {
499 let snapshot = TrainingSnapshot {
500 timestamp_ms: 12345678,
501 epoch,
502 total_epochs,
503 step,
504 steps_per_epoch,
505 loss,
506 loss_history: vec![loss * 1.1, loss * 1.05, loss],
507 learning_rate,
508 lr_history: vec![learning_rate; 3],
509 gradient_norm,
510 accuracy: 0.0,
511 tokens_per_second,
512 samples_per_second: 0.0,
513 start_timestamp_ms: 12345000,
514 gpu: None,
515 sample: None,
516 status: TrainingStatus::Running,
517 experiment_id: "test".to_string(),
518 model_name: "model".to_string(),
519 model_path: String::new(),
520 optimizer_name: "AdamW".to_string(),
521 batch_size: 4,
522 checkpoint_path: String::new(),
523 executable_path: String::new(),
524 };
525
526 let json = serde_json::to_string(&snapshot).expect("JSON serialization should succeed");
528
529 let restored: TrainingSnapshot = serde_json::from_str(&json).expect("JSON deserialization should succeed");
531
532 prop_assert_eq!(restored.epoch, epoch);
534 prop_assert_eq!(restored.total_epochs, total_epochs);
535 prop_assert_eq!(restored.step, step);
536 prop_assert_eq!(restored.steps_per_epoch, steps_per_epoch);
537 prop_assert!((restored.loss - loss).abs() < 1e-5);
538 prop_assert!((restored.learning_rate - learning_rate).abs() < 1e-10);
539 prop_assert!((restored.gradient_norm - gradient_norm).abs() < 1e-5);
540 }
541
542 #[test]
544 fn prop_loss_trend_consistent(
545 base_loss in 1.0f32..10.0,
546 trend_factor in -0.1f32..0.1,
547 ) {
548 let history: Vec<f32> = (0..10)
551 .map(|i| base_loss + (i as f32 * trend_factor))
552 .collect();
553
554 let snapshot = TrainingSnapshot {
555 loss_history: history,
556 ..Default::default()
557 };
558
559 let trend = snapshot.loss_trend();
560
561 if trend_factor > 0.05 {
565 prop_assert_eq!(trend, LossTrend::Increasing);
566 } else if trend_factor < -0.05 {
567 prop_assert_eq!(trend, LossTrend::Decreasing);
568 }
569 }
571
572 #[test]
574 fn prop_gpu_vram_percent_bounded(
575 vram_used in 0.0f32..100.0,
576 vram_total in 1.0f32..100.0,
577 ) {
578 let gpu = GpuTelemetry {
579 vram_used_gb: vram_used.min(vram_total),
580 vram_total_gb: vram_total,
581 ..Default::default()
582 };
583
584 let percent = gpu.vram_percent();
585 prop_assert!(percent >= 0.0);
586 prop_assert!(percent <= 100.0);
587 }
588
589 #[test]
591 fn prop_progress_percent_bounded(
592 epoch in 1usize..100,
593 total_epochs in 1usize..100,
594 step in 0usize..1000,
595 steps_per_epoch in 1usize..1000,
596 ) {
597 let epoch = epoch.min(total_epochs);
598 let step = step.min(steps_per_epoch);
599
600 let snapshot = TrainingSnapshot {
601 epoch,
602 total_epochs,
603 step,
604 steps_per_epoch,
605 ..Default::default()
606 };
607
608 let progress = snapshot.progress_percent();
609 prop_assert!(progress >= 0.0);
610 prop_assert!(progress <= 100.0);
611 }
612
613 #[test]
615 fn prop_state_file_roundtrip(
616 epoch in 1usize..100,
617 loss in 0.0f32..100.0,
618 lr in 1e-6f32..0.1,
619 ) {
620 let temp_dir = TempDir::new().expect("temp file creation should succeed");
621 let mut state = TrainingState::new(temp_dir.path());
622
623 let snapshot = TrainingSnapshot {
624 epoch,
625 total_epochs: 10,
626 loss,
627 learning_rate: lr,
628 status: TrainingStatus::Running,
629 ..Default::default()
630 };
631
632 state.write(&snapshot).expect("file write should succeed");
633
634 state.last_modified = None;
636 let restored = state.read().expect("file read should succeed").expect("file read should succeed");
637
638 prop_assert_eq!(restored.epoch, epoch);
639 prop_assert!((restored.loss - loss).abs() < 1e-5);
640 prop_assert!((restored.learning_rate - lr).abs() < 1e-10);
641 }
642 }
643
644 #[test]
647 fn test_gpu_telemetry_vram_percent_zero_total() {
648 let gpu = GpuTelemetry { vram_used_gb: 4.0, vram_total_gb: 0.0, ..Default::default() };
649 assert!((gpu.vram_percent() - 0.0).abs() < f32::EPSILON);
650 }
651
652 #[test]
653 fn test_gpu_telemetry_thermal_throttling() {
654 let gpu = GpuTelemetry { temperature_celsius: 84.0, ..Default::default() };
655 assert!(gpu.is_thermal_throttling());
656
657 let gpu2 = GpuTelemetry { temperature_celsius: 83.0, ..Default::default() };
658 assert!(!gpu2.is_thermal_throttling());
659
660 let gpu3 = GpuTelemetry { temperature_celsius: 70.0, ..Default::default() };
661 assert!(!gpu3.is_thermal_throttling());
662 }
663
664 #[test]
665 fn test_gpu_telemetry_power_limited() {
666 let gpu =
667 GpuTelemetry { power_watts: 380.0, power_limit_watts: 400.0, ..Default::default() };
668 assert!(gpu.is_power_limited()); let gpu2 =
671 GpuTelemetry { power_watts: 300.0, power_limit_watts: 400.0, ..Default::default() };
672 assert!(!gpu2.is_power_limited());
673 }
674
675 #[test]
676 fn test_training_snapshot_elapsed() {
677 let snapshot =
678 TrainingSnapshot { start_timestamp_ms: 1000, timestamp_ms: 6000, ..Default::default() };
679 assert_eq!(snapshot.elapsed(), Duration::from_millis(5000));
680 }
681
682 #[test]
683 fn test_training_snapshot_elapsed_same_time() {
684 let snapshot =
685 TrainingSnapshot { start_timestamp_ms: 5000, timestamp_ms: 5000, ..Default::default() };
686 assert_eq!(snapshot.elapsed(), Duration::ZERO);
687 }
688
689 #[test]
690 fn test_training_snapshot_estimated_remaining_none_zero_tps() {
691 let snapshot = TrainingSnapshot { tokens_per_second: 0.0, ..Default::default() };
692 assert!(snapshot.estimated_remaining().is_none());
693 }
694
695 #[test]
696 fn test_training_snapshot_estimated_remaining_none_zero_steps() {
697 let snapshot = TrainingSnapshot {
698 tokens_per_second: 100.0,
699 total_epochs: 0,
700 steps_per_epoch: 0,
701 ..Default::default()
702 };
703 assert!(snapshot.estimated_remaining().is_none());
704 }
705
706 #[test]
707 fn test_training_snapshot_estimated_remaining_completed() {
708 let snapshot = TrainingSnapshot {
709 tokens_per_second: 100.0,
710 epoch: 10,
711 total_epochs: 10,
712 step: 100,
713 steps_per_epoch: 100,
714 start_timestamp_ms: 1000,
715 timestamp_ms: 11000,
716 ..Default::default()
717 };
718 let remaining = snapshot.estimated_remaining();
719 assert!(remaining.is_some());
720 assert_eq!(remaining.unwrap(), Duration::ZERO);
721 }
722
723 #[test]
724 fn test_training_snapshot_estimated_remaining_halfway() {
725 let snapshot = TrainingSnapshot {
726 tokens_per_second: 100.0,
727 epoch: 5,
728 total_epochs: 10,
729 step: 50,
730 steps_per_epoch: 100,
731 start_timestamp_ms: 0,
732 timestamp_ms: 10000,
733 ..Default::default()
734 };
735 let remaining = snapshot.estimated_remaining();
736 assert!(remaining.is_some());
737 let rem_ms = remaining.unwrap().as_millis();
739 assert!(rem_ms > 5000 && rem_ms < 30000);
740 }
741
742 #[test]
743 fn test_training_snapshot_global_step() {
744 let snapshot =
745 TrainingSnapshot { epoch: 3, steps_per_epoch: 100, step: 42, ..Default::default() };
746 assert_eq!(snapshot.global_step(), 242);
748 }
749
750 #[test]
751 fn test_training_snapshot_global_step_first_epoch() {
752 let snapshot =
753 TrainingSnapshot { epoch: 1, steps_per_epoch: 50, step: 10, ..Default::default() };
754 assert_eq!(snapshot.global_step(), 10);
756 }
757
758 #[test]
759 fn test_training_snapshot_progress_zero() {
760 let snapshot =
761 TrainingSnapshot { total_epochs: 0, steps_per_epoch: 0, ..Default::default() };
762 assert!((snapshot.progress_percent() - 0.0).abs() < f32::EPSILON);
763 }
764
765 #[test]
766 fn test_loss_trend_unknown_few_samples() {
767 let snapshot = TrainingSnapshot { loss_history: vec![1.0, 2.0, 3.0], ..Default::default() };
768 assert_eq!(snapshot.loss_trend(), LossTrend::Unknown);
769 }
770
771 #[test]
772 fn test_loss_trend_decreasing() {
773 let snapshot = TrainingSnapshot {
774 loss_history: vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
775 ..Default::default()
776 };
777 assert_eq!(snapshot.loss_trend(), LossTrend::Decreasing);
778 }
779
780 #[test]
781 fn test_loss_trend_increasing() {
782 let snapshot = TrainingSnapshot {
783 loss_history: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
784 ..Default::default()
785 };
786 assert_eq!(snapshot.loss_trend(), LossTrend::Increasing);
787 }
788
789 #[test]
790 fn test_loss_trend_stable() {
791 let snapshot = TrainingSnapshot {
792 loss_history: vec![5.0, 5.01, 4.99, 5.0, 5.01, 4.99, 5.0, 5.01, 4.99, 5.0],
793 ..Default::default()
794 };
795 assert_eq!(snapshot.loss_trend(), LossTrend::Stable);
796 }
797
798 #[test]
799 fn test_loss_trend_arrow() {
800 assert_eq!(LossTrend::Decreasing.arrow(), "\u{2193}");
801 assert_eq!(LossTrend::Stable.arrow(), "\u{2192}");
802 assert_eq!(LossTrend::Increasing.arrow(), "\u{2191}");
803 assert_eq!(LossTrend::Unknown.arrow(), "?");
804 }
805
806 #[test]
807 fn test_loss_trend_description() {
808 assert_eq!(LossTrend::Decreasing.description(), "decreasing");
809 assert_eq!(LossTrend::Stable.description(), "stable");
810 assert_eq!(LossTrend::Increasing.description(), "increasing");
811 assert_eq!(LossTrend::Unknown.description(), "unknown");
812 }
813
814 #[test]
815 fn test_training_state_new_path() {
816 let state = TrainingState::new("/tmp/test-exp");
817 assert_eq!(state.path(), std::path::Path::new("/tmp/test-exp/training_state.json"));
818 }
819
820 #[test]
821 fn test_training_state_exists_missing() {
822 let temp_dir = TempDir::new().expect("temp file creation should succeed");
823 let state = TrainingState::new(temp_dir.path().join("nonexistent"));
824 assert!(!state.exists());
825 }
826
827 #[test]
828 fn test_training_state_read_missing_file() {
829 let temp_dir = TempDir::new().expect("temp file creation should succeed");
830 let mut state = TrainingState::new(temp_dir.path().join("nonexistent"));
831 let result = state.read().expect("should not error for missing file");
832 assert!(result.is_none());
833 }
834
835 #[test]
836 fn test_training_state_wait_for_state_already_exists() {
837 let temp_dir = TempDir::new().expect("temp file creation should succeed");
838 let mut state = TrainingState::new(temp_dir.path());
839 let snapshot = TrainingSnapshot::default();
840 state.write(&snapshot).expect("write should succeed");
841
842 let found = state.wait_for_state(Duration::from_millis(100)).expect("ok");
843 assert!(found);
844 }
845
846 #[test]
847 fn test_training_state_wait_for_state_timeout() {
848 let temp_dir = TempDir::new().expect("temp file creation should succeed");
849 let mut state = TrainingState::new(temp_dir.path().join("never-exists"));
850 let found = state.wait_for_state(Duration::from_millis(200)).expect("ok");
851 assert!(!found);
852 }
853
854 #[test]
855 fn test_gpu_process_info_default() {
856 let info = GpuProcessInfo::default();
857 assert_eq!(info.pid, 0);
858 assert!(info.exe_path.is_empty());
859 assert_eq!(info.gpu_memory_mb, 0);
860 }
861
862 #[test]
863 fn test_sample_peek_default() {
864 let sample = SamplePeek::default();
865 assert!(sample.input_preview.is_empty());
866 assert!(sample.target_preview.is_empty());
867 assert!(sample.generated_preview.is_empty());
868 assert!((sample.token_match_percent - 0.0).abs() < f32::EPSILON);
869 }
870
871 #[test]
872 fn test_training_status_equality() {
873 assert_eq!(TrainingStatus::Running, TrainingStatus::Running);
874 assert_eq!(TrainingStatus::Completed, TrainingStatus::Completed);
875 assert_ne!(TrainingStatus::Running, TrainingStatus::Completed);
876 assert_eq!(
877 TrainingStatus::Failed("a".to_string()),
878 TrainingStatus::Failed("a".to_string())
879 );
880 assert_ne!(
881 TrainingStatus::Failed("a".to_string()),
882 TrainingStatus::Failed("b".to_string())
883 );
884 }
885
886 #[test]
887 fn test_gpu_telemetry_serde_roundtrip() {
888 let gpu = GpuTelemetry {
889 device_name: "RTX 4090".to_string(),
890 utilization_percent: 95.0,
891 vram_used_gb: 20.0,
892 vram_total_gb: 24.0,
893 temperature_celsius: 72.0,
894 power_watts: 350.0,
895 power_limit_watts: 400.0,
896 processes: vec![GpuProcessInfo {
897 pid: 1234,
898 exe_path: "/usr/bin/python3".to_string(),
899 gpu_memory_mb: 19000,
900 cpu_percent: 50.0,
901 rss_mb: 4096,
902 }],
903 };
904 let json = serde_json::to_string(&gpu).expect("serialize");
905 let restored: GpuTelemetry = serde_json::from_str(&json).expect("deserialize");
906 assert_eq!(restored.device_name, "RTX 4090");
907 assert_eq!(restored.processes.len(), 1);
908 assert_eq!(restored.processes[0].pid, 1234);
909 }
910}