Skip to main content

entrenar/monitor/tui/
state.rs

1//! Training State for IPC (SPEC-FT-001 Section 10.1)
2//!
3//! Atomic state updates written by the trainer, read by the TUI monitor.
4//! Uses JSON file as the IPC mechanism for simplicity and portability.
5
6use serde::{Deserialize, Serialize};
7use std::fs::{self, File};
8use std::io::{BufReader, BufWriter};
9use std::path::Path;
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12/// Process using GPU resources
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct GpuProcessInfo {
15    /// Process ID
16    pub pid: u32,
17    /// Full path to executable
18    pub exe_path: String,
19    /// GPU memory used by this process in MB
20    pub gpu_memory_mb: u64,
21    /// CPU usage percentage (0-100)
22    pub cpu_percent: f32,
23    /// Resident set size (RSS) in MB
24    pub rss_mb: u64,
25}
26
27/// GPU telemetry snapshot (NVML-inspired)
28#[derive(Debug, Clone, Serialize, Deserialize, Default)]
29pub struct GpuTelemetry {
30    /// GPU device name (e.g., "RTX 4090")
31    pub device_name: String,
32    /// GPU utilization percentage (0-100)
33    pub utilization_percent: f32,
34    /// VRAM used in GB
35    pub vram_used_gb: f32,
36    /// VRAM total in GB
37    pub vram_total_gb: f32,
38    /// Temperature in Celsius
39    pub temperature_celsius: f32,
40    /// Power draw in watts
41    pub power_watts: f32,
42    /// Power limit in watts
43    pub power_limit_watts: f32,
44    /// Processes using GPU
45    #[serde(default)]
46    pub processes: Vec<GpuProcessInfo>,
47}
48
49impl GpuTelemetry {
50    /// VRAM utilization as percentage
51    pub fn vram_percent(&self) -> f32 {
52        if self.vram_total_gb > 0.0 {
53            (self.vram_used_gb / self.vram_total_gb) * 100.0
54        } else {
55            0.0
56        }
57    }
58
59    /// Check if thermal throttling is likely
60    pub fn is_thermal_throttling(&self) -> bool {
61        self.temperature_celsius > 83.0
62    }
63
64    /// Check if power limited
65    pub fn is_power_limited(&self) -> bool {
66        self.power_watts >= self.power_limit_watts * 0.95
67    }
68}
69
70/// Sample peek for live decoding visualization
71#[derive(Debug, Clone, Serialize, Deserialize, Default)]
72pub struct SamplePeek {
73    /// Input function code (truncated for display)
74    pub input_preview: String,
75    /// Target test code (truncated for display)
76    pub target_preview: String,
77    /// Generated test code (truncated for display)
78    pub generated_preview: String,
79    /// Token match percentage (0-100)
80    pub token_match_percent: f32,
81}
82
83/// Complete training state snapshot
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct TrainingSnapshot {
86    /// Unix timestamp in milliseconds
87    pub timestamp_ms: u64,
88    /// Current epoch (1-indexed)
89    pub epoch: usize,
90    /// Total epochs
91    pub total_epochs: usize,
92    /// Current step within epoch
93    pub step: usize,
94    /// Total steps per epoch
95    pub steps_per_epoch: usize,
96    /// Current loss value
97    pub loss: f32,
98    /// Loss history (last N values for sparkline)
99    pub loss_history: Vec<f32>,
100    /// Current learning rate
101    pub learning_rate: f32,
102    /// Learning rate history (per-step, for epoch summaries)
103    #[serde(default)]
104    pub lr_history: Vec<f32>,
105    /// Gradient norm
106    pub gradient_norm: f32,
107    /// Training accuracy (0.0 to 1.0)
108    #[serde(default)]
109    pub accuracy: f32,
110    /// Throughput in tokens per second
111    pub tokens_per_second: f32,
112    /// Throughput in samples per second
113    #[serde(default)]
114    pub samples_per_second: f32,
115    /// Training start timestamp (ms)
116    pub start_timestamp_ms: u64,
117    /// GPU telemetry (optional)
118    pub gpu: Option<GpuTelemetry>,
119    /// Sample peek (optional)
120    pub sample: Option<SamplePeek>,
121    /// Training status
122    pub status: TrainingStatus,
123    /// Experiment name/ID
124    pub experiment_id: String,
125    /// Model name
126    pub model_name: String,
127    /// Model path (e.g., path to .safetensors or .gguf)
128    #[serde(default)]
129    pub model_path: String,
130    /// Optimizer name (e.g., "AdamW", "SGD")
131    #[serde(default)]
132    pub optimizer_name: String,
133    /// Batch size
134    #[serde(default)]
135    pub batch_size: usize,
136    /// Checkpoint path (where checkpoints are saved)
137    #[serde(default)]
138    pub checkpoint_path: String,
139    /// Executable path (path to training binary)
140    #[serde(default)]
141    pub executable_path: String,
142}
143
144/// Training status enum
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146pub enum TrainingStatus {
147    /// Training is initializing
148    Initializing,
149    /// Training is running
150    Running,
151    /// Training is paused
152    Paused,
153    /// Training completed successfully
154    Completed,
155    /// Training failed with error
156    Failed(String),
157}
158
159impl Default for TrainingSnapshot {
160    fn default() -> Self {
161        let now =
162            SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_millis() as u64).unwrap_or(0);
163
164        Self {
165            timestamp_ms: now,
166            epoch: 0,
167            total_epochs: 0,
168            step: 0,
169            steps_per_epoch: 0,
170            loss: 0.0,
171            loss_history: Vec::new(),
172            learning_rate: 0.0,
173            lr_history: Vec::new(),
174            gradient_norm: 0.0,
175            accuracy: 0.0,
176            tokens_per_second: 0.0,
177            samples_per_second: 0.0,
178            start_timestamp_ms: now,
179            gpu: None,
180            sample: None,
181            status: TrainingStatus::Initializing,
182            experiment_id: String::new(),
183            model_name: String::new(),
184            model_path: String::new(),
185            optimizer_name: String::new(),
186            batch_size: 0,
187            checkpoint_path: String::new(),
188            executable_path: String::new(),
189        }
190    }
191}
192
193impl TrainingSnapshot {
194    /// Calculate elapsed time since training start
195    /// Uses the snapshot's timestamp_ms for deterministic/reproducible output (ENT-140)
196    pub fn elapsed(&self) -> Duration {
197        Duration::from_millis(self.timestamp_ms.saturating_sub(self.start_timestamp_ms))
198    }
199
200    /// Calculate estimated remaining time
201    pub fn estimated_remaining(&self) -> Option<Duration> {
202        if self.tokens_per_second <= 0.0 {
203            return None;
204        }
205
206        let total_steps = self.total_epochs * self.steps_per_epoch;
207        let completed_steps = (self.epoch.saturating_sub(1)) * self.steps_per_epoch + self.step;
208
209        if completed_steps == 0 || total_steps == 0 {
210            return None;
211        }
212
213        let progress = completed_steps as f64 / total_steps as f64;
214        if progress >= 1.0 {
215            return Some(Duration::ZERO);
216        }
217
218        let elapsed_ms = self.timestamp_ms.saturating_sub(self.start_timestamp_ms);
219        let total_estimated_ms = (elapsed_ms as f64 / progress) as u64;
220        let remaining_ms = total_estimated_ms.saturating_sub(elapsed_ms);
221
222        Some(Duration::from_millis(remaining_ms))
223    }
224
225    /// Global step (epoch * steps_per_epoch + step)
226    pub fn global_step(&self) -> usize {
227        (self.epoch.saturating_sub(1)) * self.steps_per_epoch + self.step
228    }
229
230    /// Progress percentage (0-100)
231    pub fn progress_percent(&self) -> f32 {
232        let total = self.total_epochs * self.steps_per_epoch;
233        if total == 0 {
234            return 0.0;
235        }
236        (self.global_step() as f32 / total as f32) * 100.0
237    }
238
239    /// Compute loss trend from recent history
240    ///
241    /// Returns:
242    /// - `LossTrend::Decreasing` if loss is going down (good)
243    /// - `LossTrend::Stable` if loss is plateauing
244    /// - `LossTrend::Increasing` if loss is going up (bad)
245    /// - `LossTrend::Unknown` if not enough data
246    pub fn loss_trend(&self) -> LossTrend {
247        // Need at least 5 samples to compute trend
248        if self.loss_history.len() < 5 {
249            return LossTrend::Unknown;
250        }
251
252        // Use last 10 samples (or all if less)
253        let window = self.loss_history.len().min(10);
254        let recent = &self.loss_history[self.loss_history.len() - window..];
255
256        // Compare first half vs second half
257        let mid = window / 2;
258        let first_half: f32 = recent[..mid].iter().sum::<f32>() / mid as f32;
259        let second_half: f32 = recent[mid..].iter().sum::<f32>() / (window - mid) as f32;
260
261        // Calculate relative change
262        let change = (second_half - first_half) / first_half.abs().max(1e-6);
263
264        // Threshold: 2% change considered significant
265        const THRESHOLD: f32 = 0.02;
266
267        if change < -THRESHOLD {
268            LossTrend::Decreasing
269        } else if change > THRESHOLD {
270            LossTrend::Increasing
271        } else {
272            LossTrend::Stable
273        }
274    }
275}
276
277/// Loss trend direction
278#[derive(Debug, Clone, Copy, PartialEq, Eq)]
279pub enum LossTrend {
280    /// Loss is decreasing (good)
281    Decreasing,
282    /// Loss is stable/plateauing
283    Stable,
284    /// Loss is increasing (bad)
285    Increasing,
286    /// Not enough data to determine
287    Unknown,
288}
289
290impl LossTrend {
291    /// Get Unicode arrow for display
292    pub fn arrow(&self) -> &'static str {
293        match self {
294            LossTrend::Decreasing => "↓",
295            LossTrend::Stable => "→",
296            LossTrend::Increasing => "↑",
297            LossTrend::Unknown => "?",
298        }
299    }
300
301    /// Get description
302    pub fn description(&self) -> &'static str {
303        match self {
304            LossTrend::Decreasing => "decreasing",
305            LossTrend::Stable => "stable",
306            LossTrend::Increasing => "increasing",
307            LossTrend::Unknown => "unknown",
308        }
309    }
310}
311
312/// Training state manager for IPC
313///
314/// Handles atomic read/write of training snapshots via JSON file.
315pub struct TrainingState {
316    /// Path to the state file
317    state_path: std::path::PathBuf,
318    /// Last read snapshot (cached)
319    last_snapshot: Option<TrainingSnapshot>,
320    /// Last modification time
321    last_modified: Option<std::time::SystemTime>,
322}
323
324impl TrainingState {
325    /// Create a new training state manager
326    ///
327    /// # Arguments
328    ///
329    /// * `experiment_dir` - Path to the experiment directory
330    pub fn new<P: AsRef<Path>>(experiment_dir: P) -> Self {
331        let state_path = experiment_dir.as_ref().join("training_state.json");
332        Self { state_path, last_snapshot: None, last_modified: None }
333    }
334
335    /// Write a training snapshot atomically
336    ///
337    /// Uses staged write + rename for atomicity.
338    pub fn write(&self, snapshot: &TrainingSnapshot) -> std::io::Result<()> {
339        // Ensure parent directory exists
340        if let Some(parent) = self.state_path.parent() {
341            fs::create_dir_all(parent)?;
342        }
343
344        // Stage writes via intermediate file for atomic persistence
345        let temp_path = self.state_path.with_extension("json.tmp");
346        let file = File::create(&temp_path)?;
347        let writer = BufWriter::new(file);
348        serde_json::to_writer_pretty(writer, snapshot)
349            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
350
351        // Atomic rename
352        fs::rename(&temp_path, &self.state_path)?;
353
354        Ok(())
355    }
356
357    /// Read the current training snapshot
358    ///
359    /// Returns `None` if the state file doesn't exist yet.
360    pub fn read(&mut self) -> std::io::Result<Option<TrainingSnapshot>> {
361        if !self.state_path.exists() {
362            return Ok(None);
363        }
364
365        // Check if file was modified since last read
366        let metadata = fs::metadata(&self.state_path)?;
367        let modified = metadata.modified()?;
368
369        if self.last_modified == Some(modified) {
370            // Return cached snapshot if file hasn't changed
371            return Ok(self.last_snapshot.clone());
372        }
373
374        // Read and parse
375        let file = File::open(&self.state_path)?;
376        let reader = BufReader::new(file);
377        let snapshot: TrainingSnapshot = serde_json::from_reader(reader)
378            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
379
380        // Update cache
381        self.last_snapshot = Some(snapshot.clone());
382        self.last_modified = Some(modified);
383
384        Ok(Some(snapshot))
385    }
386
387    /// Check if training state file exists
388    pub fn exists(&self) -> bool {
389        self.state_path.exists()
390    }
391
392    /// Get the state file path
393    pub fn path(&self) -> &Path {
394        &self.state_path
395    }
396
397    /// Wait for the state file to appear (with timeout)
398    pub fn wait_for_state(&mut self, timeout: Duration) -> std::io::Result<bool> {
399        let start = Instant::now();
400        while start.elapsed() < timeout {
401            if self.exists() {
402                return Ok(true);
403            }
404            std::thread::sleep(Duration::from_millis(100));
405        }
406        Ok(false)
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use proptest::prelude::*;
414    use tempfile::TempDir;
415
416    #[test]
417    fn test_training_snapshot_default() {
418        let snapshot = TrainingSnapshot::default();
419        assert_eq!(snapshot.epoch, 0);
420        assert_eq!(snapshot.status, TrainingStatus::Initializing);
421    }
422
423    #[test]
424    fn test_training_snapshot_progress() {
425        let mut snapshot = TrainingSnapshot::default();
426        snapshot.epoch = 2;
427        snapshot.total_epochs = 10;
428        snapshot.step = 50;
429        snapshot.steps_per_epoch = 100;
430
431        // Global step = (2-1) * 100 + 50 = 150
432        assert_eq!(snapshot.global_step(), 150);
433
434        // Progress = 150 / 1000 = 15%
435        assert!((snapshot.progress_percent() - 15.0).abs() < 0.01);
436    }
437
438    #[test]
439    fn test_gpu_telemetry_vram_percent() {
440        let gpu = GpuTelemetry { vram_used_gb: 4.0, vram_total_gb: 24.0, ..Default::default() };
441        assert!((gpu.vram_percent() - 16.67).abs() < 0.1);
442    }
443
444    #[test]
445    fn test_training_state_write_read() {
446        let temp_dir = TempDir::new().expect("temp file creation should succeed");
447        let mut state = TrainingState::new(temp_dir.path());
448
449        let snapshot = TrainingSnapshot {
450            epoch: 5,
451            total_epochs: 10,
452            loss: 0.42,
453            status: TrainingStatus::Running,
454            ..Default::default()
455        };
456
457        state.write(&snapshot).expect("file write should succeed");
458        assert!(state.exists());
459
460        let read_snapshot =
461            state.read().expect("file read should succeed").expect("file read should succeed");
462        assert_eq!(read_snapshot.epoch, 5);
463        assert!((read_snapshot.loss - 0.42).abs() < 0.001);
464    }
465
466    #[test]
467    fn test_training_state_caching() {
468        let temp_dir = TempDir::new().expect("temp file creation should succeed");
469        let mut state = TrainingState::new(temp_dir.path());
470
471        let snapshot = TrainingSnapshot { epoch: 1, ..Default::default() };
472
473        state.write(&snapshot).expect("file write should succeed");
474
475        // First read
476        let _ = state.read().expect("file read should succeed");
477
478        // Second read should return cached
479        let cached =
480            state.read().expect("file read should succeed").expect("file read should succeed");
481        assert_eq!(cached.epoch, 1);
482    }
483
484    // Property-based tests for monitoring integration (ENT-121)
485
486    proptest! {
487        /// TrainingSnapshot JSON serialization round-trip
488        #[test]
489        fn prop_snapshot_json_roundtrip(
490            epoch in 1usize..1000,
491            total_epochs in 1usize..100,
492            step in 0usize..10000,
493            steps_per_epoch in 1usize..10000,
494            loss in 0.0f32..100.0,
495            learning_rate in 1e-10f32..1.0,
496            gradient_norm in 0.0f32..1000.0,
497            tokens_per_second in 0.0f32..10000.0,
498        ) {
499            let snapshot = TrainingSnapshot {
500                timestamp_ms: 12345678,
501                epoch,
502                total_epochs,
503                step,
504                steps_per_epoch,
505                loss,
506                loss_history: vec![loss * 1.1, loss * 1.05, loss],
507                learning_rate,
508                lr_history: vec![learning_rate; 3],
509                gradient_norm,
510                accuracy: 0.0,
511                tokens_per_second,
512                samples_per_second: 0.0,
513                start_timestamp_ms: 12345000,
514                gpu: None,
515                sample: None,
516                status: TrainingStatus::Running,
517                experiment_id: "test".to_string(),
518                model_name: "model".to_string(),
519                model_path: String::new(),
520                optimizer_name: "AdamW".to_string(),
521                batch_size: 4,
522                checkpoint_path: String::new(),
523                executable_path: String::new(),
524            };
525
526            // Serialize
527            let json = serde_json::to_string(&snapshot).expect("JSON serialization should succeed");
528
529            // Deserialize
530            let restored: TrainingSnapshot = serde_json::from_str(&json).expect("JSON deserialization should succeed");
531
532            // Verify all fields preserved
533            prop_assert_eq!(restored.epoch, epoch);
534            prop_assert_eq!(restored.total_epochs, total_epochs);
535            prop_assert_eq!(restored.step, step);
536            prop_assert_eq!(restored.steps_per_epoch, steps_per_epoch);
537            prop_assert!((restored.loss - loss).abs() < 1e-5);
538            prop_assert!((restored.learning_rate - learning_rate).abs() < 1e-10);
539            prop_assert!((restored.gradient_norm - gradient_norm).abs() < 1e-5);
540        }
541
542        /// Loss trend detection is consistent with loss history direction
543        #[test]
544        fn prop_loss_trend_consistent(
545            base_loss in 1.0f32..10.0,
546            trend_factor in -0.1f32..0.1,
547        ) {
548            // Generate 10 loss values with consistent trend
549            // Positive factor = loss going up, Negative factor = loss going down
550            let history: Vec<f32> = (0..10)
551                .map(|i| base_loss + (i as f32 * trend_factor))
552                .collect();
553
554            let snapshot = TrainingSnapshot {
555                loss_history: history,
556                ..Default::default()
557            };
558
559            let trend = snapshot.loss_trend();
560
561            // With 10 samples and consistent trend, we should detect it
562            // Positive factor = loss increasing over time = LossTrend::Increasing
563            // Negative factor = loss decreasing over time = LossTrend::Decreasing
564            if trend_factor > 0.05 {
565                prop_assert_eq!(trend, LossTrend::Increasing);
566            } else if trend_factor < -0.05 {
567                prop_assert_eq!(trend, LossTrend::Decreasing);
568            }
569            // Stable if |trend_factor| <= 0.05 (within threshold)
570        }
571
572        /// GPU telemetry VRAM percentage is always 0-100
573        #[test]
574        fn prop_gpu_vram_percent_bounded(
575            vram_used in 0.0f32..100.0,
576            vram_total in 1.0f32..100.0,
577        ) {
578            let gpu = GpuTelemetry {
579                vram_used_gb: vram_used.min(vram_total),
580                vram_total_gb: vram_total,
581                ..Default::default()
582            };
583
584            let percent = gpu.vram_percent();
585            prop_assert!(percent >= 0.0);
586            prop_assert!(percent <= 100.0);
587        }
588
589        /// Progress percent is always 0-100
590        #[test]
591        fn prop_progress_percent_bounded(
592            epoch in 1usize..100,
593            total_epochs in 1usize..100,
594            step in 0usize..1000,
595            steps_per_epoch in 1usize..1000,
596        ) {
597            let epoch = epoch.min(total_epochs);
598            let step = step.min(steps_per_epoch);
599
600            let snapshot = TrainingSnapshot {
601                epoch,
602                total_epochs,
603                step,
604                steps_per_epoch,
605                ..Default::default()
606            };
607
608            let progress = snapshot.progress_percent();
609            prop_assert!(progress >= 0.0);
610            prop_assert!(progress <= 100.0);
611        }
612
613        /// State file write/read preserves all data
614        #[test]
615        fn prop_state_file_roundtrip(
616            epoch in 1usize..100,
617            loss in 0.0f32..100.0,
618            lr in 1e-6f32..0.1,
619        ) {
620            let temp_dir = TempDir::new().expect("temp file creation should succeed");
621            let mut state = TrainingState::new(temp_dir.path());
622
623            let snapshot = TrainingSnapshot {
624                epoch,
625                total_epochs: 10,
626                loss,
627                learning_rate: lr,
628                status: TrainingStatus::Running,
629                ..Default::default()
630            };
631
632            state.write(&snapshot).expect("file write should succeed");
633
634            // Clear cache and re-read
635            state.last_modified = None;
636            let restored = state.read().expect("file read should succeed").expect("file read should succeed");
637
638            prop_assert_eq!(restored.epoch, epoch);
639            prop_assert!((restored.loss - loss).abs() < 1e-5);
640            prop_assert!((restored.learning_rate - lr).abs() < 1e-10);
641        }
642    }
643
644    // ── Additional coverage tests ──
645
646    #[test]
647    fn test_gpu_telemetry_vram_percent_zero_total() {
648        let gpu = GpuTelemetry { vram_used_gb: 4.0, vram_total_gb: 0.0, ..Default::default() };
649        assert!((gpu.vram_percent() - 0.0).abs() < f32::EPSILON);
650    }
651
652    #[test]
653    fn test_gpu_telemetry_thermal_throttling() {
654        let gpu = GpuTelemetry { temperature_celsius: 84.0, ..Default::default() };
655        assert!(gpu.is_thermal_throttling());
656
657        let gpu2 = GpuTelemetry { temperature_celsius: 83.0, ..Default::default() };
658        assert!(!gpu2.is_thermal_throttling());
659
660        let gpu3 = GpuTelemetry { temperature_celsius: 70.0, ..Default::default() };
661        assert!(!gpu3.is_thermal_throttling());
662    }
663
664    #[test]
665    fn test_gpu_telemetry_power_limited() {
666        let gpu =
667            GpuTelemetry { power_watts: 380.0, power_limit_watts: 400.0, ..Default::default() };
668        assert!(gpu.is_power_limited()); // 380 >= 400*0.95 = 380
669
670        let gpu2 =
671            GpuTelemetry { power_watts: 300.0, power_limit_watts: 400.0, ..Default::default() };
672        assert!(!gpu2.is_power_limited());
673    }
674
675    #[test]
676    fn test_training_snapshot_elapsed() {
677        let snapshot =
678            TrainingSnapshot { start_timestamp_ms: 1000, timestamp_ms: 6000, ..Default::default() };
679        assert_eq!(snapshot.elapsed(), Duration::from_millis(5000));
680    }
681
682    #[test]
683    fn test_training_snapshot_elapsed_same_time() {
684        let snapshot =
685            TrainingSnapshot { start_timestamp_ms: 5000, timestamp_ms: 5000, ..Default::default() };
686        assert_eq!(snapshot.elapsed(), Duration::ZERO);
687    }
688
689    #[test]
690    fn test_training_snapshot_estimated_remaining_none_zero_tps() {
691        let snapshot = TrainingSnapshot { tokens_per_second: 0.0, ..Default::default() };
692        assert!(snapshot.estimated_remaining().is_none());
693    }
694
695    #[test]
696    fn test_training_snapshot_estimated_remaining_none_zero_steps() {
697        let snapshot = TrainingSnapshot {
698            tokens_per_second: 100.0,
699            total_epochs: 0,
700            steps_per_epoch: 0,
701            ..Default::default()
702        };
703        assert!(snapshot.estimated_remaining().is_none());
704    }
705
706    #[test]
707    fn test_training_snapshot_estimated_remaining_completed() {
708        let snapshot = TrainingSnapshot {
709            tokens_per_second: 100.0,
710            epoch: 10,
711            total_epochs: 10,
712            step: 100,
713            steps_per_epoch: 100,
714            start_timestamp_ms: 1000,
715            timestamp_ms: 11000,
716            ..Default::default()
717        };
718        let remaining = snapshot.estimated_remaining();
719        assert!(remaining.is_some());
720        assert_eq!(remaining.unwrap(), Duration::ZERO);
721    }
722
723    #[test]
724    fn test_training_snapshot_estimated_remaining_halfway() {
725        let snapshot = TrainingSnapshot {
726            tokens_per_second: 100.0,
727            epoch: 5,
728            total_epochs: 10,
729            step: 50,
730            steps_per_epoch: 100,
731            start_timestamp_ms: 0,
732            timestamp_ms: 10000,
733            ..Default::default()
734        };
735        let remaining = snapshot.estimated_remaining();
736        assert!(remaining.is_some());
737        // 50% complete at 10s elapsed -> ~10s remaining
738        let rem_ms = remaining.unwrap().as_millis();
739        assert!(rem_ms > 5000 && rem_ms < 30000);
740    }
741
742    #[test]
743    fn test_training_snapshot_global_step() {
744        let snapshot =
745            TrainingSnapshot { epoch: 3, steps_per_epoch: 100, step: 42, ..Default::default() };
746        // global_step = (3-1)*100 + 42 = 242
747        assert_eq!(snapshot.global_step(), 242);
748    }
749
750    #[test]
751    fn test_training_snapshot_global_step_first_epoch() {
752        let snapshot =
753            TrainingSnapshot { epoch: 1, steps_per_epoch: 50, step: 10, ..Default::default() };
754        // global_step = (1-1)*50 + 10 = 10
755        assert_eq!(snapshot.global_step(), 10);
756    }
757
758    #[test]
759    fn test_training_snapshot_progress_zero() {
760        let snapshot =
761            TrainingSnapshot { total_epochs: 0, steps_per_epoch: 0, ..Default::default() };
762        assert!((snapshot.progress_percent() - 0.0).abs() < f32::EPSILON);
763    }
764
765    #[test]
766    fn test_loss_trend_unknown_few_samples() {
767        let snapshot = TrainingSnapshot { loss_history: vec![1.0, 2.0, 3.0], ..Default::default() };
768        assert_eq!(snapshot.loss_trend(), LossTrend::Unknown);
769    }
770
771    #[test]
772    fn test_loss_trend_decreasing() {
773        let snapshot = TrainingSnapshot {
774            loss_history: vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
775            ..Default::default()
776        };
777        assert_eq!(snapshot.loss_trend(), LossTrend::Decreasing);
778    }
779
780    #[test]
781    fn test_loss_trend_increasing() {
782        let snapshot = TrainingSnapshot {
783            loss_history: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
784            ..Default::default()
785        };
786        assert_eq!(snapshot.loss_trend(), LossTrend::Increasing);
787    }
788
789    #[test]
790    fn test_loss_trend_stable() {
791        let snapshot = TrainingSnapshot {
792            loss_history: vec![5.0, 5.01, 4.99, 5.0, 5.01, 4.99, 5.0, 5.01, 4.99, 5.0],
793            ..Default::default()
794        };
795        assert_eq!(snapshot.loss_trend(), LossTrend::Stable);
796    }
797
798    #[test]
799    fn test_loss_trend_arrow() {
800        assert_eq!(LossTrend::Decreasing.arrow(), "\u{2193}");
801        assert_eq!(LossTrend::Stable.arrow(), "\u{2192}");
802        assert_eq!(LossTrend::Increasing.arrow(), "\u{2191}");
803        assert_eq!(LossTrend::Unknown.arrow(), "?");
804    }
805
806    #[test]
807    fn test_loss_trend_description() {
808        assert_eq!(LossTrend::Decreasing.description(), "decreasing");
809        assert_eq!(LossTrend::Stable.description(), "stable");
810        assert_eq!(LossTrend::Increasing.description(), "increasing");
811        assert_eq!(LossTrend::Unknown.description(), "unknown");
812    }
813
814    #[test]
815    fn test_training_state_new_path() {
816        let state = TrainingState::new("/tmp/test-exp");
817        assert_eq!(state.path(), std::path::Path::new("/tmp/test-exp/training_state.json"));
818    }
819
820    #[test]
821    fn test_training_state_exists_missing() {
822        let temp_dir = TempDir::new().expect("temp file creation should succeed");
823        let state = TrainingState::new(temp_dir.path().join("nonexistent"));
824        assert!(!state.exists());
825    }
826
827    #[test]
828    fn test_training_state_read_missing_file() {
829        let temp_dir = TempDir::new().expect("temp file creation should succeed");
830        let mut state = TrainingState::new(temp_dir.path().join("nonexistent"));
831        let result = state.read().expect("should not error for missing file");
832        assert!(result.is_none());
833    }
834
835    #[test]
836    fn test_training_state_wait_for_state_already_exists() {
837        let temp_dir = TempDir::new().expect("temp file creation should succeed");
838        let mut state = TrainingState::new(temp_dir.path());
839        let snapshot = TrainingSnapshot::default();
840        state.write(&snapshot).expect("write should succeed");
841
842        let found = state.wait_for_state(Duration::from_millis(100)).expect("ok");
843        assert!(found);
844    }
845
846    #[test]
847    fn test_training_state_wait_for_state_timeout() {
848        let temp_dir = TempDir::new().expect("temp file creation should succeed");
849        let mut state = TrainingState::new(temp_dir.path().join("never-exists"));
850        let found = state.wait_for_state(Duration::from_millis(200)).expect("ok");
851        assert!(!found);
852    }
853
854    #[test]
855    fn test_gpu_process_info_default() {
856        let info = GpuProcessInfo::default();
857        assert_eq!(info.pid, 0);
858        assert!(info.exe_path.is_empty());
859        assert_eq!(info.gpu_memory_mb, 0);
860    }
861
862    #[test]
863    fn test_sample_peek_default() {
864        let sample = SamplePeek::default();
865        assert!(sample.input_preview.is_empty());
866        assert!(sample.target_preview.is_empty());
867        assert!(sample.generated_preview.is_empty());
868        assert!((sample.token_match_percent - 0.0).abs() < f32::EPSILON);
869    }
870
871    #[test]
872    fn test_training_status_equality() {
873        assert_eq!(TrainingStatus::Running, TrainingStatus::Running);
874        assert_eq!(TrainingStatus::Completed, TrainingStatus::Completed);
875        assert_ne!(TrainingStatus::Running, TrainingStatus::Completed);
876        assert_eq!(
877            TrainingStatus::Failed("a".to_string()),
878            TrainingStatus::Failed("a".to_string())
879        );
880        assert_ne!(
881            TrainingStatus::Failed("a".to_string()),
882            TrainingStatus::Failed("b".to_string())
883        );
884    }
885
886    #[test]
887    fn test_gpu_telemetry_serde_roundtrip() {
888        let gpu = GpuTelemetry {
889            device_name: "RTX 4090".to_string(),
890            utilization_percent: 95.0,
891            vram_used_gb: 20.0,
892            vram_total_gb: 24.0,
893            temperature_celsius: 72.0,
894            power_watts: 350.0,
895            power_limit_watts: 400.0,
896            processes: vec![GpuProcessInfo {
897                pid: 1234,
898                exe_path: "/usr/bin/python3".to_string(),
899                gpu_memory_mb: 19000,
900                cpu_percent: 50.0,
901                rss_mb: 4096,
902            }],
903        };
904        let json = serde_json::to_string(&gpu).expect("serialize");
905        let restored: GpuTelemetry = serde_json::from_str(&json).expect("deserialize");
906        assert_eq!(restored.device_name, "RTX 4090");
907        assert_eq!(restored.processes.len(), 1);
908        assert_eq!(restored.processes[0].pid, 1234);
909    }
910}