1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ExperimentMeta {
9 #[serde(default)]
10 pub name: Option<String>,
11 #[serde(default)]
12 pub project: Option<String>,
13 #[serde(default)]
14 pub status: Option<String>,
15 #[serde(default)]
16 pub start_time: Option<String>,
17 #[serde(default)]
18 pub end_time: Option<String>,
19 #[serde(default)]
20 pub total_steps: Option<u64>,
21 #[serde(default)]
22 pub best_metrics: HashMap<String, f64>,
23
24 #[serde(default)]
26 pub seed: Option<u64>,
27 #[serde(default)]
28 pub config_hash: Option<String>,
29 #[serde(default)]
30 pub config: Option<serde_json::Value>,
31
32 #[serde(default)]
34 pub tags: Vec<String>,
35
36 #[serde(default)]
38 pub git_hash: Option<String>,
39 #[serde(default)]
40 pub git_dirty: Option<bool>,
41 #[serde(default)]
42 pub hostname: Option<String>,
43 #[serde(default)]
44 pub gpu_model: Option<String>,
45 #[serde(default)]
46 pub python_version: Option<String>,
47 #[serde(default)]
48 pub pytorch_version: Option<String>,
49}
50
51pub fn read_meta(dir: &Path) -> Option<ExperimentMeta> {
53 let meta_path = dir.join("meta.json");
54 let content = std::fs::read_to_string(&meta_path).ok()?;
55 serde_json::from_str(&content).ok()
56}
57
58pub fn read_tags(dir: &Path) -> Vec<String> {
60 read_meta(dir).map(|m| m.tags).unwrap_or_default()
61}
62
63pub fn update_tags(dir: &Path, tags: &[String]) -> Result<()> {
66 let meta_path = dir.join("meta.json");
67 let mut doc: serde_json::Value = if meta_path.exists() {
68 let content = std::fs::read_to_string(&meta_path)?;
69 serde_json::from_str(&content)?
70 } else {
71 serde_json::json!({})
72 };
73
74 doc["tags"] = serde_json::json!(tags);
75
76 let tmp_path = dir.join("meta.json.tmp");
77 std::fs::write(&tmp_path, serde_json::to_string_pretty(&doc)?)?;
78 std::fs::rename(&tmp_path, &meta_path)?;
79 Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::fs;
86
87 #[test]
88 fn test_read_valid_meta() {
89 let dir = tempfile::tempdir().unwrap();
90 let meta = r#"{
91 "name": "exp-001",
92 "project": "test",
93 "status": "done",
94 "total_steps": 1000,
95 "best_metrics": {"loss": 0.05, "psnr": 28.4},
96 "git_hash": "abc123"
97 }"#;
98 fs::write(dir.path().join("meta.json"), meta).unwrap();
99
100 let result = read_meta(dir.path()).unwrap();
101 assert_eq!(result.name.as_deref(), Some("exp-001"));
102 assert_eq!(result.status.as_deref(), Some("done"));
103 assert_eq!(result.total_steps, Some(1000));
104 assert!((result.best_metrics["loss"] - 0.05).abs() < f64::EPSILON);
105 assert_eq!(result.git_hash.as_deref(), Some("abc123"));
106 }
107
108 #[test]
109 fn test_read_missing_meta() {
110 let dir = tempfile::tempdir().unwrap();
111 assert!(read_meta(dir.path()).is_none());
112 }
113
114 #[test]
115 fn test_read_partial_meta() {
116 let dir = tempfile::tempdir().unwrap();
117 let meta = r#"{"name": "partial"}"#;
118 fs::write(dir.path().join("meta.json"), meta).unwrap();
119
120 let result = read_meta(dir.path()).unwrap();
121 assert_eq!(result.name.as_deref(), Some("partial"));
122 assert!(result.project.is_none());
123 assert!(result.best_metrics.is_empty());
124 }
125
126 #[test]
127 fn test_update_tags_creates_meta() {
128 let dir = tempfile::tempdir().unwrap();
129 let tags = vec!["best".to_string(), "baseline".to_string()];
130 update_tags(dir.path(), &tags).unwrap();
131
132 let read_back = read_tags(dir.path());
133 assert_eq!(read_back, tags);
134 }
135
136 #[test]
137 fn test_update_tags_preserves_existing_fields() {
138 let dir = tempfile::tempdir().unwrap();
139 let meta = r#"{"name": "exp-001", "custom_field": 42}"#;
140 fs::write(dir.path().join("meta.json"), meta).unwrap();
141
142 update_tags(dir.path(), &["v1".to_string()]).unwrap();
143
144 let read_back = read_tags(dir.path());
146 assert_eq!(read_back, vec!["v1".to_string()]);
147
148 let content = fs::read_to_string(dir.path().join("meta.json")).unwrap();
150 let doc: serde_json::Value = serde_json::from_str(&content).unwrap();
151 assert_eq!(doc["custom_field"], serde_json::json!(42));
152 assert_eq!(doc["name"], serde_json::json!("exp-001"));
153 }
154
155 #[test]
156 fn test_read_tags_missing_meta() {
157 let dir = tempfile::tempdir().unwrap();
158 let tags = read_tags(dir.path());
159 assert!(tags.is_empty());
160 }
161}