Skip to main content

entrenar/finetune/
reproducibility.rs

1//! Reproducibility configuration and experiment locking
2//!
3//! Ensures scientific reproducibility of fine-tuning experiments.
4
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8/// Reproducibility configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ReproducibilityConfig {
11    /// Random seed for all operations
12    pub seed: u64,
13    /// Use deterministic algorithms (may be slower)
14    pub deterministic_algorithms: bool,
15    /// Disable cuDNN benchmark mode
16    pub cudnn_benchmark: bool,
17    /// Enable cuDNN deterministic mode
18    pub cudnn_deterministic: bool,
19}
20
21impl Default for ReproducibilityConfig {
22    fn default() -> Self {
23        Self {
24            seed: 42,
25            deterministic_algorithms: true,
26            cudnn_benchmark: false,
27            cudnn_deterministic: true,
28        }
29    }
30}
31
32impl ReproducibilityConfig {
33    /// Create config with specific seed
34    #[must_use]
35    pub const fn with_seed(seed: u64) -> Self {
36        Self {
37            seed,
38            deterministic_algorithms: true,
39            cudnn_benchmark: false,
40            cudnn_deterministic: true,
41        }
42    }
43
44    /// Disable deterministic mode (faster but not reproducible)
45    #[must_use]
46    pub const fn non_deterministic(mut self) -> Self {
47        self.deterministic_algorithms = false;
48        self.cudnn_benchmark = true;
49        self.cudnn_deterministic = false;
50        self
51    }
52
53    /// Apply reproducibility settings to environment
54    #[allow(clippy::disallowed_methods)] // Intentional: CUDA env vars for reproducibility
55    pub fn apply(&self) {
56        // Set environment variables for PyTorch/CUDA if used
57        std::env::set_var("PYTHONHASHSEED", self.seed.to_string());
58        std::env::set_var("CUBLAS_WORKSPACE_CONFIG", ":4096:8");
59
60        if self.cudnn_deterministic {
61            std::env::set_var("CUDNN_DETERMINISTIC", "1");
62        }
63
64        if !self.cudnn_benchmark {
65            std::env::set_var("CUDNN_BENCHMARK", "0");
66        }
67    }
68}
69
70/// Experiment lockfile for reproducibility
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ExperimentLock {
73    /// Experiment ID
74    pub experiment_id: String,
75    /// Timestamp (ISO 8601)
76    pub timestamp: String,
77    /// Git commit hash
78    pub git_commit: Option<String>,
79    /// Rust version
80    pub rust_version: String,
81    /// CUDA version (if available)
82    pub cuda_version: Option<String>,
83    /// cuDNN version (if available)
84    pub cudnn_version: Option<String>,
85    /// Dependencies with versions
86    pub dependencies: Vec<DependencyVersion>,
87    /// Reproducibility config
88    pub reproducibility: ReproducibilityConfig,
89    /// Config checksum
90    pub config_checksum: String,
91    /// Model checksum
92    pub model_checksum: Option<String>,
93    /// Dataset checksum
94    pub dataset_checksum: Option<String>,
95}
96
97/// Dependency version info
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct DependencyVersion {
100    /// Crate/package name
101    pub name: String,
102    /// Version string
103    pub version: String,
104}
105
106impl ExperimentLock {
107    /// Create new experiment lock
108    #[must_use]
109    pub fn new(experiment_id: impl Into<String>) -> Self {
110        Self {
111            experiment_id: experiment_id.into(),
112            timestamp: chrono::Utc::now().to_rfc3339(),
113            git_commit: Self::get_git_commit(),
114            rust_version: Self::get_rust_version(),
115            cuda_version: Self::get_cuda_version(),
116            cudnn_version: None,
117            dependencies: Self::get_dependencies(),
118            reproducibility: ReproducibilityConfig::default(),
119            config_checksum: String::new(),
120            model_checksum: None,
121            dataset_checksum: None,
122        }
123    }
124
125    /// Set reproducibility config
126    #[must_use]
127    pub fn with_reproducibility(mut self, config: ReproducibilityConfig) -> Self {
128        self.reproducibility = config;
129        self
130    }
131
132    /// Set config checksum
133    #[must_use]
134    pub fn with_config_checksum(mut self, checksum: impl Into<String>) -> Self {
135        self.config_checksum = checksum.into();
136        self
137    }
138
139    /// Set model checksum
140    #[must_use]
141    pub fn with_model_checksum(mut self, checksum: impl Into<String>) -> Self {
142        self.model_checksum = Some(checksum.into());
143        self
144    }
145
146    /// Set dataset checksum
147    #[must_use]
148    pub fn with_dataset_checksum(mut self, checksum: impl Into<String>) -> Self {
149        self.dataset_checksum = Some(checksum.into());
150        self
151    }
152
153    /// Save lockfile to path
154    ///
155    /// # Errors
156    ///
157    /// Returns error if file cannot be written.
158    pub fn save(&self, path: &Path) -> Result<(), std::io::Error> {
159        let yaml = serde_yaml::to_string(self).map_err(std::io::Error::other)?;
160        std::fs::write(path, yaml)
161    }
162
163    /// Load lockfile from path
164    ///
165    /// # Errors
166    ///
167    /// Returns error if file cannot be read or parsed.
168    pub fn load(path: &Path) -> Result<Self, std::io::Error> {
169        let content = std::fs::read_to_string(path)?;
170        serde_yaml::from_str(&content)
171            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
172    }
173
174    /// Get current git commit hash
175    fn get_git_commit() -> Option<String> {
176        std::process::Command::new("git")
177            .args(["rev-parse", "HEAD"])
178            .output()
179            .ok()
180            .filter(|o| o.status.success())
181            .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
182    }
183
184    /// Get Rust version
185    fn get_rust_version() -> String {
186        std::process::Command::new("rustc").arg("--version").output().ok().map_or_else(
187            || "unknown".into(),
188            |o| String::from_utf8_lossy(&o.stdout).trim().to_string(),
189        )
190    }
191
192    /// Get CUDA version
193    fn get_cuda_version() -> Option<String> {
194        std::process::Command::new("nvcc")
195            .arg("--version")
196            .output()
197            .ok()
198            .filter(|o| o.status.success())
199            .and_then(|o| {
200                let stdout = String::from_utf8_lossy(&o.stdout);
201                stdout.lines().find(|l| l.contains("release")).map(|l| l.trim().to_string())
202            })
203    }
204
205    /// Get key dependencies from Cargo.lock
206    fn get_dependencies() -> Vec<DependencyVersion> {
207        // Read from Cargo.lock if available
208        let cargo_lock = Path::new("Cargo.lock");
209        if !cargo_lock.exists() {
210            return Vec::new();
211        }
212
213        // Parse relevant dependencies
214        let key_deps = ["entrenar", "trueno", "serde", "ndarray"];
215        let mut deps = Vec::new();
216
217        if let Ok(content) = std::fs::read_to_string(cargo_lock) {
218            let mut current_name = String::new();
219            for line in content.lines() {
220                if line.starts_with("name = ") {
221                    current_name = line
222                        .strip_prefix("name = \"")
223                        .and_then(|s| s.strip_suffix('"'))
224                        .unwrap_or("")
225                        .to_string();
226                } else if line.starts_with("version = ")
227                    && !current_name.is_empty()
228                    && key_deps.contains(&current_name.as_str())
229                {
230                    let version = line
231                        .strip_prefix("version = \"")
232                        .and_then(|s| s.strip_suffix('"'))
233                        .unwrap_or("")
234                        .to_string();
235                    deps.push(DependencyVersion { name: current_name.clone(), version });
236                }
237            }
238        }
239
240        deps
241    }
242
243    /// Verify current environment matches lockfile
244    #[must_use]
245    pub fn verify(&self) -> VerificationResult {
246        let mut result = VerificationResult::default();
247
248        // Check git commit
249        if let Some(ref expected) = self.git_commit {
250            if let Some(current) = Self::get_git_commit() {
251                if &current != expected {
252                    result.git_mismatch = Some((expected.clone(), current));
253                }
254            }
255        }
256
257        // Check Rust version
258        let current_rust = Self::get_rust_version();
259        if current_rust != self.rust_version {
260            result.rust_mismatch = Some((self.rust_version.clone(), current_rust));
261        }
262
263        // Check CUDA version
264        if let Some(ref expected) = self.cuda_version {
265            if let Some(current) = Self::get_cuda_version() {
266                if &current != expected {
267                    result.cuda_mismatch = Some((expected.clone(), current));
268                }
269            }
270        }
271
272        result
273    }
274}
275
276/// Verification result
277#[allow(clippy::struct_field_names)]
278#[derive(Debug, Clone, Default)]
279pub struct VerificationResult {
280    /// Git commit mismatch (expected, actual)
281    pub git_mismatch: Option<(String, String)>,
282    /// Rust version mismatch (expected, actual)
283    pub rust_mismatch: Option<(String, String)>,
284    /// CUDA version mismatch (expected, actual)
285    pub cuda_mismatch: Option<(String, String)>,
286}
287
288impl VerificationResult {
289    /// Check if verification passed
290    #[must_use]
291    pub fn passed(&self) -> bool {
292        self.git_mismatch.is_none() && self.rust_mismatch.is_none() && self.cuda_mismatch.is_none()
293    }
294
295    /// Get list of warnings
296    #[must_use]
297    pub fn warnings(&self) -> Vec<String> {
298        let mut warnings = Vec::new();
299
300        if let Some((expected, actual)) = &self.git_mismatch {
301            warnings.push(format!(
302                "Git commit mismatch: expected {}, got {}",
303                &expected[..8.min(expected.len())],
304                &actual[..8.min(actual.len())]
305            ));
306        }
307
308        if let Some((expected, actual)) = &self.rust_mismatch {
309            warnings.push(format!("Rust version mismatch: expected {expected}, got {actual}"));
310        }
311
312        if let Some((expected, actual)) = &self.cuda_mismatch {
313            warnings.push(format!("CUDA version mismatch: expected {expected}, got {actual}"));
314        }
315
316        warnings
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_reproducibility_config_default() {
326        let config = ReproducibilityConfig::default();
327        assert_eq!(config.seed, 42);
328        assert!(config.deterministic_algorithms);
329        assert!(!config.cudnn_benchmark);
330        assert!(config.cudnn_deterministic);
331    }
332
333    #[test]
334    fn test_reproducibility_config_with_seed() {
335        let config = ReproducibilityConfig::with_seed(123);
336        assert_eq!(config.seed, 123);
337        assert!(config.deterministic_algorithms);
338    }
339
340    #[test]
341    fn test_reproducibility_config_non_deterministic() {
342        let config = ReproducibilityConfig::default().non_deterministic();
343        assert!(!config.deterministic_algorithms);
344        assert!(config.cudnn_benchmark);
345        assert!(!config.cudnn_deterministic);
346    }
347
348    #[test]
349    fn test_experiment_lock_new() {
350        let lock = ExperimentLock::new("test-001");
351        assert_eq!(lock.experiment_id, "test-001");
352        assert!(!lock.timestamp.is_empty());
353        assert!(!lock.rust_version.is_empty());
354    }
355
356    #[test]
357    fn test_experiment_lock_with_checksums() {
358        let lock = ExperimentLock::new("test")
359            .with_config_checksum("abc123")
360            .with_model_checksum("def456")
361            .with_dataset_checksum("ghi789");
362
363        assert_eq!(lock.config_checksum, "abc123");
364        assert_eq!(lock.model_checksum, Some("def456".into()));
365        assert_eq!(lock.dataset_checksum, Some("ghi789".into()));
366    }
367
368    #[test]
369    fn test_experiment_lock_serialization() {
370        let lock =
371            ExperimentLock::new("test").with_reproducibility(ReproducibilityConfig::with_seed(100));
372
373        let yaml = serde_yaml::to_string(&lock).expect("operation should succeed");
374        assert!(yaml.contains("experiment_id: test"));
375        assert!(yaml.contains("seed: 100"));
376
377        let restored: ExperimentLock =
378            serde_yaml::from_str(&yaml).expect("lock acquisition should succeed");
379        assert_eq!(restored.experiment_id, "test");
380        assert_eq!(restored.reproducibility.seed, 100);
381    }
382
383    #[test]
384    fn test_verification_result_passed() {
385        let result = VerificationResult::default();
386        assert!(result.passed());
387        assert!(result.warnings().is_empty());
388    }
389
390    #[test]
391    fn test_verification_result_with_mismatches() {
392        let result = VerificationResult {
393            git_mismatch: Some(("abc123".into(), "def456".into())),
394            rust_mismatch: None,
395            cuda_mismatch: None,
396        };
397
398        assert!(!result.passed());
399        assert_eq!(result.warnings().len(), 1);
400        assert!(result.warnings()[0].contains("Git commit"));
401    }
402
403    #[test]
404    fn test_dependency_version() {
405        let dep = DependencyVersion { name: "entrenar".into(), version: "0.5.6".into() };
406
407        let json = serde_json::to_string(&dep).expect("JSON serialization should succeed");
408        assert!(json.contains("entrenar"));
409        assert!(json.contains("0.5.6"));
410    }
411
412    #[test]
413    fn test_reproducibility_config_apply() {
414        let config = ReproducibilityConfig::with_seed(12345);
415        config.apply();
416
417        // Verify environment variables were set
418        assert_eq!(std::env::var("PYTHONHASHSEED").expect("operation should succeed"), "12345");
419        assert_eq!(
420            std::env::var("CUBLAS_WORKSPACE_CONFIG").expect("config should be valid"),
421            ":4096:8"
422        );
423    }
424
425    #[test]
426    fn test_experiment_lock_save_load() {
427        let lock = ExperimentLock::new("save-load-test")
428            .with_reproducibility(ReproducibilityConfig::with_seed(999))
429            .with_config_checksum("sha256:test");
430
431        let temp_dir = std::env::temp_dir();
432        let path = temp_dir.join("test_lock.yaml");
433
434        // Save
435        lock.save(&path).expect("save should succeed");
436
437        // Load
438        let loaded = ExperimentLock::load(&path).expect("lock acquisition should succeed");
439        assert_eq!(loaded.experiment_id, "save-load-test");
440        assert_eq!(loaded.reproducibility.seed, 999);
441        assert_eq!(loaded.config_checksum, "sha256:test");
442
443        // Cleanup
444        let _ = std::fs::remove_file(&path);
445    }
446
447    #[test]
448    fn test_experiment_lock_verify() {
449        let lock = ExperimentLock::new("verify-test");
450        let result = lock.verify();
451        // At minimum, the result should be valid
452        let _ = result.passed();
453        let _ = result.warnings();
454    }
455
456    #[test]
457    fn test_verification_result_multiple_warnings() {
458        let result = VerificationResult {
459            git_mismatch: Some(("abc12345".into(), "def67890".into())),
460            rust_mismatch: Some(("1.70.0".into(), "1.75.0".into())),
461            cuda_mismatch: Some(("12.0".into(), "12.1".into())),
462        };
463
464        assert!(!result.passed());
465        let warnings = result.warnings();
466        assert_eq!(warnings.len(), 3);
467        assert!(warnings.iter().any(|w| w.contains("Git")));
468        assert!(warnings.iter().any(|w| w.contains("Rust")));
469        assert!(warnings.iter().any(|w| w.contains("CUDA")));
470    }
471
472    #[test]
473    fn test_experiment_lock_with_all_checksums() {
474        let lock = ExperimentLock::new("checksum-test")
475            .with_config_checksum("sha256:config")
476            .with_model_checksum("sha256:model")
477            .with_dataset_checksum("sha256:dataset");
478
479        assert_eq!(lock.config_checksum, "sha256:config");
480        assert_eq!(lock.model_checksum, Some("sha256:model".into()));
481        assert_eq!(lock.dataset_checksum, Some("sha256:dataset".into()));
482    }
483
484    #[test]
485    fn test_experiment_lock_yaml_format() {
486        let lock =
487            ExperimentLock::new("yaml-test").with_reproducibility(ReproducibilityConfig::default());
488
489        let yaml = serde_yaml::to_string(&lock).expect("operation should succeed");
490        assert!(yaml.contains("experiment_id"));
491        assert!(yaml.contains("timestamp"));
492        assert!(yaml.contains("reproducibility"));
493        assert!(yaml.contains("seed: 42"));
494    }
495}