entrenar/monitor/wasm/
collector.rs1#[cfg(target_arch = "wasm32")]
4use wasm_bindgen::prelude::*;
5
6use crate::monitor::{Metric, MetricStats, MetricsCollector};
7use std::collections::HashMap;
8
9#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
13#[derive(Debug)]
14pub struct WasmMetricsCollector {
15 inner: MetricsCollector,
16}
17
18#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
19impl WasmMetricsCollector {
20 #[cfg_attr(target_arch = "wasm32", wasm_bindgen(constructor))]
22 pub fn new() -> Self {
23 Self { inner: MetricsCollector::new() }
24 }
25
26 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
28 pub fn record_loss(&mut self, value: f64) {
29 self.inner.record(Metric::Loss, value);
30 }
31
32 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
34 pub fn record_accuracy(&mut self, value: f64) {
35 self.inner.record(Metric::Accuracy, value);
36 }
37
38 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
40 pub fn record_learning_rate(&mut self, value: f64) {
41 self.inner.record(Metric::LearningRate, value);
42 }
43
44 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
46 pub fn record_gradient_norm(&mut self, value: f64) {
47 self.inner.record(Metric::GradientNorm, value);
48 }
49
50 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
52 pub fn record_custom(&mut self, name: &str, value: f64) {
53 self.inner.record(Metric::Custom(name.to_string()), value);
54 }
55
56 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
58 pub fn count(&self) -> usize {
59 self.inner.count()
60 }
61
62 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
64 pub fn is_empty(&self) -> bool {
65 self.inner.is_empty()
66 }
67
68 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
70 pub fn clear(&mut self) {
71 self.inner.clear();
72 }
73
74 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
76 pub fn summary_json(&self) -> String {
77 let summary = self.inner.summary();
78 let json_map: HashMap<String, WasmMetricStats> = summary
79 .into_iter()
80 .map(|(k, v)| (k.as_str().to_string(), WasmMetricStats::from(v)))
81 .collect();
82 serde_json::to_string(&json_map).unwrap_or_else(|_err| "{}".to_string())
83 }
84
85 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
87 pub fn loss_mean(&self) -> f64 {
88 self.inner.summary().get(&Metric::Loss).map_or(f64::NAN, |s| s.mean)
89 }
90
91 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
93 pub fn accuracy_mean(&self) -> f64 {
94 self.inner.summary().get(&Metric::Accuracy).map_or(f64::NAN, |s| s.mean)
95 }
96
97 pub fn loss_values(&self) -> Vec<f64> {
99 self.inner
100 .to_records()
101 .iter()
102 .filter(|r| r.metric == Metric::Loss)
103 .map(|r| r.value)
104 .collect()
105 }
106
107 pub fn accuracy_values(&self) -> Vec<f64> {
109 self.inner
110 .to_records()
111 .iter()
112 .filter(|r| r.metric == Metric::Accuracy)
113 .map(|r| r.value)
114 .collect()
115 }
116
117 pub fn timestamps(&self) -> Vec<u64> {
119 self.inner.to_records().iter().map(|r| r.timestamp).collect()
120 }
121
122 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
124 pub fn loss_std(&self) -> f64 {
125 self.inner.summary().get(&Metric::Loss).map_or(f64::NAN, |s| s.std)
126 }
127
128 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
130 pub fn accuracy_std(&self) -> f64 {
131 self.inner.summary().get(&Metric::Accuracy).map_or(f64::NAN, |s| s.std)
132 }
133
134 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
136 pub fn loss_has_nan(&self) -> bool {
137 self.inner.summary().get(&Metric::Loss).is_some_and(|s| s.has_nan)
138 }
139
140 #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
142 pub fn loss_has_inf(&self) -> bool {
143 self.inner.summary().get(&Metric::Loss).is_some_and(|s| s.has_inf)
144 }
145}
146
147impl Default for WasmMetricsCollector {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
155pub(crate) struct WasmMetricStats {
156 count: usize,
157 mean: f64,
158 std: f64,
159 min: f64,
160 max: f64,
161 has_nan: bool,
162 has_inf: bool,
163}
164
165impl From<MetricStats> for WasmMetricStats {
166 fn from(s: MetricStats) -> Self {
167 Self {
168 count: s.count,
169 mean: s.mean,
170 std: s.std,
171 min: s.min,
172 max: s.max,
173 has_nan: s.has_nan,
174 has_inf: s.has_inf,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_wasm_collector_new() {
185 let collector = WasmMetricsCollector::new();
186 assert!(collector.is_empty());
187 assert_eq!(collector.count(), 0);
188 }
189
190 #[test]
191 fn test_wasm_collector_record_loss() {
192 let mut collector = WasmMetricsCollector::new();
193 collector.record_loss(0.5);
194 collector.record_loss(0.3);
195 assert_eq!(collector.count(), 2);
196 assert!((collector.loss_mean() - 0.4).abs() < 1e-6);
197 }
198
199 #[test]
200 fn test_wasm_collector_record_accuracy() {
201 let mut collector = WasmMetricsCollector::new();
202 collector.record_accuracy(0.8);
203 collector.record_accuracy(0.9);
204 assert_eq!(collector.count(), 2);
205 assert!((collector.accuracy_mean() - 0.85).abs() < 1e-6);
206 }
207
208 #[test]
209 fn test_wasm_collector_record_custom() {
210 let mut collector = WasmMetricsCollector::new();
211 collector.record_custom("perplexity", 15.5);
212 assert_eq!(collector.count(), 1);
213 }
214
215 #[test]
216 fn test_wasm_collector_clear() {
217 let mut collector = WasmMetricsCollector::new();
218 collector.record_loss(0.5);
219 collector.record_accuracy(0.8);
220 assert_eq!(collector.count(), 2);
221 collector.clear();
222 assert!(collector.is_empty());
223 }
224
225 #[test]
226 fn test_wasm_collector_summary_json() {
227 let mut collector = WasmMetricsCollector::new();
228 collector.record_loss(0.5);
229 collector.record_loss(0.3);
230
231 let json = collector.summary_json();
232 assert!(json.contains("loss"));
233 assert!(json.contains("mean"));
234
235 let parsed: serde_json::Value =
237 serde_json::from_str(&json).expect("JSON deserialization should succeed");
238 assert!(parsed.get("loss").is_some());
239 }
240
241 #[test]
242 fn test_wasm_collector_missing_metric() {
243 let collector = WasmMetricsCollector::new();
244 assert!(collector.loss_mean().is_nan());
245 assert!(collector.accuracy_mean().is_nan());
246 }
247
248 #[test]
249 fn test_wasm_metric_stats_from() {
250 let stats = MetricStats {
251 count: 10,
252 mean: 0.5,
253 std: 0.1,
254 min: 0.2,
255 max: 0.8,
256 sum: 5.0,
257 has_nan: false,
258 has_inf: false,
259 };
260
261 let wasm_stats = WasmMetricStats::from(stats);
262 assert_eq!(wasm_stats.count, 10);
263 assert!((wasm_stats.mean - 0.5).abs() < 1e-6);
264 assert!((wasm_stats.std - 0.1).abs() < 1e-6);
265 }
266
267 #[test]
268 fn test_wasm_collector_all_metric_types() {
269 let mut collector = WasmMetricsCollector::new();
270 collector.record_loss(0.5);
271 collector.record_accuracy(0.8);
272 collector.record_learning_rate(0.001);
273 collector.record_gradient_norm(1.5);
274 collector.record_custom("perplexity", 15.5);
275
276 assert_eq!(collector.count(), 5);
277
278 let json = collector.summary_json();
279 assert!(json.contains("loss"));
280 assert!(json.contains("accuracy"));
281 assert!(json.contains("learning_rate"));
282 assert!(json.contains("gradient_norm"));
283 assert!(json.contains("perplexity"));
284 }
285
286 #[test]
287 fn test_wasm_collector_default() {
288 let collector = WasmMetricsCollector::default();
289 assert!(collector.is_empty());
290 }
291
292 #[test]
293 fn test_wasm_collector_loss_values() {
294 let mut collector = WasmMetricsCollector::new();
295 collector.record_loss(0.5);
296 collector.record_loss(0.3);
297 collector.record_accuracy(0.8); let values = collector.loss_values();
300 assert_eq!(values.len(), 2);
301 assert!((values[0] - 0.5).abs() < 1e-6);
302 assert!((values[1] - 0.3).abs() < 1e-6);
303 }
304
305 #[test]
306 fn test_wasm_collector_accuracy_values() {
307 let mut collector = WasmMetricsCollector::new();
308 collector.record_accuracy(0.8);
309 collector.record_accuracy(0.9);
310 collector.record_loss(0.5); let values = collector.accuracy_values();
313 assert_eq!(values.len(), 2);
314 assert!((values[0] - 0.8).abs() < 1e-6);
315 assert!((values[1] - 0.9).abs() < 1e-6);
316 }
317
318 #[test]
319 fn test_wasm_collector_timestamps() {
320 let mut collector = WasmMetricsCollector::new();
321 collector.record_loss(0.5);
322 collector.record_loss(0.3);
323
324 let timestamps = collector.timestamps();
325 assert_eq!(timestamps.len(), 2);
326 assert!(timestamps[0] > 0);
327 assert!(timestamps[1] >= timestamps[0]);
328 }
329
330 #[test]
331 fn test_wasm_collector_loss_std() {
332 let mut collector = WasmMetricsCollector::new();
333 collector.record_loss(0.2);
334 collector.record_loss(0.4);
335 collector.record_loss(0.6);
336 collector.record_loss(0.8);
337
338 let std = collector.loss_std();
339 assert!(std > 0.0);
340 assert!(std < 1.0);
341 }
342
343 #[test]
344 fn test_wasm_collector_nan_detection() {
345 let mut collector = WasmMetricsCollector::new();
346 collector.record_loss(0.5);
347 assert!(!collector.loss_has_nan());
348
349 collector.record_loss(f64::NAN);
350 assert!(collector.loss_has_nan());
351 }
352
353 #[test]
354 fn test_wasm_collector_inf_detection() {
355 let mut collector = WasmMetricsCollector::new();
356 collector.record_loss(0.5);
357 assert!(!collector.loss_has_inf());
358
359 collector.record_loss(f64::INFINITY);
360 assert!(collector.loss_has_inf());
361 }
362
363 #[test]
364 fn test_wasm_collector_empty_std() {
365 let collector = WasmMetricsCollector::new();
366 assert!(collector.loss_std().is_nan());
367 assert!(collector.accuracy_std().is_nan());
368 }
369}
370
371#[cfg(test)]
372mod proptests {
373 use super::*;
374 use proptest::prelude::*;
375
376 proptest! {
377 #[test]
379 fn prop_record_increases_count(values in prop::collection::vec(-1e10f64..1e10, 1..100)) {
380 let mut collector = WasmMetricsCollector::new();
381 for v in &values {
382 collector.record_loss(*v);
383 }
384 let valid_count = values.iter().filter(|v| !v.is_nan() && !v.is_infinite()).count();
386 prop_assert_eq!(collector.count(), valid_count);
387 }
388
389 #[test]
391 fn prop_mean_within_bounds(values in prop::collection::vec(0.0f64..100.0, 2..50)) {
392 let mut collector = WasmMetricsCollector::new();
393 for v in &values {
394 collector.record_loss(*v);
395 }
396
397 let mean = collector.loss_mean();
398 let min = values.iter().copied().fold(f64::INFINITY, f64::min);
399 let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
400
401 prop_assert!(mean >= min - 1e-6);
402 prop_assert!(mean <= max + 1e-6);
403 }
404
405 #[test]
407 fn prop_std_non_negative(values in prop::collection::vec(0.0f64..100.0, 2..50)) {
408 let mut collector = WasmMetricsCollector::new();
409 for v in &values {
410 collector.record_loss(*v);
411 }
412
413 let std = collector.loss_std();
414 prop_assert!(std >= 0.0 || std.is_nan());
415 }
416 }
417}