entrenar/dashboard/wasm/
run.rs1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use wasm_bindgen::prelude::*;
6
7use crate::storage::{ExperimentStorage, RunStatus};
8
9use super::storage::IndexedDbStorage;
10
11#[wasm_bindgen]
15pub struct WasmRun {
16 run_id: String,
17 experiment_id: String,
18 storage: Arc<Mutex<IndexedDbStorage>>,
19 step_counters: HashMap<String, u64>,
20 finished: bool,
21}
22
23#[wasm_bindgen]
24impl WasmRun {
25 #[wasm_bindgen(constructor)]
27 pub fn new(experiment_name: &str) -> std::result::Result<WasmRun, JsValue> {
28 let mut storage = IndexedDbStorage::new();
29
30 let experiment_id = storage
31 .create_experiment(experiment_name, None)
32 .map_err(|e| JsValue::from_str(&e.to_string()))?;
33
34 let run_id =
35 storage.create_run(&experiment_id).map_err(|e| JsValue::from_str(&e.to_string()))?;
36
37 storage.start_run(&run_id).map_err(|e| JsValue::from_str(&e.to_string()))?;
38
39 Ok(Self {
40 run_id,
41 experiment_id,
42 storage: Arc::new(Mutex::new(storage)),
43 step_counters: HashMap::new(),
44 finished: false,
45 })
46 }
47
48 pub fn log_metric(&mut self, key: &str, value: f64) -> std::result::Result<(), JsValue> {
50 if self.finished {
51 return Err(JsValue::from_str("Cannot log to finished run"));
52 }
53
54 let step = *self.step_counters.get(key).unwrap_or(&0);
55 self.log_metric_at(key, step, value)?;
56 self.step_counters.insert(key.to_string(), step + 1);
57 Ok(())
58 }
59
60 pub fn log_metric_at(
62 &mut self,
63 key: &str,
64 step: u64,
65 value: f64,
66 ) -> std::result::Result<(), JsValue> {
67 if self.finished {
68 return Err(JsValue::from_str("Cannot log to finished run"));
69 }
70
71 self.storage
72 .lock()
73 .map_err(|e| JsValue::from_str(&e.to_string()))?
74 .log_metric(&self.run_id, key, step, value)
75 .map_err(|e| JsValue::from_str(&e.to_string()))?;
76
77 Ok(())
78 }
79
80 pub fn get_metrics_json(&self) -> std::result::Result<String, JsValue> {
82 let storage = self.storage.lock().map_err(|e| JsValue::from_str(&e.to_string()))?;
83
84 let keys = storage.list_metric_keys(&self.run_id);
85 let mut metrics: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
86
87 for key in keys {
88 if let Ok(points) = storage.get_metrics(&self.run_id, &key) {
89 let values: Vec<serde_json::Value> = points
90 .iter()
91 .map(|p| {
92 serde_json::json!({
93 "step": p.step,
94 "value": p.value,
95 "timestamp": p.timestamp.to_rfc3339()
96 })
97 })
98 .collect();
99 metrics.insert(key, values);
100 }
101 }
102
103 serde_json::to_string(&metrics).map_err(|e| JsValue::from_str(&e.to_string()))
104 }
105
106 pub fn subscribe_metrics(&self, _callback: &js_sys::Function) {
110 }
114
115 pub fn run_id(&self) -> String {
117 self.run_id.clone()
118 }
119
120 pub fn experiment_id(&self) -> String {
122 self.experiment_id.clone()
123 }
124
125 pub fn current_step(&self, key: &str) -> u64 {
127 *self.step_counters.get(key).unwrap_or(&0)
128 }
129
130 pub fn is_finished(&self) -> bool {
132 self.finished
133 }
134
135 pub fn finish(&mut self) -> std::result::Result<(), JsValue> {
137 if self.finished {
138 return Ok(());
139 }
140
141 self.storage
142 .lock()
143 .map_err(|e| JsValue::from_str(&e.to_string()))?
144 .complete_run(&self.run_id, RunStatus::Success)
145 .map_err(|e| JsValue::from_str(&e.to_string()))?;
146
147 self.finished = true;
148 Ok(())
149 }
150
151 pub fn fail(&mut self) -> std::result::Result<(), JsValue> {
153 if self.finished {
154 return Ok(());
155 }
156
157 self.storage
158 .lock()
159 .map_err(|e| JsValue::from_str(&e.to_string()))?
160 .complete_run(&self.run_id, RunStatus::Failed)
161 .map_err(|e| JsValue::from_str(&e.to_string()))?;
162
163 self.finished = true;
164 Ok(())
165 }
166}