1use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ReproducibilityConfig {
11 pub seed: u64,
13 pub deterministic_algorithms: bool,
15 pub cudnn_benchmark: bool,
17 pub cudnn_deterministic: bool,
19}
20
21impl Default for ReproducibilityConfig {
22 fn default() -> Self {
23 Self {
24 seed: 42,
25 deterministic_algorithms: true,
26 cudnn_benchmark: false,
27 cudnn_deterministic: true,
28 }
29 }
30}
31
32impl ReproducibilityConfig {
33 #[must_use]
35 pub const fn with_seed(seed: u64) -> Self {
36 Self {
37 seed,
38 deterministic_algorithms: true,
39 cudnn_benchmark: false,
40 cudnn_deterministic: true,
41 }
42 }
43
44 #[must_use]
46 pub const fn non_deterministic(mut self) -> Self {
47 self.deterministic_algorithms = false;
48 self.cudnn_benchmark = true;
49 self.cudnn_deterministic = false;
50 self
51 }
52
53 #[allow(clippy::disallowed_methods)] pub fn apply(&self) {
56 std::env::set_var("PYTHONHASHSEED", self.seed.to_string());
58 std::env::set_var("CUBLAS_WORKSPACE_CONFIG", ":4096:8");
59
60 if self.cudnn_deterministic {
61 std::env::set_var("CUDNN_DETERMINISTIC", "1");
62 }
63
64 if !self.cudnn_benchmark {
65 std::env::set_var("CUDNN_BENCHMARK", "0");
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ExperimentLock {
73 pub experiment_id: String,
75 pub timestamp: String,
77 pub git_commit: Option<String>,
79 pub rust_version: String,
81 pub cuda_version: Option<String>,
83 pub cudnn_version: Option<String>,
85 pub dependencies: Vec<DependencyVersion>,
87 pub reproducibility: ReproducibilityConfig,
89 pub config_checksum: String,
91 pub model_checksum: Option<String>,
93 pub dataset_checksum: Option<String>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct DependencyVersion {
100 pub name: String,
102 pub version: String,
104}
105
106impl ExperimentLock {
107 #[must_use]
109 pub fn new(experiment_id: impl Into<String>) -> Self {
110 Self {
111 experiment_id: experiment_id.into(),
112 timestamp: chrono::Utc::now().to_rfc3339(),
113 git_commit: Self::get_git_commit(),
114 rust_version: Self::get_rust_version(),
115 cuda_version: Self::get_cuda_version(),
116 cudnn_version: None,
117 dependencies: Self::get_dependencies(),
118 reproducibility: ReproducibilityConfig::default(),
119 config_checksum: String::new(),
120 model_checksum: None,
121 dataset_checksum: None,
122 }
123 }
124
125 #[must_use]
127 pub fn with_reproducibility(mut self, config: ReproducibilityConfig) -> Self {
128 self.reproducibility = config;
129 self
130 }
131
132 #[must_use]
134 pub fn with_config_checksum(mut self, checksum: impl Into<String>) -> Self {
135 self.config_checksum = checksum.into();
136 self
137 }
138
139 #[must_use]
141 pub fn with_model_checksum(mut self, checksum: impl Into<String>) -> Self {
142 self.model_checksum = Some(checksum.into());
143 self
144 }
145
146 #[must_use]
148 pub fn with_dataset_checksum(mut self, checksum: impl Into<String>) -> Self {
149 self.dataset_checksum = Some(checksum.into());
150 self
151 }
152
153 pub fn save(&self, path: &Path) -> Result<(), std::io::Error> {
159 let yaml = serde_yaml::to_string(self).map_err(std::io::Error::other)?;
160 std::fs::write(path, yaml)
161 }
162
163 pub fn load(path: &Path) -> Result<Self, std::io::Error> {
169 let content = std::fs::read_to_string(path)?;
170 serde_yaml::from_str(&content)
171 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
172 }
173
174 fn get_git_commit() -> Option<String> {
176 std::process::Command::new("git")
177 .args(["rev-parse", "HEAD"])
178 .output()
179 .ok()
180 .filter(|o| o.status.success())
181 .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
182 }
183
184 fn get_rust_version() -> String {
186 std::process::Command::new("rustc").arg("--version").output().ok().map_or_else(
187 || "unknown".into(),
188 |o| String::from_utf8_lossy(&o.stdout).trim().to_string(),
189 )
190 }
191
192 fn get_cuda_version() -> Option<String> {
194 std::process::Command::new("nvcc")
195 .arg("--version")
196 .output()
197 .ok()
198 .filter(|o| o.status.success())
199 .and_then(|o| {
200 let stdout = String::from_utf8_lossy(&o.stdout);
201 stdout.lines().find(|l| l.contains("release")).map(|l| l.trim().to_string())
202 })
203 }
204
205 fn get_dependencies() -> Vec<DependencyVersion> {
207 let cargo_lock = Path::new("Cargo.lock");
209 if !cargo_lock.exists() {
210 return Vec::new();
211 }
212
213 let key_deps = ["entrenar", "trueno", "serde", "ndarray"];
215 let mut deps = Vec::new();
216
217 if let Ok(content) = std::fs::read_to_string(cargo_lock) {
218 let mut current_name = String::new();
219 for line in content.lines() {
220 if line.starts_with("name = ") {
221 current_name = line
222 .strip_prefix("name = \"")
223 .and_then(|s| s.strip_suffix('"'))
224 .unwrap_or("")
225 .to_string();
226 } else if line.starts_with("version = ")
227 && !current_name.is_empty()
228 && key_deps.contains(¤t_name.as_str())
229 {
230 let version = line
231 .strip_prefix("version = \"")
232 .and_then(|s| s.strip_suffix('"'))
233 .unwrap_or("")
234 .to_string();
235 deps.push(DependencyVersion { name: current_name.clone(), version });
236 }
237 }
238 }
239
240 deps
241 }
242
243 #[must_use]
245 pub fn verify(&self) -> VerificationResult {
246 let mut result = VerificationResult::default();
247
248 if let Some(ref expected) = self.git_commit {
250 if let Some(current) = Self::get_git_commit() {
251 if ¤t != expected {
252 result.git_mismatch = Some((expected.clone(), current));
253 }
254 }
255 }
256
257 let current_rust = Self::get_rust_version();
259 if current_rust != self.rust_version {
260 result.rust_mismatch = Some((self.rust_version.clone(), current_rust));
261 }
262
263 if let Some(ref expected) = self.cuda_version {
265 if let Some(current) = Self::get_cuda_version() {
266 if ¤t != expected {
267 result.cuda_mismatch = Some((expected.clone(), current));
268 }
269 }
270 }
271
272 result
273 }
274}
275
276#[allow(clippy::struct_field_names)]
278#[derive(Debug, Clone, Default)]
279pub struct VerificationResult {
280 pub git_mismatch: Option<(String, String)>,
282 pub rust_mismatch: Option<(String, String)>,
284 pub cuda_mismatch: Option<(String, String)>,
286}
287
288impl VerificationResult {
289 #[must_use]
291 pub fn passed(&self) -> bool {
292 self.git_mismatch.is_none() && self.rust_mismatch.is_none() && self.cuda_mismatch.is_none()
293 }
294
295 #[must_use]
297 pub fn warnings(&self) -> Vec<String> {
298 let mut warnings = Vec::new();
299
300 if let Some((expected, actual)) = &self.git_mismatch {
301 warnings.push(format!(
302 "Git commit mismatch: expected {}, got {}",
303 &expected[..8.min(expected.len())],
304 &actual[..8.min(actual.len())]
305 ));
306 }
307
308 if let Some((expected, actual)) = &self.rust_mismatch {
309 warnings.push(format!("Rust version mismatch: expected {expected}, got {actual}"));
310 }
311
312 if let Some((expected, actual)) = &self.cuda_mismatch {
313 warnings.push(format!("CUDA version mismatch: expected {expected}, got {actual}"));
314 }
315
316 warnings
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_reproducibility_config_default() {
326 let config = ReproducibilityConfig::default();
327 assert_eq!(config.seed, 42);
328 assert!(config.deterministic_algorithms);
329 assert!(!config.cudnn_benchmark);
330 assert!(config.cudnn_deterministic);
331 }
332
333 #[test]
334 fn test_reproducibility_config_with_seed() {
335 let config = ReproducibilityConfig::with_seed(123);
336 assert_eq!(config.seed, 123);
337 assert!(config.deterministic_algorithms);
338 }
339
340 #[test]
341 fn test_reproducibility_config_non_deterministic() {
342 let config = ReproducibilityConfig::default().non_deterministic();
343 assert!(!config.deterministic_algorithms);
344 assert!(config.cudnn_benchmark);
345 assert!(!config.cudnn_deterministic);
346 }
347
348 #[test]
349 fn test_experiment_lock_new() {
350 let lock = ExperimentLock::new("test-001");
351 assert_eq!(lock.experiment_id, "test-001");
352 assert!(!lock.timestamp.is_empty());
353 assert!(!lock.rust_version.is_empty());
354 }
355
356 #[test]
357 fn test_experiment_lock_with_checksums() {
358 let lock = ExperimentLock::new("test")
359 .with_config_checksum("abc123")
360 .with_model_checksum("def456")
361 .with_dataset_checksum("ghi789");
362
363 assert_eq!(lock.config_checksum, "abc123");
364 assert_eq!(lock.model_checksum, Some("def456".into()));
365 assert_eq!(lock.dataset_checksum, Some("ghi789".into()));
366 }
367
368 #[test]
369 fn test_experiment_lock_serialization() {
370 let lock =
371 ExperimentLock::new("test").with_reproducibility(ReproducibilityConfig::with_seed(100));
372
373 let yaml = serde_yaml::to_string(&lock).expect("operation should succeed");
374 assert!(yaml.contains("experiment_id: test"));
375 assert!(yaml.contains("seed: 100"));
376
377 let restored: ExperimentLock =
378 serde_yaml::from_str(&yaml).expect("lock acquisition should succeed");
379 assert_eq!(restored.experiment_id, "test");
380 assert_eq!(restored.reproducibility.seed, 100);
381 }
382
383 #[test]
384 fn test_verification_result_passed() {
385 let result = VerificationResult::default();
386 assert!(result.passed());
387 assert!(result.warnings().is_empty());
388 }
389
390 #[test]
391 fn test_verification_result_with_mismatches() {
392 let result = VerificationResult {
393 git_mismatch: Some(("abc123".into(), "def456".into())),
394 rust_mismatch: None,
395 cuda_mismatch: None,
396 };
397
398 assert!(!result.passed());
399 assert_eq!(result.warnings().len(), 1);
400 assert!(result.warnings()[0].contains("Git commit"));
401 }
402
403 #[test]
404 fn test_dependency_version() {
405 let dep = DependencyVersion { name: "entrenar".into(), version: "0.5.6".into() };
406
407 let json = serde_json::to_string(&dep).expect("JSON serialization should succeed");
408 assert!(json.contains("entrenar"));
409 assert!(json.contains("0.5.6"));
410 }
411
412 #[test]
413 fn test_reproducibility_config_apply() {
414 let config = ReproducibilityConfig::with_seed(12345);
415 config.apply();
416
417 assert_eq!(std::env::var("PYTHONHASHSEED").expect("operation should succeed"), "12345");
419 assert_eq!(
420 std::env::var("CUBLAS_WORKSPACE_CONFIG").expect("config should be valid"),
421 ":4096:8"
422 );
423 }
424
425 #[test]
426 fn test_experiment_lock_save_load() {
427 let lock = ExperimentLock::new("save-load-test")
428 .with_reproducibility(ReproducibilityConfig::with_seed(999))
429 .with_config_checksum("sha256:test");
430
431 let temp_dir = std::env::temp_dir();
432 let path = temp_dir.join("test_lock.yaml");
433
434 lock.save(&path).expect("save should succeed");
436
437 let loaded = ExperimentLock::load(&path).expect("lock acquisition should succeed");
439 assert_eq!(loaded.experiment_id, "save-load-test");
440 assert_eq!(loaded.reproducibility.seed, 999);
441 assert_eq!(loaded.config_checksum, "sha256:test");
442
443 let _ = std::fs::remove_file(&path);
445 }
446
447 #[test]
448 fn test_experiment_lock_verify() {
449 let lock = ExperimentLock::new("verify-test");
450 let result = lock.verify();
451 let _ = result.passed();
453 let _ = result.warnings();
454 }
455
456 #[test]
457 fn test_verification_result_multiple_warnings() {
458 let result = VerificationResult {
459 git_mismatch: Some(("abc12345".into(), "def67890".into())),
460 rust_mismatch: Some(("1.70.0".into(), "1.75.0".into())),
461 cuda_mismatch: Some(("12.0".into(), "12.1".into())),
462 };
463
464 assert!(!result.passed());
465 let warnings = result.warnings();
466 assert_eq!(warnings.len(), 3);
467 assert!(warnings.iter().any(|w| w.contains("Git")));
468 assert!(warnings.iter().any(|w| w.contains("Rust")));
469 assert!(warnings.iter().any(|w| w.contains("CUDA")));
470 }
471
472 #[test]
473 fn test_experiment_lock_with_all_checksums() {
474 let lock = ExperimentLock::new("checksum-test")
475 .with_config_checksum("sha256:config")
476 .with_model_checksum("sha256:model")
477 .with_dataset_checksum("sha256:dataset");
478
479 assert_eq!(lock.config_checksum, "sha256:config");
480 assert_eq!(lock.model_checksum, Some("sha256:model".into()));
481 assert_eq!(lock.dataset_checksum, Some("sha256:dataset".into()));
482 }
483
484 #[test]
485 fn test_experiment_lock_yaml_format() {
486 let lock =
487 ExperimentLock::new("yaml-test").with_reproducibility(ReproducibilityConfig::default());
488
489 let yaml = serde_yaml::to_string(&lock).expect("operation should succeed");
490 assert!(yaml.contains("experiment_id"));
491 assert!(yaml.contains("timestamp"));
492 assert!(yaml.contains("reproducibility"));
493 assert!(yaml.contains("seed: 42"));
494 }
495}