Skip to main content

entrenar/monitor/wasm/
collector.rs

1//! WASM-compatible metrics collector.
2
3#[cfg(target_arch = "wasm32")]
4use wasm_bindgen::prelude::*;
5
6use crate::monitor::{Metric, MetricStats, MetricsCollector};
7use std::collections::HashMap;
8
9/// WASM-compatible metrics collector.
10///
11/// Wraps MetricsCollector with JavaScript-friendly API.
12#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
13#[derive(Debug)]
14pub struct WasmMetricsCollector {
15    inner: MetricsCollector,
16}
17
18#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
19impl WasmMetricsCollector {
20    /// Create a new metrics collector.
21    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(constructor))]
22    pub fn new() -> Self {
23        Self { inner: MetricsCollector::new() }
24    }
25
26    /// Record a loss value.
27    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
28    pub fn record_loss(&mut self, value: f64) {
29        self.inner.record(Metric::Loss, value);
30    }
31
32    /// Record an accuracy value.
33    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
34    pub fn record_accuracy(&mut self, value: f64) {
35        self.inner.record(Metric::Accuracy, value);
36    }
37
38    /// Record a learning rate value.
39    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
40    pub fn record_learning_rate(&mut self, value: f64) {
41        self.inner.record(Metric::LearningRate, value);
42    }
43
44    /// Record a gradient norm value.
45    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
46    pub fn record_gradient_norm(&mut self, value: f64) {
47        self.inner.record(Metric::GradientNorm, value);
48    }
49
50    /// Record a custom metric.
51    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
52    pub fn record_custom(&mut self, name: &str, value: f64) {
53        self.inner.record(Metric::Custom(name.to_string()), value);
54    }
55
56    /// Get number of recorded metrics.
57    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
58    pub fn count(&self) -> usize {
59        self.inner.count()
60    }
61
62    /// Check if collector is empty.
63    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
64    pub fn is_empty(&self) -> bool {
65        self.inner.is_empty()
66    }
67
68    /// Clear all recorded metrics.
69    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
70    pub fn clear(&mut self) {
71        self.inner.clear();
72    }
73
74    /// Get summary as JSON string.
75    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
76    pub fn summary_json(&self) -> String {
77        let summary = self.inner.summary();
78        let json_map: HashMap<String, WasmMetricStats> = summary
79            .into_iter()
80            .map(|(k, v)| (k.as_str().to_string(), WasmMetricStats::from(v)))
81            .collect();
82        serde_json::to_string(&json_map).unwrap_or_else(|_err| "{}".to_string())
83    }
84
85    /// Get loss statistics.
86    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
87    pub fn loss_mean(&self) -> f64 {
88        self.inner.summary().get(&Metric::Loss).map_or(f64::NAN, |s| s.mean)
89    }
90
91    /// Get accuracy statistics.
92    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
93    pub fn accuracy_mean(&self) -> f64 {
94        self.inner.summary().get(&Metric::Accuracy).map_or(f64::NAN, |s| s.mean)
95    }
96
97    /// Get loss values as a Vec for JavaScript Float64Array conversion.
98    pub fn loss_values(&self) -> Vec<f64> {
99        self.inner
100            .to_records()
101            .iter()
102            .filter(|r| r.metric == Metric::Loss)
103            .map(|r| r.value)
104            .collect()
105    }
106
107    /// Get accuracy values as a Vec for JavaScript Float64Array conversion.
108    pub fn accuracy_values(&self) -> Vec<f64> {
109        self.inner
110            .to_records()
111            .iter()
112            .filter(|r| r.metric == Metric::Accuracy)
113            .map(|r| r.value)
114            .collect()
115    }
116
117    /// Get all timestamps as milliseconds since epoch.
118    pub fn timestamps(&self) -> Vec<u64> {
119        self.inner.to_records().iter().map(|r| r.timestamp).collect()
120    }
121
122    /// Get loss standard deviation.
123    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
124    pub fn loss_std(&self) -> f64 {
125        self.inner.summary().get(&Metric::Loss).map_or(f64::NAN, |s| s.std)
126    }
127
128    /// Get accuracy standard deviation.
129    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
130    pub fn accuracy_std(&self) -> f64 {
131        self.inner.summary().get(&Metric::Accuracy).map_or(f64::NAN, |s| s.std)
132    }
133
134    /// Check if NaN was detected in loss.
135    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
136    pub fn loss_has_nan(&self) -> bool {
137        self.inner.summary().get(&Metric::Loss).is_some_and(|s| s.has_nan)
138    }
139
140    /// Check if Inf was detected in loss.
141    #[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
142    pub fn loss_has_inf(&self) -> bool {
143        self.inner.summary().get(&Metric::Loss).is_some_and(|s| s.has_inf)
144    }
145}
146
147impl Default for WasmMetricsCollector {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// WASM-compatible metric statistics (JSON-serializable).
154#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
155pub(crate) struct WasmMetricStats {
156    count: usize,
157    mean: f64,
158    std: f64,
159    min: f64,
160    max: f64,
161    has_nan: bool,
162    has_inf: bool,
163}
164
165impl From<MetricStats> for WasmMetricStats {
166    fn from(s: MetricStats) -> Self {
167        Self {
168            count: s.count,
169            mean: s.mean,
170            std: s.std,
171            min: s.min,
172            max: s.max,
173            has_nan: s.has_nan,
174            has_inf: s.has_inf,
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_wasm_collector_new() {
185        let collector = WasmMetricsCollector::new();
186        assert!(collector.is_empty());
187        assert_eq!(collector.count(), 0);
188    }
189
190    #[test]
191    fn test_wasm_collector_record_loss() {
192        let mut collector = WasmMetricsCollector::new();
193        collector.record_loss(0.5);
194        collector.record_loss(0.3);
195        assert_eq!(collector.count(), 2);
196        assert!((collector.loss_mean() - 0.4).abs() < 1e-6);
197    }
198
199    #[test]
200    fn test_wasm_collector_record_accuracy() {
201        let mut collector = WasmMetricsCollector::new();
202        collector.record_accuracy(0.8);
203        collector.record_accuracy(0.9);
204        assert_eq!(collector.count(), 2);
205        assert!((collector.accuracy_mean() - 0.85).abs() < 1e-6);
206    }
207
208    #[test]
209    fn test_wasm_collector_record_custom() {
210        let mut collector = WasmMetricsCollector::new();
211        collector.record_custom("perplexity", 15.5);
212        assert_eq!(collector.count(), 1);
213    }
214
215    #[test]
216    fn test_wasm_collector_clear() {
217        let mut collector = WasmMetricsCollector::new();
218        collector.record_loss(0.5);
219        collector.record_accuracy(0.8);
220        assert_eq!(collector.count(), 2);
221        collector.clear();
222        assert!(collector.is_empty());
223    }
224
225    #[test]
226    fn test_wasm_collector_summary_json() {
227        let mut collector = WasmMetricsCollector::new();
228        collector.record_loss(0.5);
229        collector.record_loss(0.3);
230
231        let json = collector.summary_json();
232        assert!(json.contains("loss"));
233        assert!(json.contains("mean"));
234
235        // Parse and validate
236        let parsed: serde_json::Value =
237            serde_json::from_str(&json).expect("JSON deserialization should succeed");
238        assert!(parsed.get("loss").is_some());
239    }
240
241    #[test]
242    fn test_wasm_collector_missing_metric() {
243        let collector = WasmMetricsCollector::new();
244        assert!(collector.loss_mean().is_nan());
245        assert!(collector.accuracy_mean().is_nan());
246    }
247
248    #[test]
249    fn test_wasm_metric_stats_from() {
250        let stats = MetricStats {
251            count: 10,
252            mean: 0.5,
253            std: 0.1,
254            min: 0.2,
255            max: 0.8,
256            sum: 5.0,
257            has_nan: false,
258            has_inf: false,
259        };
260
261        let wasm_stats = WasmMetricStats::from(stats);
262        assert_eq!(wasm_stats.count, 10);
263        assert!((wasm_stats.mean - 0.5).abs() < 1e-6);
264        assert!((wasm_stats.std - 0.1).abs() < 1e-6);
265    }
266
267    #[test]
268    fn test_wasm_collector_all_metric_types() {
269        let mut collector = WasmMetricsCollector::new();
270        collector.record_loss(0.5);
271        collector.record_accuracy(0.8);
272        collector.record_learning_rate(0.001);
273        collector.record_gradient_norm(1.5);
274        collector.record_custom("perplexity", 15.5);
275
276        assert_eq!(collector.count(), 5);
277
278        let json = collector.summary_json();
279        assert!(json.contains("loss"));
280        assert!(json.contains("accuracy"));
281        assert!(json.contains("learning_rate"));
282        assert!(json.contains("gradient_norm"));
283        assert!(json.contains("perplexity"));
284    }
285
286    #[test]
287    fn test_wasm_collector_default() {
288        let collector = WasmMetricsCollector::default();
289        assert!(collector.is_empty());
290    }
291
292    #[test]
293    fn test_wasm_collector_loss_values() {
294        let mut collector = WasmMetricsCollector::new();
295        collector.record_loss(0.5);
296        collector.record_loss(0.3);
297        collector.record_accuracy(0.8); // Should not be in loss values
298
299        let values = collector.loss_values();
300        assert_eq!(values.len(), 2);
301        assert!((values[0] - 0.5).abs() < 1e-6);
302        assert!((values[1] - 0.3).abs() < 1e-6);
303    }
304
305    #[test]
306    fn test_wasm_collector_accuracy_values() {
307        let mut collector = WasmMetricsCollector::new();
308        collector.record_accuracy(0.8);
309        collector.record_accuracy(0.9);
310        collector.record_loss(0.5); // Should not be in accuracy values
311
312        let values = collector.accuracy_values();
313        assert_eq!(values.len(), 2);
314        assert!((values[0] - 0.8).abs() < 1e-6);
315        assert!((values[1] - 0.9).abs() < 1e-6);
316    }
317
318    #[test]
319    fn test_wasm_collector_timestamps() {
320        let mut collector = WasmMetricsCollector::new();
321        collector.record_loss(0.5);
322        collector.record_loss(0.3);
323
324        let timestamps = collector.timestamps();
325        assert_eq!(timestamps.len(), 2);
326        assert!(timestamps[0] > 0);
327        assert!(timestamps[1] >= timestamps[0]);
328    }
329
330    #[test]
331    fn test_wasm_collector_loss_std() {
332        let mut collector = WasmMetricsCollector::new();
333        collector.record_loss(0.2);
334        collector.record_loss(0.4);
335        collector.record_loss(0.6);
336        collector.record_loss(0.8);
337
338        let std = collector.loss_std();
339        assert!(std > 0.0);
340        assert!(std < 1.0);
341    }
342
343    #[test]
344    fn test_wasm_collector_nan_detection() {
345        let mut collector = WasmMetricsCollector::new();
346        collector.record_loss(0.5);
347        assert!(!collector.loss_has_nan());
348
349        collector.record_loss(f64::NAN);
350        assert!(collector.loss_has_nan());
351    }
352
353    #[test]
354    fn test_wasm_collector_inf_detection() {
355        let mut collector = WasmMetricsCollector::new();
356        collector.record_loss(0.5);
357        assert!(!collector.loss_has_inf());
358
359        collector.record_loss(f64::INFINITY);
360        assert!(collector.loss_has_inf());
361    }
362
363    #[test]
364    fn test_wasm_collector_empty_std() {
365        let collector = WasmMetricsCollector::new();
366        assert!(collector.loss_std().is_nan());
367        assert!(collector.accuracy_std().is_nan());
368    }
369}
370
371#[cfg(test)]
372mod proptests {
373    use super::*;
374    use proptest::prelude::*;
375
376    proptest! {
377        /// Property: Recording values always increases count
378        #[test]
379        fn prop_record_increases_count(values in prop::collection::vec(-1e10f64..1e10, 1..100)) {
380            let mut collector = WasmMetricsCollector::new();
381            for v in &values {
382                collector.record_loss(*v);
383            }
384            // Count should equal valid (non-NaN, non-Inf) values
385            let valid_count = values.iter().filter(|v| !v.is_nan() && !v.is_infinite()).count();
386            prop_assert_eq!(collector.count(), valid_count);
387        }
388
389        /// Property: Mean is always within bounds of recorded values
390        #[test]
391        fn prop_mean_within_bounds(values in prop::collection::vec(0.0f64..100.0, 2..50)) {
392            let mut collector = WasmMetricsCollector::new();
393            for v in &values {
394                collector.record_loss(*v);
395            }
396
397            let mean = collector.loss_mean();
398            let min = values.iter().copied().fold(f64::INFINITY, f64::min);
399            let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
400
401            prop_assert!(mean >= min - 1e-6);
402            prop_assert!(mean <= max + 1e-6);
403        }
404
405        /// Property: Standard deviation is non-negative
406        #[test]
407        fn prop_std_non_negative(values in prop::collection::vec(0.0f64..100.0, 2..50)) {
408            let mut collector = WasmMetricsCollector::new();
409            for v in &values {
410                collector.record_loss(*v);
411            }
412
413            let std = collector.loss_std();
414            prop_assert!(std >= 0.0 || std.is_nan());
415        }
416    }
417}