Skip to main content

entrenar/cli/commands/
experiments.rs

1//! Experiment store CLI commands.
2//!
3//! Query experiments, runs, and metrics from the project-local SQLite store.
4
5use crate::cli::LogLevel;
6use crate::config::{ExperimentsArgs, ExperimentsCommand, OutputFormat};
7use crate::storage::{ExperimentStorage, SqliteBackend};
8
9pub fn run_experiments(args: ExperimentsArgs, log_level: LogLevel) -> Result<(), String> {
10    let store = SqliteBackend::open_project(&args.project)
11        .map_err(|e| format!("Failed to open experiment store: {e}"))?;
12
13    match args.command {
14        ExperimentsCommand::List => list_experiments(&store, &args.format, log_level),
15        ExperimentsCommand::Show { id } => show_experiment(&store, &id, &args.format),
16        ExperimentsCommand::Runs { experiment_id } => {
17            list_runs(&store, &experiment_id, &args.format)
18        }
19        ExperimentsCommand::Metrics { run_id, key } => {
20            show_metrics(&store, &run_id, &key, &args.format)
21        }
22        ExperimentsCommand::Delete { id } => delete_experiment(&store, &id, log_level),
23    }
24}
25
26fn list_experiments(
27    store: &SqliteBackend,
28    format: &OutputFormat,
29    _log_level: LogLevel,
30) -> Result<(), String> {
31    let experiments =
32        store.list_experiments().map_err(|e| format!("Failed to list experiments: {e}"))?;
33
34    if experiments.is_empty() {
35        eprintln!("No experiments found in {}", store.path());
36        return Ok(());
37    }
38
39    match format {
40        OutputFormat::Json => {
41            let json = serde_json::to_string_pretty(&experiments)
42                .map_err(|e| format!("JSON serialization failed: {e}"))?;
43            println!("{json}");
44        }
45        _ => {
46            println!("{:<20} {:<30} {:<24}", "ID", "NAME", "CREATED");
47            println!("{}", "-".repeat(74));
48            for exp in &experiments {
49                println!(
50                    "{:<20} {:<30} {:<24}",
51                    truncate(&exp.id, 18),
52                    truncate(&exp.name, 28),
53                    exp.created_at.format("%Y-%m-%d %H:%M:%S"),
54                );
55            }
56            println!("\n{} experiment(s)", experiments.len());
57        }
58    }
59
60    Ok(())
61}
62
63fn show_experiment(store: &SqliteBackend, id: &str, format: &OutputFormat) -> Result<(), String> {
64    let experiment =
65        store.get_experiment(id).map_err(|e| format!("Failed to get experiment: {e}"))?;
66
67    match format {
68        OutputFormat::Json => {
69            let json = serde_json::to_string_pretty(&experiment)
70                .map_err(|e| format!("JSON serialization failed: {e}"))?;
71            println!("{json}");
72        }
73        _ => {
74            println!("Experiment: {}", experiment.name);
75            println!("  ID:      {}", experiment.id);
76            println!("  Created: {}", experiment.created_at.format("%Y-%m-%d %H:%M:%S"));
77            println!("  Updated: {}", experiment.updated_at.format("%Y-%m-%d %H:%M:%S"));
78            if let Some(desc) = &experiment.description {
79                println!("  Desc:    {desc}");
80            }
81            if let Some(config) = &experiment.config {
82                println!("  Config:  {config}");
83            }
84
85            // Also show runs
86            if let Ok(runs) = store.list_runs(id) {
87                if !runs.is_empty() {
88                    println!("\n  Runs ({}):", runs.len());
89                    for run in &runs {
90                        println!(
91                            "    {:<18} {:?}  {}",
92                            truncate(&run.id, 16),
93                            run.status,
94                            run.start_time.format("%Y-%m-%d %H:%M:%S"),
95                        );
96                    }
97                }
98            }
99        }
100    }
101
102    Ok(())
103}
104
105fn list_runs(
106    store: &SqliteBackend,
107    experiment_id: &str,
108    format: &OutputFormat,
109) -> Result<(), String> {
110    let runs = store.list_runs(experiment_id).map_err(|e| format!("Failed to list runs: {e}"))?;
111
112    if runs.is_empty() {
113        eprintln!("No runs found for experiment {experiment_id}");
114        return Ok(());
115    }
116
117    match format {
118        OutputFormat::Json => {
119            let json = serde_json::to_string_pretty(&runs)
120                .map_err(|e| format!("JSON serialization failed: {e}"))?;
121            println!("{json}");
122        }
123        _ => {
124            println!("{:<20} {:<12} {:<24} {:<24}", "ID", "STATUS", "STARTED", "ENDED");
125            println!("{}", "-".repeat(80));
126            for run in &runs {
127                let end = run
128                    .end_time
129                    .map_or_else(|| "-".to_string(), |t| t.format("%Y-%m-%d %H:%M:%S").to_string());
130                println!(
131                    "{:<20} {:<12} {:<24} {:<24}",
132                    truncate(&run.id, 18),
133                    format!("{:?}", run.status),
134                    run.start_time.format("%Y-%m-%d %H:%M:%S"),
135                    end,
136                );
137            }
138            println!("\n{} run(s)", runs.len());
139        }
140    }
141
142    Ok(())
143}
144
145fn show_metrics(
146    store: &SqliteBackend,
147    run_id: &str,
148    key: &str,
149    format: &OutputFormat,
150) -> Result<(), String> {
151    let metrics =
152        store.get_metrics(run_id, key).map_err(|e| format!("Failed to get metrics: {e}"))?;
153
154    if metrics.is_empty() {
155        eprintln!("No metrics found for run {run_id}, key '{key}'");
156        return Ok(());
157    }
158
159    match format {
160        OutputFormat::Json => {
161            let json = serde_json::to_string_pretty(&metrics)
162                .map_err(|e| format!("JSON serialization failed: {e}"))?;
163            println!("{json}");
164        }
165        _ => {
166            println!("Metrics: {key} (run {run_id})");
167            println!("{:<8} {:<16} {:<24}", "STEP", "VALUE", "TIMESTAMP");
168            println!("{}", "-".repeat(48));
169            for point in &metrics {
170                println!(
171                    "{:<8} {:<16.6} {:<24}",
172                    point.step,
173                    point.value,
174                    point.timestamp.format("%Y-%m-%d %H:%M:%S"),
175                );
176            }
177            println!("\n{} point(s)", metrics.len());
178        }
179    }
180
181    Ok(())
182}
183
184fn delete_experiment(store: &SqliteBackend, id: &str, _log_level: LogLevel) -> Result<(), String> {
185    // Verify it exists first
186    store.get_experiment(id).map_err(|e| format!("Failed to find experiment: {e}"))?;
187
188    let conn = store.lock_conn().map_err(|e| format!("Lock error: {e}"))?;
189
190    // Delete in dependency order: metrics → params → artifacts → span_ids → runs → experiment
191    conn.execute(
192        "DELETE FROM metrics WHERE run_id IN (SELECT id FROM runs WHERE experiment_id = ?1)",
193        [id],
194    )
195    .map_err(|e| format!("Failed to delete metrics: {e}"))?;
196
197    conn.execute(
198        "DELETE FROM params WHERE run_id IN (SELECT id FROM runs WHERE experiment_id = ?1)",
199        [id],
200    )
201    .map_err(|e| format!("Failed to delete params: {e}"))?;
202
203    conn.execute(
204        "DELETE FROM artifacts WHERE run_id IN (SELECT id FROM runs WHERE experiment_id = ?1)",
205        [id],
206    )
207    .map_err(|e| format!("Failed to delete artifacts: {e}"))?;
208
209    conn.execute(
210        "DELETE FROM span_ids WHERE run_id IN (SELECT id FROM runs WHERE experiment_id = ?1)",
211        [id],
212    )
213    .map_err(|e| format!("Failed to delete span IDs: {e}"))?;
214
215    conn.execute("DELETE FROM runs WHERE experiment_id = ?1", [id])
216        .map_err(|e| format!("Failed to delete runs: {e}"))?;
217
218    conn.execute("DELETE FROM experiments WHERE id = ?1", [id])
219        .map_err(|e| format!("Failed to delete experiment: {e}"))?;
220
221    eprintln!("Deleted experiment {id}");
222    Ok(())
223}
224
225fn truncate(s: &str, max: usize) -> String {
226    if s.len() > max {
227        format!("{}...", &s[..max - 3])
228    } else {
229        s.to_string()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    #![allow(clippy::unwrap_used)]
236    use super::*;
237    use crate::storage::ExperimentStorage;
238
239    #[test]
240    fn test_truncate_short() {
241        assert_eq!(truncate("hi", 10), "hi");
242    }
243
244    #[test]
245    fn test_truncate_exact() {
246        assert_eq!(truncate("hello", 5), "hello");
247    }
248
249    #[test]
250    fn test_truncate_long() {
251        assert_eq!(truncate("a very long name", 10), "a very ...");
252    }
253
254    #[test]
255    fn test_truncate_boundary() {
256        assert_eq!(truncate("abcde", 4), "a...");
257    }
258
259    #[test]
260    fn test_list_experiments_empty() {
261        let store = SqliteBackend::open_in_memory().unwrap();
262        assert!(list_experiments(&store, &OutputFormat::Text, LogLevel::Normal).is_ok());
263    }
264
265    #[test]
266    fn test_list_experiments_with_data() {
267        let mut s = SqliteBackend::open_in_memory().unwrap();
268        s.create_experiment("e1", None).unwrap();
269        s.create_experiment("e2", None).unwrap();
270        assert!(list_experiments(&s, &OutputFormat::Text, LogLevel::Normal).is_ok());
271    }
272
273    #[test]
274    fn test_list_experiments_json() {
275        let mut s = SqliteBackend::open_in_memory().unwrap();
276        s.create_experiment("j", None).unwrap();
277        assert!(list_experiments(&s, &OutputFormat::Json, LogLevel::Normal).is_ok());
278    }
279
280    #[test]
281    fn test_show_experiment_text() {
282        let mut s = SqliteBackend::open_in_memory().unwrap();
283        let id = s.create_experiment("sh", None).unwrap();
284        assert!(show_experiment(&s, &id, &OutputFormat::Text).is_ok());
285    }
286
287    #[test]
288    fn test_show_experiment_json() {
289        let mut s = SqliteBackend::open_in_memory().unwrap();
290        let id = s.create_experiment("sj", None).unwrap();
291        assert!(show_experiment(&s, &id, &OutputFormat::Json).is_ok());
292    }
293
294    #[test]
295    fn test_show_experiment_not_found() {
296        let s = SqliteBackend::open_in_memory().unwrap();
297        assert!(show_experiment(&s, "x", &OutputFormat::Text).is_err());
298    }
299
300    #[test]
301    fn test_show_experiment_with_runs() {
302        let mut s = SqliteBackend::open_in_memory().unwrap();
303        let eid = s.create_experiment("wr", None).unwrap();
304        s.create_run(&eid).unwrap();
305        assert!(show_experiment(&s, &eid, &OutputFormat::Text).is_ok());
306    }
307
308    #[test]
309    fn test_list_runs_empty() {
310        let mut s = SqliteBackend::open_in_memory().unwrap();
311        let eid = s.create_experiment("nr", None).unwrap();
312        assert!(list_runs(&s, &eid, &OutputFormat::Text).is_ok());
313    }
314
315    #[test]
316    fn test_list_runs_with_data() {
317        let mut s = SqliteBackend::open_in_memory().unwrap();
318        let eid = s.create_experiment("hr", None).unwrap();
319        let rid = s.create_run(&eid).unwrap();
320        s.start_run(&rid).unwrap();
321        assert!(list_runs(&s, &eid, &OutputFormat::Text).is_ok());
322    }
323
324    #[test]
325    fn test_list_runs_json() {
326        let mut s = SqliteBackend::open_in_memory().unwrap();
327        let eid = s.create_experiment("rj", None).unwrap();
328        s.create_run(&eid).unwrap();
329        assert!(list_runs(&s, &eid, &OutputFormat::Json).is_ok());
330    }
331
332    #[test]
333    fn test_list_runs_completed() {
334        let mut s = SqliteBackend::open_in_memory().unwrap();
335        let eid = s.create_experiment("c", None).unwrap();
336        let rid = s.create_run(&eid).unwrap();
337        s.start_run(&rid).unwrap();
338        s.complete_run(&rid, crate::storage::RunStatus::Success).unwrap();
339        assert!(list_runs(&s, &eid, &OutputFormat::Text).is_ok());
340    }
341
342    #[test]
343    fn test_show_metrics_empty() {
344        let mut s = SqliteBackend::open_in_memory().unwrap();
345        let eid = s.create_experiment("me", None).unwrap();
346        let rid = s.create_run(&eid).unwrap();
347        s.start_run(&rid).unwrap();
348        assert!(show_metrics(&s, &rid, "loss", &OutputFormat::Text).is_ok());
349    }
350
351    #[test]
352    fn test_show_metrics_with_data() {
353        let mut s = SqliteBackend::open_in_memory().unwrap();
354        let eid = s.create_experiment("md", None).unwrap();
355        let rid = s.create_run(&eid).unwrap();
356        s.start_run(&rid).unwrap();
357        s.log_metric(&rid, "loss", 0, 0.5).unwrap();
358        s.log_metric(&rid, "loss", 1, 0.3).unwrap();
359        assert!(show_metrics(&s, &rid, "loss", &OutputFormat::Text).is_ok());
360    }
361
362    #[test]
363    fn test_show_metrics_json() {
364        let mut s = SqliteBackend::open_in_memory().unwrap();
365        let eid = s.create_experiment("mj", None).unwrap();
366        let rid = s.create_run(&eid).unwrap();
367        s.start_run(&rid).unwrap();
368        s.log_metric(&rid, "a", 0, 0.9).unwrap();
369        assert!(show_metrics(&s, &rid, "a", &OutputFormat::Json).is_ok());
370    }
371
372    #[test]
373    fn test_delete_experiment_ok() {
374        let mut s = SqliteBackend::open_in_memory().unwrap();
375        let id = s.create_experiment("d", None).unwrap();
376        assert!(delete_experiment(&s, &id, LogLevel::Normal).is_ok());
377    }
378
379    #[test]
380    fn test_delete_experiment_not_found() {
381        let s = SqliteBackend::open_in_memory().unwrap();
382        assert!(delete_experiment(&s, "x", LogLevel::Normal).is_err());
383    }
384
385    #[test]
386    fn test_delete_with_data() {
387        let mut s = SqliteBackend::open_in_memory().unwrap();
388        let eid = s.create_experiment("dd", None).unwrap();
389        let rid = s.create_run(&eid).unwrap();
390        s.start_run(&rid).unwrap();
391        s.log_metric(&rid, "l", 0, 0.5).unwrap();
392        assert!(delete_experiment(&s, &eid, LogLevel::Normal).is_ok());
393    }
394
395    #[test]
396    fn test_run_experiments_list() {
397        let d = std::env::temp_dir().join("ent_exp_l");
398        let _ = std::fs::create_dir_all(&d);
399        let a = ExperimentsArgs {
400            command: ExperimentsCommand::List,
401            project: d.clone(),
402            format: OutputFormat::Text,
403        };
404        assert!(run_experiments(a, LogLevel::Normal).is_ok());
405        let _ = std::fs::remove_dir_all(d.join(".entrenar"));
406    }
407
408    #[test]
409    fn test_run_experiments_show_nf() {
410        let d = std::env::temp_dir().join("ent_exp_s");
411        let _ = std::fs::create_dir_all(&d);
412        let a = ExperimentsArgs {
413            command: ExperimentsCommand::Show { id: "x".into() },
414            project: d.clone(),
415            format: OutputFormat::Text,
416        };
417        assert!(run_experiments(a, LogLevel::Normal).is_err());
418        let _ = std::fs::remove_dir_all(d.join(".entrenar"));
419    }
420}