Skip to main content

entrenar/dashboard/wasm/
run.rs

1//! WASM-compatible run wrapper.
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use wasm_bindgen::prelude::*;
6
7use crate::storage::{ExperimentStorage, RunStatus};
8
9use super::storage::IndexedDbStorage;
10
11/// WASM-compatible run wrapper.
12///
13/// Provides a JavaScript-friendly API for training runs.
14#[wasm_bindgen]
15pub struct WasmRun {
16    run_id: String,
17    experiment_id: String,
18    storage: Arc<Mutex<IndexedDbStorage>>,
19    step_counters: HashMap<String, u64>,
20    finished: bool,
21}
22
23#[wasm_bindgen]
24impl WasmRun {
25    /// Create a new run in a new experiment.
26    #[wasm_bindgen(constructor)]
27    pub fn new(experiment_name: &str) -> std::result::Result<WasmRun, JsValue> {
28        let mut storage = IndexedDbStorage::new();
29
30        let experiment_id = storage
31            .create_experiment(experiment_name, None)
32            .map_err(|e| JsValue::from_str(&e.to_string()))?;
33
34        let run_id =
35            storage.create_run(&experiment_id).map_err(|e| JsValue::from_str(&e.to_string()))?;
36
37        storage.start_run(&run_id).map_err(|e| JsValue::from_str(&e.to_string()))?;
38
39        Ok(Self {
40            run_id,
41            experiment_id,
42            storage: Arc::new(Mutex::new(storage)),
43            step_counters: HashMap::new(),
44            finished: false,
45        })
46    }
47
48    /// Log a metric value, auto-incrementing the step.
49    pub fn log_metric(&mut self, key: &str, value: f64) -> std::result::Result<(), JsValue> {
50        if self.finished {
51            return Err(JsValue::from_str("Cannot log to finished run"));
52        }
53
54        let step = *self.step_counters.get(key).unwrap_or(&0);
55        self.log_metric_at(key, step, value)?;
56        self.step_counters.insert(key.to_string(), step + 1);
57        Ok(())
58    }
59
60    /// Log a metric value at a specific step.
61    pub fn log_metric_at(
62        &mut self,
63        key: &str,
64        step: u64,
65        value: f64,
66    ) -> std::result::Result<(), JsValue> {
67        if self.finished {
68            return Err(JsValue::from_str("Cannot log to finished run"));
69        }
70
71        self.storage
72            .lock()
73            .map_err(|e| JsValue::from_str(&e.to_string()))?
74            .log_metric(&self.run_id, key, step, value)
75            .map_err(|e| JsValue::from_str(&e.to_string()))?;
76
77        Ok(())
78    }
79
80    /// Get all metrics as a JSON string.
81    pub fn get_metrics_json(&self) -> std::result::Result<String, JsValue> {
82        let storage = self.storage.lock().map_err(|e| JsValue::from_str(&e.to_string()))?;
83
84        let keys = storage.list_metric_keys(&self.run_id);
85        let mut metrics: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
86
87        for key in keys {
88            if let Ok(points) = storage.get_metrics(&self.run_id, &key) {
89                let values: Vec<serde_json::Value> = points
90                    .iter()
91                    .map(|p| {
92                        serde_json::json!({
93                            "step": p.step,
94                            "value": p.value,
95                            "timestamp": p.timestamp.to_rfc3339()
96                        })
97                    })
98                    .collect();
99                metrics.insert(key, values);
100            }
101        }
102
103        serde_json::to_string(&metrics).map_err(|e| JsValue::from_str(&e.to_string()))
104    }
105
106    /// Subscribe to metric updates via a JavaScript callback.
107    ///
108    /// The callback receives (key: string, value: number) for each update.
109    pub fn subscribe_metrics(&self, _callback: &js_sys::Function) {
110        // In a full implementation, this would store the callback
111        // and invoke it when metrics are logged.
112        // For now, this is a placeholder showing the API.
113    }
114
115    /// Get the run ID.
116    pub fn run_id(&self) -> String {
117        self.run_id.clone()
118    }
119
120    /// Get the experiment ID.
121    pub fn experiment_id(&self) -> String {
122        self.experiment_id.clone()
123    }
124
125    /// Get current step for a metric key.
126    pub fn current_step(&self, key: &str) -> u64 {
127        *self.step_counters.get(key).unwrap_or(&0)
128    }
129
130    /// Check if the run is finished.
131    pub fn is_finished(&self) -> bool {
132        self.finished
133    }
134
135    /// Finish the run with success status.
136    pub fn finish(&mut self) -> std::result::Result<(), JsValue> {
137        if self.finished {
138            return Ok(());
139        }
140
141        self.storage
142            .lock()
143            .map_err(|e| JsValue::from_str(&e.to_string()))?
144            .complete_run(&self.run_id, RunStatus::Success)
145            .map_err(|e| JsValue::from_str(&e.to_string()))?;
146
147        self.finished = true;
148        Ok(())
149    }
150
151    /// Finish the run with failed status.
152    pub fn fail(&mut self) -> std::result::Result<(), JsValue> {
153        if self.finished {
154            return Ok(());
155        }
156
157        self.storage
158            .lock()
159            .map_err(|e| JsValue::from_str(&e.to_string()))?
160            .complete_run(&self.run_id, RunStatus::Failed)
161            .map_err(|e| JsValue::from_str(&e.to_string()))?;
162
163        self.finished = true;
164        Ok(())
165    }
166}