Skip to main content

pacha/cli/
mod.rs

1//! CLI command handlers.
2//!
3//! This module contains the business logic for CLI commands,
4//! separated from argument parsing for testability.
5
6use crate::prelude::*;
7use std::fmt::Write;
8use std::path::Path;
9
10/// Handle model commands.
11pub fn handle_model_register(
12    registry: &Registry,
13    name: &str,
14    artifact: &Path,
15    version: &ModelVersion,
16    description: Option<&str>,
17) -> Result<ModelId> {
18    let data = std::fs::read(artifact)?;
19    let card = ModelCard::new(description.unwrap_or_default());
20    registry.register_model(name, version, &data, card)
21}
22
23/// Format model info for display.
24pub fn format_model_info(model: &Model) -> String {
25    let mut out = String::new();
26    let _ = writeln!(out, "Model: {}:{}", model.name, model.version);
27    let _ = writeln!(out, "  ID:          {}", model.id);
28    let _ = writeln!(out, "  Stage:       {}", model.stage);
29    let _ = writeln!(out, "  Created:     {}", model.created_at);
30    let _ = writeln!(out, "  Description: {}", model.card.description);
31    let _ = writeln!(out, "  Size:        {} bytes", model.content_address.size());
32    let _ = writeln!(out, "  Hash:        {}", model.content_address.hash_hex());
33    if !model.card.metrics.is_empty() {
34        out.push_str("  Metrics:\n");
35        for (k, v) in &model.card.metrics {
36            let _ = writeln!(out, "    {k}: {v}");
37        }
38    }
39    out
40}
41
42/// Handle dataset commands.
43pub fn handle_data_register(
44    registry: &Registry,
45    name: &str,
46    data_path: &Path,
47    version: &DatasetVersion,
48    purpose: Option<&str>,
49) -> Result<DatasetId> {
50    let content = std::fs::read(data_path)?;
51    let datasheet = Datasheet::new(purpose.unwrap_or_default());
52    registry.register_dataset(name, version, &content, datasheet)
53}
54
55/// Format dataset info for display.
56pub fn format_dataset_info(dataset: &Dataset) -> String {
57    let mut out = String::new();
58    let _ = writeln!(out, "Dataset: {}:{}", dataset.name, dataset.version);
59    let _ = writeln!(out, "  ID:      {}", dataset.id);
60    let _ = writeln!(out, "  Created: {}", dataset.created_at);
61    let _ = writeln!(out, "  Purpose: {}", dataset.datasheet.purpose);
62    let _ = writeln!(out, "  Size:    {} bytes", dataset.content_address.size());
63    let _ = writeln!(out, "  Hash:    {}", dataset.content_address.hash_hex());
64    out
65}
66
67/// Format recipe info for display.
68pub fn format_recipe_info(recipe: &TrainingRecipe) -> String {
69    let mut out = String::new();
70    let _ = writeln!(out, "Recipe: {}:{}", recipe.name, recipe.version);
71    let _ = writeln!(out, "  ID:          {}", recipe.id);
72    let _ = writeln!(out, "  Description: {}", recipe.description);
73    let _ = writeln!(out, "  Created:     {}", recipe.created_at);
74    out.push_str("  Hyperparameters:\n");
75    let _ = writeln!(out, "    Learning rate: {}", recipe.hyperparameters.learning_rate);
76    let _ = writeln!(out, "    Batch size:    {}", recipe.hyperparameters.batch_size);
77    let _ = writeln!(out, "    Epochs:        {}", recipe.hyperparameters.epochs);
78    out
79}
80
81/// Format storage stats for display.
82pub fn format_stats(stats: &StorageStats) -> String {
83    let mut out = String::new();
84    out.push_str("Registry Statistics:\n");
85    let _ = writeln!(out, "  Models:   {}", stats.model_count);
86    let _ = writeln!(out, "  Datasets: {}", stats.dataset_count);
87    let _ = writeln!(out, "  Recipes:  {}", stats.recipe_count);
88    let _ = writeln!(out, "  Objects:  {}", stats.object_count);
89    let _ = writeln!(out, "  Size:     {} bytes", stats.total_size_bytes);
90    out
91}
92
93/// Format run info for display.
94pub fn format_run_info(run: &ExperimentRun) -> String {
95    contract_pre_display_format!();
96    let mut out = String::new();
97    let _ = writeln!(out, "Run: {}", run.run_id);
98    let _ = writeln!(out, "  Status:  {}", run.status);
99    let _ = writeln!(out, "  Started: {}", run.started_at);
100    if let Some(finished) = run.finished_at {
101        let _ = writeln!(out, "  Finished: {finished}");
102    }
103    if !run.metrics.is_empty() {
104        out.push_str("  Final metrics:\n");
105        let mut latest: std::collections::HashMap<&str, f64> = std::collections::HashMap::new();
106        for m in &run.metrics {
107            latest.insert(&m.name, m.value);
108        }
109        for (k, v) in latest {
110            let _ = writeln!(out, "    {k}: {v}");
111        }
112    }
113    out
114}
115
116/// Find best run by metric.
117pub fn find_best_run<'a>(
118    runs: &'a [ExperimentRun],
119    metric: &str,
120    minimize: bool,
121) -> Option<(&'a ExperimentRun, f64)> {
122    contract_pre_configuration!(runs);
123    runs.iter()
124        .filter(|r| r.status == RunStatus::Completed)
125        .filter_map(|r| r.get_metric(metric).map(|v| (r, v)))
126        .min_by(|(_, a), (_, b)| {
127            if minimize {
128                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
129            } else {
130                b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
131            }
132        })
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use tempfile::TempDir;
139
140    fn setup() -> (TempDir, Registry) {
141        let dir = TempDir::new().unwrap();
142        let config = RegistryConfig::new(dir.path());
143        let registry = Registry::open(config).unwrap();
144        (dir, registry)
145    }
146
147    #[test]
148    fn test_handle_model_register() {
149        let (dir, registry) = setup();
150        let artifact = dir.path().join("model.bin");
151        std::fs::write(&artifact, b"weights").unwrap();
152
153        let id = handle_model_register(
154            &registry,
155            "test",
156            &artifact,
157            &ModelVersion::new(1, 0, 0),
158            Some("Test model"),
159        )
160        .unwrap();
161
162        assert!(!id.to_string().is_empty());
163    }
164
165    #[test]
166    fn test_handle_data_register() {
167        let (dir, registry) = setup();
168        let data = dir.path().join("data.csv");
169        std::fs::write(&data, b"a,b,c").unwrap();
170
171        let id = handle_data_register(
172            &registry,
173            "test-data",
174            &data,
175            &DatasetVersion::new(1, 0, 0),
176            Some("Test data"),
177        )
178        .unwrap();
179
180        assert!(!id.to_string().is_empty());
181    }
182
183    #[test]
184    fn test_format_stats() {
185        let stats = StorageStats {
186            model_count: 5,
187            dataset_count: 3,
188            recipe_count: 2,
189            object_count: 10,
190            total_size_bytes: 1024,
191        };
192        let out = format_stats(&stats);
193        assert!(out.contains("Models:   5"));
194        assert!(out.contains("Datasets: 3"));
195    }
196
197    #[test]
198    fn test_find_best_run_maximize() {
199        let runs = vec![
200            create_run_with_metric("auc", 0.8),
201            create_run_with_metric("auc", 0.95),
202            create_run_with_metric("auc", 0.85),
203        ];
204        let best = find_best_run(&runs, "auc", false);
205        assert!(best.is_some());
206        assert!((best.unwrap().1 - 0.95).abs() < 1e-9);
207    }
208
209    #[test]
210    fn test_find_best_run_minimize() {
211        let runs = vec![
212            create_run_with_metric("loss", 0.5),
213            create_run_with_metric("loss", 0.1),
214            create_run_with_metric("loss", 0.3),
215        ];
216        let best = find_best_run(&runs, "loss", true);
217        assert!(best.is_some());
218        assert!((best.unwrap().1 - 0.1).abs() < 1e-9);
219    }
220
221    fn create_run_with_metric(name: &str, value: f64) -> ExperimentRun {
222        let mut run = ExperimentRun::new(Hyperparameters::default());
223        run.log_metric(name, value, 0);
224        run.complete();
225        run
226    }
227
228    #[test]
229    fn test_format_model_info() {
230        let (dir, registry) = setup();
231        let artifact = dir.path().join("m.bin");
232        std::fs::write(&artifact, b"data").unwrap();
233
234        let card = ModelCard::builder().description("Test").metrics([("acc", 0.9)]).build();
235        registry.register_model("fmt-test", &ModelVersion::new(1, 0, 0), b"data", card).unwrap();
236
237        let model = registry.get_model("fmt-test", &ModelVersion::new(1, 0, 0)).unwrap();
238        let out = format_model_info(&model);
239        assert!(out.contains("fmt-test:1.0.0"));
240        assert!(out.contains("Stage:"));
241        assert!(out.contains("acc: 0.9"));
242    }
243
244    #[test]
245    fn test_format_dataset_info() {
246        let (_dir, registry) = setup();
247        let datasheet = Datasheet::new("Test purpose");
248        registry
249            .register_dataset("fmt-data", &DatasetVersion::new(1, 0, 0), b"csv", datasheet)
250            .unwrap();
251
252        let ds = registry.get_dataset("fmt-data", &DatasetVersion::new(1, 0, 0)).unwrap();
253        let out = format_dataset_info(&ds);
254        assert!(out.contains("fmt-data:1.0.0"));
255        assert!(out.contains("Purpose: Test purpose"));
256    }
257
258    #[test]
259    fn test_format_recipe_info() {
260        let (_dir, registry) = setup();
261        let recipe = TrainingRecipe::builder()
262            .name("fmt-recipe")
263            .version(RecipeVersion::new(1, 0, 0))
264            .description("Test recipe")
265            .hyperparameters(
266                Hyperparameters::builder().learning_rate(0.01).batch_size(64).epochs(5).build(),
267            )
268            .build();
269        registry.register_recipe(&recipe).unwrap();
270
271        let r = registry.get_recipe("fmt-recipe", &RecipeVersion::new(1, 0, 0)).unwrap();
272        let out = format_recipe_info(&r);
273        assert!(out.contains("fmt-recipe:1.0.0"));
274        assert!(out.contains("Batch size:    64"));
275    }
276
277    #[test]
278    fn test_format_run_info() {
279        let mut run = ExperimentRun::new(Hyperparameters::default());
280        run.log_metric("loss", 0.5, 0);
281        run.log_metric("loss", 0.2, 100);
282        run.complete();
283
284        let out = format_run_info(&run);
285        assert!(out.contains("Status:  completed"));
286        assert!(out.contains("loss: 0.2"));
287    }
288
289    #[test]
290    fn test_find_best_run_no_matches() {
291        let runs = vec![create_run_with_metric("auc", 0.8)];
292        let best = find_best_run(&runs, "nonexistent", false);
293        assert!(best.is_none());
294    }
295
296    #[test]
297    fn test_find_best_run_empty() {
298        let runs: Vec<ExperimentRun> = vec![];
299        let best = find_best_run(&runs, "auc", false);
300        assert!(best.is_none());
301    }
302}