1use crate::prelude::*;
7use std::fmt::Write;
8use std::path::Path;
9
10pub fn handle_model_register(
12 registry: &Registry,
13 name: &str,
14 artifact: &Path,
15 version: &ModelVersion,
16 description: Option<&str>,
17) -> Result<ModelId> {
18 let data = std::fs::read(artifact)?;
19 let card = ModelCard::new(description.unwrap_or_default());
20 registry.register_model(name, version, &data, card)
21}
22
23pub fn format_model_info(model: &Model) -> String {
25 let mut out = String::new();
26 let _ = writeln!(out, "Model: {}:{}", model.name, model.version);
27 let _ = writeln!(out, " ID: {}", model.id);
28 let _ = writeln!(out, " Stage: {}", model.stage);
29 let _ = writeln!(out, " Created: {}", model.created_at);
30 let _ = writeln!(out, " Description: {}", model.card.description);
31 let _ = writeln!(out, " Size: {} bytes", model.content_address.size());
32 let _ = writeln!(out, " Hash: {}", model.content_address.hash_hex());
33 if !model.card.metrics.is_empty() {
34 out.push_str(" Metrics:\n");
35 for (k, v) in &model.card.metrics {
36 let _ = writeln!(out, " {k}: {v}");
37 }
38 }
39 out
40}
41
42pub fn handle_data_register(
44 registry: &Registry,
45 name: &str,
46 data_path: &Path,
47 version: &DatasetVersion,
48 purpose: Option<&str>,
49) -> Result<DatasetId> {
50 let content = std::fs::read(data_path)?;
51 let datasheet = Datasheet::new(purpose.unwrap_or_default());
52 registry.register_dataset(name, version, &content, datasheet)
53}
54
55pub fn format_dataset_info(dataset: &Dataset) -> String {
57 let mut out = String::new();
58 let _ = writeln!(out, "Dataset: {}:{}", dataset.name, dataset.version);
59 let _ = writeln!(out, " ID: {}", dataset.id);
60 let _ = writeln!(out, " Created: {}", dataset.created_at);
61 let _ = writeln!(out, " Purpose: {}", dataset.datasheet.purpose);
62 let _ = writeln!(out, " Size: {} bytes", dataset.content_address.size());
63 let _ = writeln!(out, " Hash: {}", dataset.content_address.hash_hex());
64 out
65}
66
67pub fn format_recipe_info(recipe: &TrainingRecipe) -> String {
69 let mut out = String::new();
70 let _ = writeln!(out, "Recipe: {}:{}", recipe.name, recipe.version);
71 let _ = writeln!(out, " ID: {}", recipe.id);
72 let _ = writeln!(out, " Description: {}", recipe.description);
73 let _ = writeln!(out, " Created: {}", recipe.created_at);
74 out.push_str(" Hyperparameters:\n");
75 let _ = writeln!(out, " Learning rate: {}", recipe.hyperparameters.learning_rate);
76 let _ = writeln!(out, " Batch size: {}", recipe.hyperparameters.batch_size);
77 let _ = writeln!(out, " Epochs: {}", recipe.hyperparameters.epochs);
78 out
79}
80
81pub fn format_stats(stats: &StorageStats) -> String {
83 let mut out = String::new();
84 out.push_str("Registry Statistics:\n");
85 let _ = writeln!(out, " Models: {}", stats.model_count);
86 let _ = writeln!(out, " Datasets: {}", stats.dataset_count);
87 let _ = writeln!(out, " Recipes: {}", stats.recipe_count);
88 let _ = writeln!(out, " Objects: {}", stats.object_count);
89 let _ = writeln!(out, " Size: {} bytes", stats.total_size_bytes);
90 out
91}
92
93pub fn format_run_info(run: &ExperimentRun) -> String {
95 contract_pre_display_format!();
96 let mut out = String::new();
97 let _ = writeln!(out, "Run: {}", run.run_id);
98 let _ = writeln!(out, " Status: {}", run.status);
99 let _ = writeln!(out, " Started: {}", run.started_at);
100 if let Some(finished) = run.finished_at {
101 let _ = writeln!(out, " Finished: {finished}");
102 }
103 if !run.metrics.is_empty() {
104 out.push_str(" Final metrics:\n");
105 let mut latest: std::collections::HashMap<&str, f64> = std::collections::HashMap::new();
106 for m in &run.metrics {
107 latest.insert(&m.name, m.value);
108 }
109 for (k, v) in latest {
110 let _ = writeln!(out, " {k}: {v}");
111 }
112 }
113 out
114}
115
116pub fn find_best_run<'a>(
118 runs: &'a [ExperimentRun],
119 metric: &str,
120 minimize: bool,
121) -> Option<(&'a ExperimentRun, f64)> {
122 contract_pre_configuration!(runs);
123 runs.iter()
124 .filter(|r| r.status == RunStatus::Completed)
125 .filter_map(|r| r.get_metric(metric).map(|v| (r, v)))
126 .min_by(|(_, a), (_, b)| {
127 if minimize {
128 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
129 } else {
130 b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
131 }
132 })
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use tempfile::TempDir;
139
140 fn setup() -> (TempDir, Registry) {
141 let dir = TempDir::new().unwrap();
142 let config = RegistryConfig::new(dir.path());
143 let registry = Registry::open(config).unwrap();
144 (dir, registry)
145 }
146
147 #[test]
148 fn test_handle_model_register() {
149 let (dir, registry) = setup();
150 let artifact = dir.path().join("model.bin");
151 std::fs::write(&artifact, b"weights").unwrap();
152
153 let id = handle_model_register(
154 ®istry,
155 "test",
156 &artifact,
157 &ModelVersion::new(1, 0, 0),
158 Some("Test model"),
159 )
160 .unwrap();
161
162 assert!(!id.to_string().is_empty());
163 }
164
165 #[test]
166 fn test_handle_data_register() {
167 let (dir, registry) = setup();
168 let data = dir.path().join("data.csv");
169 std::fs::write(&data, b"a,b,c").unwrap();
170
171 let id = handle_data_register(
172 ®istry,
173 "test-data",
174 &data,
175 &DatasetVersion::new(1, 0, 0),
176 Some("Test data"),
177 )
178 .unwrap();
179
180 assert!(!id.to_string().is_empty());
181 }
182
183 #[test]
184 fn test_format_stats() {
185 let stats = StorageStats {
186 model_count: 5,
187 dataset_count: 3,
188 recipe_count: 2,
189 object_count: 10,
190 total_size_bytes: 1024,
191 };
192 let out = format_stats(&stats);
193 assert!(out.contains("Models: 5"));
194 assert!(out.contains("Datasets: 3"));
195 }
196
197 #[test]
198 fn test_find_best_run_maximize() {
199 let runs = vec![
200 create_run_with_metric("auc", 0.8),
201 create_run_with_metric("auc", 0.95),
202 create_run_with_metric("auc", 0.85),
203 ];
204 let best = find_best_run(&runs, "auc", false);
205 assert!(best.is_some());
206 assert!((best.unwrap().1 - 0.95).abs() < 1e-9);
207 }
208
209 #[test]
210 fn test_find_best_run_minimize() {
211 let runs = vec![
212 create_run_with_metric("loss", 0.5),
213 create_run_with_metric("loss", 0.1),
214 create_run_with_metric("loss", 0.3),
215 ];
216 let best = find_best_run(&runs, "loss", true);
217 assert!(best.is_some());
218 assert!((best.unwrap().1 - 0.1).abs() < 1e-9);
219 }
220
221 fn create_run_with_metric(name: &str, value: f64) -> ExperimentRun {
222 let mut run = ExperimentRun::new(Hyperparameters::default());
223 run.log_metric(name, value, 0);
224 run.complete();
225 run
226 }
227
228 #[test]
229 fn test_format_model_info() {
230 let (dir, registry) = setup();
231 let artifact = dir.path().join("m.bin");
232 std::fs::write(&artifact, b"data").unwrap();
233
234 let card = ModelCard::builder().description("Test").metrics([("acc", 0.9)]).build();
235 registry.register_model("fmt-test", &ModelVersion::new(1, 0, 0), b"data", card).unwrap();
236
237 let model = registry.get_model("fmt-test", &ModelVersion::new(1, 0, 0)).unwrap();
238 let out = format_model_info(&model);
239 assert!(out.contains("fmt-test:1.0.0"));
240 assert!(out.contains("Stage:"));
241 assert!(out.contains("acc: 0.9"));
242 }
243
244 #[test]
245 fn test_format_dataset_info() {
246 let (_dir, registry) = setup();
247 let datasheet = Datasheet::new("Test purpose");
248 registry
249 .register_dataset("fmt-data", &DatasetVersion::new(1, 0, 0), b"csv", datasheet)
250 .unwrap();
251
252 let ds = registry.get_dataset("fmt-data", &DatasetVersion::new(1, 0, 0)).unwrap();
253 let out = format_dataset_info(&ds);
254 assert!(out.contains("fmt-data:1.0.0"));
255 assert!(out.contains("Purpose: Test purpose"));
256 }
257
258 #[test]
259 fn test_format_recipe_info() {
260 let (_dir, registry) = setup();
261 let recipe = TrainingRecipe::builder()
262 .name("fmt-recipe")
263 .version(RecipeVersion::new(1, 0, 0))
264 .description("Test recipe")
265 .hyperparameters(
266 Hyperparameters::builder().learning_rate(0.01).batch_size(64).epochs(5).build(),
267 )
268 .build();
269 registry.register_recipe(&recipe).unwrap();
270
271 let r = registry.get_recipe("fmt-recipe", &RecipeVersion::new(1, 0, 0)).unwrap();
272 let out = format_recipe_info(&r);
273 assert!(out.contains("fmt-recipe:1.0.0"));
274 assert!(out.contains("Batch size: 64"));
275 }
276
277 #[test]
278 fn test_format_run_info() {
279 let mut run = ExperimentRun::new(Hyperparameters::default());
280 run.log_metric("loss", 0.5, 0);
281 run.log_metric("loss", 0.2, 100);
282 run.complete();
283
284 let out = format_run_info(&run);
285 assert!(out.contains("Status: completed"));
286 assert!(out.contains("loss: 0.2"));
287 }
288
289 #[test]
290 fn test_find_best_run_no_matches() {
291 let runs = vec![create_run_with_metric("auc", 0.8)];
292 let best = find_best_run(&runs, "nonexistent", false);
293 assert!(best.is_none());
294 }
295
296 #[test]
297 fn test_find_best_run_empty() {
298 let runs: Vec<ExperimentRun> = vec![];
299 let best = find_best_run(&runs, "auc", false);
300 assert!(best.is_none());
301 }
302}