Skip to main content

oxiphysics_io/
machine_learning_io.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Machine learning model I/O for the OxiPhysics engine.
6//!
7//! Covers:
8//! - ML model serialization (weights/biases as binary)
9//! - ONNX-like simplified format (op-graph with typed tensors)
10//! - PyTorch-like state dict (string-keyed tensor store)
11//! - Dataset I/O (training/validation split, shuffle)
12//! - Feature normalization parameter storage
13//! - Label encoding / decoding
14//! - Confusion matrix export
15//! - Training history (loss/accuracy per epoch)
16//! - Hyperparameter configuration
17//! - Model checkpoint with metadata
18
19use std::collections::HashMap;
20
21// ---------------------------------------------------------------------------
22// Tensor — flat f64 storage with shape
23// ---------------------------------------------------------------------------
24
25/// A multi-dimensional tensor stored as a flat `Vec`f64`.
26#[allow(dead_code)]
27#[derive(Debug, Clone, PartialEq)]
28pub struct Tensor {
29    /// Shape of the tensor (row-major).
30    pub shape: Vec<usize>,
31    /// Flat data in row-major order.
32    pub data: Vec<f64>,
33}
34
35impl Tensor {
36    /// Construct a tensor with the given shape and flat data.
37    ///
38    /// # Panics
39    /// Panics if the length of `data` does not match the product of `shape`.
40    pub fn new(shape: Vec<usize>, data: Vec<f64>) -> Self {
41        let expected: usize = shape.iter().product();
42        assert_eq!(
43            data.len(),
44            expected,
45            "data length {} does not match shape {:?} (product {})",
46            data.len(),
47            shape,
48            expected
49        );
50        Tensor { shape, data }
51    }
52
53    /// Create a zero tensor with the given shape.
54    pub fn zeros(shape: Vec<usize>) -> Self {
55        let n: usize = shape.iter().product();
56        Tensor {
57            shape,
58            data: vec![0.0; n],
59        }
60    }
61
62    /// Total number of elements.
63    pub fn numel(&self) -> usize {
64        self.data.len()
65    }
66
67    /// Number of dimensions.
68    pub fn ndim(&self) -> usize {
69        self.shape.len()
70    }
71
72    /// Serialise to little-endian bytes: `\[ndim u64\]\[dim0 u64\]...\[elem f64\]...`
73    pub fn to_bytes(&self) -> Vec<u8> {
74        let mut buf = Vec::with_capacity(8 + 8 * self.shape.len() + 8 * self.data.len());
75        buf.extend_from_slice(&(self.shape.len() as u64).to_le_bytes());
76        for &d in &self.shape {
77            buf.extend_from_slice(&(d as u64).to_le_bytes());
78        }
79        for &v in &self.data {
80            buf.extend_from_slice(&v.to_bits().to_le_bytes());
81        }
82        buf
83    }
84
85    /// Deserialise from bytes produced by [`Tensor::to_bytes`].
86    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
87        if bytes.len() < 8 {
88            return None;
89        }
90        let ndim = u64::from_le_bytes(bytes[0..8].try_into().ok()?) as usize;
91        let header_len = 8 + 8 * ndim;
92        if bytes.len() < header_len {
93            return None;
94        }
95        let mut shape = Vec::with_capacity(ndim);
96        for i in 0..ndim {
97            let off = 8 + 8 * i;
98            shape.push(u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?) as usize);
99        }
100        let n: usize = shape.iter().product();
101        if bytes.len() < header_len + 8 * n {
102            return None;
103        }
104        let mut data = Vec::with_capacity(n);
105        for i in 0..n {
106            let off = header_len + 8 * i;
107            let bits = u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?);
108            data.push(f64::from_bits(bits));
109        }
110        Some(Tensor { shape, data })
111    }
112
113    /// Element-wise add (must have identical shape).
114    pub fn add(&self, other: &Tensor) -> Option<Tensor> {
115        if self.shape != other.shape {
116            return None;
117        }
118        let data = self
119            .data
120            .iter()
121            .zip(&other.data)
122            .map(|(a, b)| a + b)
123            .collect();
124        Some(Tensor {
125            shape: self.shape.clone(),
126            data,
127        })
128    }
129
130    /// Scalar multiply.
131    pub fn scale(&self, s: f64) -> Tensor {
132        Tensor {
133            shape: self.shape.clone(),
134            data: self.data.iter().map(|v| v * s).collect(),
135        }
136    }
137
138    /// Compute sum of all elements.
139    pub fn sum(&self) -> f64 {
140        self.data.iter().sum()
141    }
142
143    /// Compute mean of all elements.
144    pub fn mean(&self) -> f64 {
145        if self.data.is_empty() {
146            return 0.0;
147        }
148        self.sum() / self.data.len() as f64
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Layer — single dense layer
154// ---------------------------------------------------------------------------
155
156/// A single dense (fully connected) layer with weights and biases.
157#[allow(dead_code)]
158#[derive(Debug, Clone)]
159pub struct DenseLayer {
160    /// Layer name.
161    pub name: String,
162    /// Weight matrix: shape `\[out_features, in_features\]`.
163    pub weights: Tensor,
164    /// Bias vector: shape `[out_features]`.
165    pub bias: Tensor,
166    /// Activation function name (e.g. `"relu"`, `"sigmoid"`, `"tanh"`, `"linear"`).
167    pub activation: String,
168}
169
170impl DenseLayer {
171    /// Construct a zero-initialised dense layer.
172    pub fn new(
173        name: impl Into<String>,
174        in_features: usize,
175        out_features: usize,
176        activation: impl Into<String>,
177    ) -> Self {
178        DenseLayer {
179            name: name.into(),
180            weights: Tensor::zeros(vec![out_features, in_features]),
181            bias: Tensor::zeros(vec![out_features]),
182            activation: activation.into(),
183        }
184    }
185
186    /// Forward pass: `output\[i\] = sum_j(w\[i,j\] * input\[j\]) + bias\[i\]`, then activation.
187    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
188        let in_feat = input.len();
189        let out_feat = self.bias.data.len();
190        let mut out = vec![0.0f64; out_feat];
191        for i in 0..out_feat {
192            let mut acc = self.bias.data[i];
193            for j in 0..in_feat.min(self.weights.data.len() / out_feat) {
194                acc += self.weights.data[i * in_feat + j] * input[j];
195            }
196            out[i] = apply_activation(acc, &self.activation);
197        }
198        out
199    }
200
201    /// Number of trainable parameters.
202    pub fn param_count(&self) -> usize {
203        self.weights.numel() + self.bias.numel()
204    }
205
206    /// Serialise to bytes (name length, name utf8, weights bytes, bias bytes).
207    pub fn to_bytes(&self) -> Vec<u8> {
208        let name_bytes = self.name.as_bytes();
209        let act_bytes = self.activation.as_bytes();
210        let mut buf = Vec::new();
211        buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
212        buf.extend_from_slice(name_bytes);
213        buf.extend_from_slice(&(act_bytes.len() as u64).to_le_bytes());
214        buf.extend_from_slice(act_bytes);
215        let wb = self.weights.to_bytes();
216        buf.extend_from_slice(&(wb.len() as u64).to_le_bytes());
217        buf.extend_from_slice(&wb);
218        let bb = self.bias.to_bytes();
219        buf.extend_from_slice(&(bb.len() as u64).to_le_bytes());
220        buf.extend_from_slice(&bb);
221        buf
222    }
223}
224
225/// Apply a named activation function to a scalar.
226#[allow(dead_code)]
227pub fn apply_activation(x: f64, activation: &str) -> f64 {
228    match activation {
229        "relu" => x.max(0.0),
230        "sigmoid" => 1.0 / (1.0 + (-x).exp()),
231        "tanh" => x.tanh(),
232        "softplus" => (1.0 + x.exp()).ln(),
233        "elu" => {
234            if x >= 0.0 {
235                x
236            } else {
237                x.exp() - 1.0
238            }
239        }
240        "leaky_relu" => {
241            if x >= 0.0 {
242                x
243            } else {
244                0.01 * x
245            }
246        }
247        _ => x, // linear / identity
248    }
249}
250
251// ---------------------------------------------------------------------------
252// ModelWeights — named layers with binary I/O
253// ---------------------------------------------------------------------------
254
255/// A collection of named dense layers (binary-serialisable model weights).
256#[allow(dead_code)]
257#[derive(Debug, Clone, Default)]
258pub struct ModelWeights {
259    /// Layers in order.
260    pub layers: Vec<DenseLayer>,
261}
262
263impl ModelWeights {
264    /// Create an empty model.
265    pub fn new() -> Self {
266        ModelWeights { layers: Vec::new() }
267    }
268
269    /// Append a layer.
270    pub fn add_layer(&mut self, layer: DenseLayer) {
271        self.layers.push(layer);
272    }
273
274    /// Look up a layer by name.
275    pub fn get_layer(&self, name: &str) -> Option<&DenseLayer> {
276        self.layers.iter().find(|l| l.name == name)
277    }
278
279    /// Total trainable parameter count.
280    pub fn total_params(&self) -> usize {
281        self.layers.iter().map(|l| l.param_count()).sum()
282    }
283
284    /// Serialise all layers to a flat byte buffer.
285    ///
286    /// Format: `\[layer_count: u64\]\[layer_0_len: u64\][layer_0_bytes]...`
287    pub fn to_bytes(&self) -> Vec<u8> {
288        let mut buf = Vec::new();
289        buf.extend_from_slice(&(self.layers.len() as u64).to_le_bytes());
290        for layer in &self.layers {
291            let lb = layer.to_bytes();
292            buf.extend_from_slice(&(lb.len() as u64).to_le_bytes());
293            buf.extend_from_slice(&lb);
294        }
295        buf
296    }
297}
298
299// ---------------------------------------------------------------------------
300// StateDict — PyTorch-like key-value tensor store
301// ---------------------------------------------------------------------------
302
303/// PyTorch-like state dict: a `HashMap<String, Tensor>`.
304#[allow(dead_code)]
305#[derive(Debug, Clone, Default)]
306pub struct StateDict {
307    /// The underlying key-value store.
308    pub tensors: HashMap<String, Tensor>,
309}
310
311impl StateDict {
312    /// Create an empty state dict.
313    pub fn new() -> Self {
314        StateDict {
315            tensors: HashMap::new(),
316        }
317    }
318
319    /// Insert a tensor under a key.
320    pub fn insert(&mut self, key: impl Into<String>, tensor: Tensor) {
321        self.tensors.insert(key.into(), tensor);
322    }
323
324    /// Retrieve a tensor by key.
325    pub fn get(&self, key: &str) -> Option<&Tensor> {
326        self.tensors.get(key)
327    }
328
329    /// Number of tensors.
330    pub fn len(&self) -> usize {
331        self.tensors.len()
332    }
333
334    /// Whether the dict is empty.
335    pub fn is_empty(&self) -> bool {
336        self.tensors.is_empty()
337    }
338
339    /// Total number of parameters (sum of all tensor element counts).
340    pub fn total_params(&self) -> usize {
341        self.tensors.values().map(|t| t.numel()).sum()
342    }
343
344    /// Serialise to bytes.
345    ///
346    /// Format: `\[entry_count: u64\](\[key_len: u64\][key_bytes]\[tensor_len: u64\][tensor_bytes])...`
347    pub fn to_bytes(&self) -> Vec<u8> {
348        let mut buf = Vec::new();
349        buf.extend_from_slice(&(self.tensors.len() as u64).to_le_bytes());
350        let mut keys: Vec<&String> = self.tensors.keys().collect();
351        keys.sort(); // deterministic order
352        for k in keys {
353            let kb = k.as_bytes();
354            buf.extend_from_slice(&(kb.len() as u64).to_le_bytes());
355            buf.extend_from_slice(kb);
356            let tb = self.tensors[k].to_bytes();
357            buf.extend_from_slice(&(tb.len() as u64).to_le_bytes());
358            buf.extend_from_slice(&tb);
359        }
360        buf
361    }
362
363    /// Deserialise from bytes produced by [`StateDict::to_bytes`].
364    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
365        let mut pos = 0usize;
366        let n = read_u64(bytes, &mut pos)? as usize;
367        let mut dict = StateDict::new();
368        for _ in 0..n {
369            let klen = read_u64(bytes, &mut pos)? as usize;
370            if pos + klen > bytes.len() {
371                return None;
372            }
373            let key = String::from_utf8(bytes[pos..pos + klen].to_vec()).ok()?;
374            pos += klen;
375            let tlen = read_u64(bytes, &mut pos)? as usize;
376            if pos + tlen > bytes.len() {
377                return None;
378            }
379            let tensor = Tensor::from_bytes(&bytes[pos..pos + tlen])?;
380            pos += tlen;
381            dict.insert(key, tensor);
382        }
383        Some(dict)
384    }
385}
386
387/// Read a little-endian `u64` from `bytes` at `*pos`, advancing `*pos` by 8.
388fn read_u64(bytes: &[u8], pos: &mut usize) -> Option<u64> {
389    if *pos + 8 > bytes.len() {
390        return None;
391    }
392    let v = u64::from_le_bytes(bytes[*pos..*pos + 8].try_into().ok()?);
393    *pos += 8;
394    Some(v)
395}
396
397// ---------------------------------------------------------------------------
398// OnnxLikeGraph — simplified ONNX-style operator graph
399// ---------------------------------------------------------------------------
400
401/// A single operation node in an ONNX-like compute graph.
402#[allow(dead_code)]
403#[derive(Debug, Clone)]
404pub struct OnnxNode {
405    /// Unique node name.
406    pub name: String,
407    /// Operation type (e.g. `"MatMul"`, `"Relu"`, `"Add"`, `"Sigmoid"`).
408    pub op_type: String,
409    /// Names of input tensors.
410    pub inputs: Vec<String>,
411    /// Names of output tensors.
412    pub outputs: Vec<String>,
413    /// Optional scalar attributes (key → value).
414    pub attributes: HashMap<String, f64>,
415}
416
417impl OnnxNode {
418    /// Construct a new node with no attributes.
419    pub fn new(
420        name: impl Into<String>,
421        op_type: impl Into<String>,
422        inputs: Vec<String>,
423        outputs: Vec<String>,
424    ) -> Self {
425        OnnxNode {
426            name: name.into(),
427            op_type: op_type.into(),
428            inputs,
429            outputs,
430            attributes: HashMap::new(),
431        }
432    }
433
434    /// Set an attribute.
435    pub fn with_attr(mut self, key: impl Into<String>, value: f64) -> Self {
436        self.attributes.insert(key.into(), value);
437        self
438    }
439}
440
441/// A simplified ONNX-like computation graph.
442#[allow(dead_code)]
443#[derive(Debug, Clone, Default)]
444pub struct OnnxLikeGraph {
445    /// Ordered list of operation nodes.
446    pub nodes: Vec<OnnxNode>,
447    /// Named initialiser tensors (model weights).
448    pub initializers: StateDict,
449    /// Input tensor names.
450    pub inputs: Vec<String>,
451    /// Output tensor names.
452    pub outputs: Vec<String>,
453    /// Graph name.
454    pub name: String,
455}
456
457impl OnnxLikeGraph {
458    /// Create an empty graph.
459    pub fn new(name: impl Into<String>) -> Self {
460        OnnxLikeGraph {
461            name: name.into(),
462            nodes: Vec::new(),
463            initializers: StateDict::new(),
464            inputs: Vec::new(),
465            outputs: Vec::new(),
466        }
467    }
468
469    /// Add an operation node.
470    pub fn add_node(&mut self, node: OnnxNode) {
471        self.nodes.push(node);
472    }
473
474    /// Add an initialiser.
475    pub fn add_initializer(&mut self, name: impl Into<String>, tensor: Tensor) {
476        self.initializers.insert(name, tensor);
477    }
478
479    /// Number of nodes in the graph.
480    pub fn node_count(&self) -> usize {
481        self.nodes.len()
482    }
483
484    /// Count nodes by op type.
485    pub fn count_op(&self, op: &str) -> usize {
486        self.nodes.iter().filter(|n| n.op_type == op).count()
487    }
488
489    /// Topological order check: returns `true` if all node input names
490    /// are either graph inputs or outputs of an earlier node.
491    pub fn is_topologically_valid(&self) -> bool {
492        let mut available: std::collections::HashSet<&str> =
493            self.inputs.iter().map(|s| s.as_str()).collect();
494        // Initialisers are always available
495        for k in self.initializers.tensors.keys() {
496            available.insert(k.as_str());
497        }
498        for node in &self.nodes {
499            for inp in &node.inputs {
500                if !available.contains(inp.as_str()) {
501                    return false;
502                }
503            }
504            for out in &node.outputs {
505                available.insert(out.as_str());
506            }
507        }
508        true
509    }
510}
511
512// ---------------------------------------------------------------------------
513// Dataset — rows of features with optional labels
514// ---------------------------------------------------------------------------
515
516/// A dataset row: a feature vector and an optional label index.
517#[allow(dead_code)]
518#[derive(Debug, Clone)]
519pub struct DataRow {
520    /// Feature values.
521    pub features: Vec<f64>,
522    /// Optional class label index.
523    pub label: Option<usize>,
524}
525
526impl DataRow {
527    /// Construct a labelled row.
528    pub fn labelled(features: Vec<f64>, label: usize) -> Self {
529        DataRow {
530            features,
531            label: Some(label),
532        }
533    }
534
535    /// Construct an unlabelled row.
536    pub fn unlabelled(features: Vec<f64>) -> Self {
537        DataRow {
538            features,
539            label: None,
540        }
541    }
542}
543
544/// A dataset with optional train/validation split.
545#[allow(dead_code)]
546#[derive(Debug, Clone, Default)]
547pub struct Dataset {
548    /// All rows.
549    pub rows: Vec<DataRow>,
550    /// Feature names.
551    pub feature_names: Vec<String>,
552    /// Class names (index → name).
553    pub class_names: Vec<String>,
554}
555
556impl Dataset {
557    /// Create an empty dataset.
558    pub fn new() -> Self {
559        Dataset {
560            rows: Vec::new(),
561            feature_names: Vec::new(),
562            class_names: Vec::new(),
563        }
564    }
565
566    /// Add a row.
567    pub fn push(&mut self, row: DataRow) {
568        self.rows.push(row);
569    }
570
571    /// Number of rows.
572    pub fn len(&self) -> usize {
573        self.rows.len()
574    }
575
576    /// Whether the dataset is empty.
577    pub fn is_empty(&self) -> bool {
578        self.rows.is_empty()
579    }
580
581    /// Number of features (from the first row, or 0).
582    pub fn num_features(&self) -> usize {
583        self.rows.first().map(|r| r.features.len()).unwrap_or(0)
584    }
585
586    /// Shuffle rows using Fisher-Yates with a simple LCG.
587    pub fn shuffle(&mut self, seed: u64) {
588        let n = self.rows.len();
589        if n < 2 {
590            return;
591        }
592        let mut rng = LcgRng::new(seed);
593        for i in (1..n).rev() {
594            let j = rng.next_usize_below(i + 1);
595            self.rows.swap(i, j);
596        }
597    }
598
599    /// Split into training and validation subsets.
600    ///
601    /// `val_fraction` is the fraction of rows reserved for validation.
602    pub fn train_val_split(&self, val_fraction: f64) -> (Dataset, Dataset) {
603        let val_count = ((self.rows.len() as f64) * val_fraction.clamp(0.0, 1.0)) as usize;
604        let train_count = self.rows.len().saturating_sub(val_count);
605        let mut train = Dataset {
606            rows: self.rows[..train_count].to_vec(),
607            feature_names: self.feature_names.clone(),
608            class_names: self.class_names.clone(),
609        };
610        let mut val = Dataset {
611            rows: self.rows[train_count..].to_vec(),
612            feature_names: self.feature_names.clone(),
613            class_names: self.class_names.clone(),
614        };
615        // suppress unused warnings
616        let _ = &mut train;
617        let _ = &mut val;
618        (train, val)
619    }
620
621    /// Compute per-feature mean and standard deviation.
622    ///
623    /// Returns `(means, stds)` each of length `num_features`.
624    pub fn feature_stats(&self) -> (Vec<f64>, Vec<f64>) {
625        let nf = self.num_features();
626        if nf == 0 || self.rows.is_empty() {
627            return (vec![], vec![]);
628        }
629        let n = self.rows.len() as f64;
630        let mut means = vec![0.0f64; nf];
631        for row in &self.rows {
632            for (k, &v) in row.features.iter().enumerate() {
633                means[k] += v;
634            }
635        }
636        for m in &mut means {
637            *m /= n;
638        }
639        let mut stds = vec![0.0f64; nf];
640        for row in &self.rows {
641            for (k, &v) in row.features.iter().enumerate() {
642                let d = v - means[k];
643                stds[k] += d * d;
644            }
645        }
646        for s in &mut stds {
647            *s = (*s / n).sqrt();
648        }
649        (means, stds)
650    }
651}
652
653// ---------------------------------------------------------------------------
654// Minimal LCG RNG for shuffle (no rand dep in this module)
655// ---------------------------------------------------------------------------
656
657/// A minimal linear congruential generator used for dataset shuffling.
658#[allow(dead_code)]
659struct LcgRng {
660    state: u64,
661}
662
663impl LcgRng {
664    fn new(seed: u64) -> Self {
665        LcgRng {
666            state: seed ^ 0x1234_5678_9abc_def0,
667        }
668    }
669
670    fn next_u64(&mut self) -> u64 {
671        // Knuth's MMIX constants
672        self.state = self
673            .state
674            .wrapping_mul(6_364_136_223_846_793_005)
675            .wrapping_add(1_442_695_040_888_963_407);
676        self.state
677    }
678
679    fn next_usize_below(&mut self, n: usize) -> usize {
680        if n == 0 {
681            return 0;
682        }
683        (self.next_u64() % n as u64) as usize
684    }
685}
686
687// ---------------------------------------------------------------------------
688// NormalizationParams — feature normalization storage
689// ---------------------------------------------------------------------------
690
691/// Stored feature normalization parameters (mean and std for z-score normalization).
692#[allow(dead_code)]
693#[derive(Debug, Clone)]
694pub struct NormalizationParams {
695    /// Per-feature mean.
696    pub means: Vec<f64>,
697    /// Per-feature standard deviation.
698    pub stds: Vec<f64>,
699    /// Minimum values (for min-max normalization).
700    pub mins: Vec<f64>,
701    /// Maximum values (for min-max normalization).
702    pub maxs: Vec<f64>,
703}
704
705impl NormalizationParams {
706    /// Compute z-score normalization parameters from a dataset.
707    pub fn from_dataset(dataset: &Dataset) -> Self {
708        let (means, stds) = dataset.feature_stats();
709        let nf = means.len();
710        let mut mins = vec![f64::INFINITY; nf];
711        let mut maxs = vec![f64::NEG_INFINITY; nf];
712        for row in &dataset.rows {
713            for (k, &v) in row.features.iter().enumerate() {
714                if v < mins[k] {
715                    mins[k] = v;
716                }
717                if v > maxs[k] {
718                    maxs[k] = v;
719                }
720            }
721        }
722        NormalizationParams {
723            means,
724            stds,
725            mins,
726            maxs,
727        }
728    }
729
730    /// Apply z-score normalization to a feature vector.
731    pub fn normalize_zscore(&self, features: &[f64]) -> Vec<f64> {
732        features
733            .iter()
734            .enumerate()
735            .map(|(k, &v)| {
736                let s = if k < self.stds.len() {
737                    self.stds[k]
738                } else {
739                    1.0
740                };
741                let m = if k < self.means.len() {
742                    self.means[k]
743                } else {
744                    0.0
745                };
746                if s.abs() < 1e-15 { 0.0 } else { (v - m) / s }
747            })
748            .collect()
749    }
750
751    /// Apply min-max normalization to a feature vector (maps to [0, 1]).
752    pub fn normalize_minmax(&self, features: &[f64]) -> Vec<f64> {
753        features
754            .iter()
755            .enumerate()
756            .map(|(k, &v)| {
757                let mn = if k < self.mins.len() {
758                    self.mins[k]
759                } else {
760                    0.0
761                };
762                let mx = if k < self.maxs.len() {
763                    self.maxs[k]
764                } else {
765                    1.0
766                };
767                let range = mx - mn;
768                if range.abs() < 1e-15 {
769                    0.0
770                } else {
771                    (v - mn) / range
772                }
773            })
774            .collect()
775    }
776
777    /// Serialise to bytes.
778    pub fn to_bytes(&self) -> Vec<u8> {
779        let mut buf = Vec::new();
780        let write_vec = |buf: &mut Vec<u8>, v: &[f64]| {
781            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
782            for &x in v {
783                buf.extend_from_slice(&x.to_bits().to_le_bytes());
784            }
785        };
786        write_vec(&mut buf, &self.means);
787        write_vec(&mut buf, &self.stds);
788        write_vec(&mut buf, &self.mins);
789        write_vec(&mut buf, &self.maxs);
790        buf
791    }
792}
793
794// ---------------------------------------------------------------------------
795// LabelEncoder — integer ↔ class-name mapping
796// ---------------------------------------------------------------------------
797
798/// Encodes class labels as integers and decodes them back.
799#[allow(dead_code)]
800#[derive(Debug, Clone, Default)]
801pub struct LabelEncoder {
802    /// Class names in order (index 0 = first class).
803    pub classes: Vec<String>,
804    /// Reverse map: class name → index.
805    index: HashMap<String, usize>,
806}
807
808impl LabelEncoder {
809    /// Create an empty encoder.
810    pub fn new() -> Self {
811        LabelEncoder {
812            classes: Vec::new(),
813            index: HashMap::new(),
814        }
815    }
816
817    /// Fit the encoder from a list of class names.
818    pub fn fit(mut class_names: Vec<String>) -> Self {
819        class_names.sort();
820        class_names.dedup();
821        let index = class_names
822            .iter()
823            .enumerate()
824            .map(|(i, s)| (s.clone(), i))
825            .collect();
826        LabelEncoder {
827            classes: class_names,
828            index,
829        }
830    }
831
832    /// Encode a class name to its index.
833    pub fn encode(&self, name: &str) -> Option<usize> {
834        self.index.get(name).copied()
835    }
836
837    /// Decode an index to its class name.
838    pub fn decode(&self, idx: usize) -> Option<&str> {
839        self.classes.get(idx).map(|s| s.as_str())
840    }
841
842    /// Number of classes.
843    pub fn num_classes(&self) -> usize {
844        self.classes.len()
845    }
846
847    /// One-hot encode an index.
848    pub fn one_hot(&self, idx: usize) -> Vec<f64> {
849        let mut v = vec![0.0f64; self.num_classes()];
850        if idx < v.len() {
851            v[idx] = 1.0;
852        }
853        v
854    }
855}
856
857// ---------------------------------------------------------------------------
858// ConfusionMatrix
859// ---------------------------------------------------------------------------
860
861/// Confusion matrix for multi-class classification.
862#[allow(dead_code)]
863#[derive(Debug, Clone)]
864pub struct ConfusionMatrix {
865    /// Number of classes.
866    pub num_classes: usize,
867    /// Matrix data: `counts\[true_label * num_classes + predicted_label\]`.
868    pub counts: Vec<u64>,
869}
870
871impl ConfusionMatrix {
872    /// Create a zero-filled confusion matrix for `num_classes` classes.
873    pub fn new(num_classes: usize) -> Self {
874        ConfusionMatrix {
875            num_classes,
876            counts: vec![0; num_classes * num_classes],
877        }
878    }
879
880    /// Record a single prediction.
881    pub fn record(&mut self, true_label: usize, predicted: usize) {
882        if true_label < self.num_classes && predicted < self.num_classes {
883            self.counts[true_label * self.num_classes + predicted] += 1;
884        }
885    }
886
887    /// Overall accuracy: fraction of correct predictions.
888    pub fn accuracy(&self) -> f64 {
889        let total: u64 = self.counts.iter().sum();
890        if total == 0 {
891            return 0.0;
892        }
893        let correct: u64 = (0..self.num_classes)
894            .map(|i| self.counts[i * self.num_classes + i])
895            .sum();
896        correct as f64 / total as f64
897    }
898
899    /// Per-class precision: `TP / (TP + FP)`.
900    pub fn precision(&self, class: usize) -> f64 {
901        if class >= self.num_classes {
902            return 0.0;
903        }
904        let tp = self.counts[class * self.num_classes + class] as f64;
905        let fp: f64 = (0..self.num_classes)
906            .filter(|&r| r != class)
907            .map(|r| self.counts[r * self.num_classes + class] as f64)
908            .sum();
909        if tp + fp < 1e-15 { 0.0 } else { tp / (tp + fp) }
910    }
911
912    /// Per-class recall: `TP / (TP + FN)`.
913    pub fn recall(&self, class: usize) -> f64 {
914        if class >= self.num_classes {
915            return 0.0;
916        }
917        let tp = self.counts[class * self.num_classes + class] as f64;
918        let fn_: f64 = (0..self.num_classes)
919            .filter(|&c| c != class)
920            .map(|c| self.counts[class * self.num_classes + c] as f64)
921            .sum();
922        if tp + fn_ < 1e-15 {
923            0.0
924        } else {
925            tp / (tp + fn_)
926        }
927    }
928
929    /// Per-class F1 score.
930    pub fn f1(&self, class: usize) -> f64 {
931        let p = self.precision(class);
932        let r = self.recall(class);
933        if p + r < 1e-15 {
934            0.0
935        } else {
936            2.0 * p * r / (p + r)
937        }
938    }
939
940    /// Export as a CSV string.
941    pub fn to_csv(&self) -> String {
942        let mut s = String::new();
943        // Header
944        s.push_str("true\\pred");
945        for j in 0..self.num_classes {
946            s.push_str(&format!(",class_{j}"));
947        }
948        s.push('\n');
949        for i in 0..self.num_classes {
950            s.push_str(&format!("class_{i}"));
951            for j in 0..self.num_classes {
952                s.push_str(&format!(",{}", self.counts[i * self.num_classes + j]));
953            }
954            s.push('\n');
955        }
956        s
957    }
958}
959
960// ---------------------------------------------------------------------------
961// TrainingHistory
962// ---------------------------------------------------------------------------
963
964/// Per-epoch metrics.
965#[allow(dead_code)]
966#[derive(Debug, Clone)]
967pub struct EpochRecord {
968    /// Epoch number (0-indexed).
969    pub epoch: usize,
970    /// Training loss.
971    pub train_loss: f64,
972    /// Validation loss.
973    pub val_loss: f64,
974    /// Training accuracy.
975    pub train_acc: f64,
976    /// Validation accuracy.
977    pub val_acc: f64,
978    /// Learning rate used.
979    pub learning_rate: f64,
980}
981
982/// Full training history for a model.
983#[allow(dead_code)]
984#[derive(Debug, Clone, Default)]
985pub struct TrainingHistory {
986    /// Records, one per epoch.
987    pub records: Vec<EpochRecord>,
988}
989
990impl TrainingHistory {
991    /// Create an empty history.
992    pub fn new() -> Self {
993        TrainingHistory {
994            records: Vec::new(),
995        }
996    }
997
998    /// Append an epoch record.
999    pub fn push(&mut self, record: EpochRecord) {
1000        self.records.push(record);
1001    }
1002
1003    /// Number of epochs recorded.
1004    pub fn num_epochs(&self) -> usize {
1005        self.records.len()
1006    }
1007
1008    /// Best validation accuracy and the epoch at which it occurred.
1009    pub fn best_val_acc(&self) -> Option<(usize, f64)> {
1010        self.records
1011            .iter()
1012            .enumerate()
1013            .max_by(|(_, a), (_, b)| {
1014                a.val_acc
1015                    .partial_cmp(&b.val_acc)
1016                    .unwrap_or(std::cmp::Ordering::Equal)
1017            })
1018            .map(|(i, r)| (i, r.val_acc))
1019    }
1020
1021    /// Best (lowest) validation loss and its epoch.
1022    pub fn best_val_loss(&self) -> Option<(usize, f64)> {
1023        self.records
1024            .iter()
1025            .enumerate()
1026            .min_by(|(_, a), (_, b)| {
1027                a.val_loss
1028                    .partial_cmp(&b.val_loss)
1029                    .unwrap_or(std::cmp::Ordering::Equal)
1030            })
1031            .map(|(i, r)| (i, r.val_loss))
1032    }
1033
1034    /// Export history as a CSV string.
1035    pub fn to_csv(&self) -> String {
1036        let mut s = String::from("epoch,train_loss,val_loss,train_acc,val_acc,lr\n");
1037        for r in &self.records {
1038            s.push_str(&format!(
1039                "{},{:.6},{:.6},{:.6},{:.6},{:.8}\n",
1040                r.epoch, r.train_loss, r.val_loss, r.train_acc, r.val_acc, r.learning_rate
1041            ));
1042        }
1043        s
1044    }
1045}
1046
1047// ---------------------------------------------------------------------------
1048// HyperparamConfig — JSON-compatible hyperparameter store
1049// ---------------------------------------------------------------------------
1050
1051/// Typed hyperparameter value.
1052#[allow(dead_code)]
1053#[derive(Debug, Clone, PartialEq)]
1054pub enum HpValue {
1055    /// Floating point (also covers int values stored as f64).
1056    Float(f64),
1057    /// Boolean flag.
1058    Bool(bool),
1059    /// String-valued parameter.
1060    Str(String),
1061}
1062
1063impl HpValue {
1064    /// Return the f64 value if this is a `Float`, else `None`.
1065    pub fn as_float(&self) -> Option<f64> {
1066        if let HpValue::Float(v) = self {
1067            Some(*v)
1068        } else {
1069            None
1070        }
1071    }
1072
1073    /// Return the bool value if this is a `Bool`, else `None`.
1074    pub fn as_bool(&self) -> Option<bool> {
1075        if let HpValue::Bool(v) = self {
1076            Some(*v)
1077        } else {
1078            None
1079        }
1080    }
1081
1082    /// Return the string ref if this is a `Str`, else `None`.
1083    pub fn as_str(&self) -> Option<&str> {
1084        if let HpValue::Str(s) = self {
1085            Some(s.as_str())
1086        } else {
1087            None
1088        }
1089    }
1090}
1091
1092/// Hyperparameter configuration container.
1093#[allow(dead_code)]
1094#[derive(Debug, Clone, Default)]
1095pub struct HyperparamConfig {
1096    /// Key-value map.
1097    pub params: HashMap<String, HpValue>,
1098}
1099
1100impl HyperparamConfig {
1101    /// Create an empty config.
1102    pub fn new() -> Self {
1103        HyperparamConfig {
1104            params: HashMap::new(),
1105        }
1106    }
1107
1108    /// Set a float hyperparameter.
1109    pub fn set_float(&mut self, key: impl Into<String>, value: f64) {
1110        self.params.insert(key.into(), HpValue::Float(value));
1111    }
1112
1113    /// Set a boolean hyperparameter.
1114    pub fn set_bool(&mut self, key: impl Into<String>, value: bool) {
1115        self.params.insert(key.into(), HpValue::Bool(value));
1116    }
1117
1118    /// Set a string hyperparameter.
1119    pub fn set_str(&mut self, key: impl Into<String>, value: impl Into<String>) {
1120        self.params.insert(key.into(), HpValue::Str(value.into()));
1121    }
1122
1123    /// Get a float hyperparameter.
1124    pub fn get_float(&self, key: &str) -> Option<f64> {
1125        self.params.get(key)?.as_float()
1126    }
1127
1128    /// Get a bool hyperparameter.
1129    pub fn get_bool(&self, key: &str) -> Option<bool> {
1130        self.params.get(key)?.as_bool()
1131    }
1132
1133    /// Get a string hyperparameter.
1134    pub fn get_str(&self, key: &str) -> Option<&str> {
1135        self.params.get(key)?.as_str()
1136    }
1137
1138    /// Serialise as a simple JSON string (no external dependencies).
1139    pub fn to_json(&self) -> String {
1140        let mut parts: Vec<String> = Vec::new();
1141        let mut keys: Vec<&String> = self.params.keys().collect();
1142        keys.sort();
1143        for k in keys {
1144            let v_str = match &self.params[k] {
1145                HpValue::Float(f) => format!("{f}"),
1146                HpValue::Bool(b) => format!("{b}"),
1147                HpValue::Str(s) => format!("\"{}\"", s.replace('"', "\\\"")),
1148            };
1149            parts.push(format!("\"{}\":{}", k.replace('"', "\\\""), v_str));
1150        }
1151        format!("{{{}}}", parts.join(","))
1152    }
1153}
1154
1155// ---------------------------------------------------------------------------
1156// ModelCheckpoint
1157// ---------------------------------------------------------------------------
1158
1159/// Metadata stored alongside a model checkpoint.
1160#[allow(dead_code)]
1161#[derive(Debug, Clone)]
1162pub struct CheckpointMeta {
1163    /// Epoch at which the checkpoint was saved.
1164    pub epoch: usize,
1165    /// Validation loss at checkpoint time.
1166    pub val_loss: f64,
1167    /// Validation accuracy at checkpoint time.
1168    pub val_acc: f64,
1169    /// Wall-clock training time in seconds (cumulative).
1170    pub train_time_secs: f64,
1171    /// Model architecture description.
1172    pub architecture: String,
1173    /// Framework version string.
1174    pub framework_version: String,
1175}
1176
1177impl CheckpointMeta {
1178    /// Serialise as a simple text block.
1179    pub fn to_text(&self) -> String {
1180        format!(
1181            "epoch={}\nval_loss={:.8}\nval_acc={:.8}\ntrain_time_secs={:.3}\narchitecture={}\nframework_version={}\n",
1182            self.epoch,
1183            self.val_loss,
1184            self.val_acc,
1185            self.train_time_secs,
1186            self.architecture,
1187            self.framework_version
1188        )
1189    }
1190}
1191
1192/// A model checkpoint: state dict + metadata + hyperparameters.
1193#[allow(dead_code)]
1194#[derive(Debug, Clone)]
1195pub struct ModelCheckpoint {
1196    /// Model weights.
1197    pub state: StateDict,
1198    /// Checkpoint metadata.
1199    pub meta: CheckpointMeta,
1200    /// Hyperparameters used when this checkpoint was saved.
1201    pub hparams: HyperparamConfig,
1202}
1203
1204impl ModelCheckpoint {
1205    /// Create a new checkpoint.
1206    pub fn new(state: StateDict, meta: CheckpointMeta, hparams: HyperparamConfig) -> Self {
1207        ModelCheckpoint {
1208            state,
1209            meta,
1210            hparams,
1211        }
1212    }
1213
1214    /// Serialise to a flat byte buffer.
1215    ///
1216    /// Format: `\[state_len: u64\][state_bytes]\[meta_text_len: u64\][meta_text_utf8]\[hp_json_len: u64\][hp_json_utf8]`
1217    pub fn to_bytes(&self) -> Vec<u8> {
1218        let mut buf = Vec::new();
1219        let sb = self.state.to_bytes();
1220        buf.extend_from_slice(&(sb.len() as u64).to_le_bytes());
1221        buf.extend_from_slice(&sb);
1222        let mt = self.meta.to_text();
1223        let mb = mt.as_bytes();
1224        buf.extend_from_slice(&(mb.len() as u64).to_le_bytes());
1225        buf.extend_from_slice(mb);
1226        let hp = self.hparams.to_json();
1227        let hb = hp.as_bytes();
1228        buf.extend_from_slice(&(hb.len() as u64).to_le_bytes());
1229        buf.extend_from_slice(hb);
1230        buf
1231    }
1232
1233    /// Byte size of the serialised checkpoint.
1234    pub fn byte_size(&self) -> usize {
1235        self.to_bytes().len()
1236    }
1237}
1238
1239// ---------------------------------------------------------------------------
1240// Utility functions
1241// ---------------------------------------------------------------------------
1242
1243/// Compute softmax of a slice.
1244#[allow(dead_code)]
1245pub fn softmax(logits: &[f64]) -> Vec<f64> {
1246    if logits.is_empty() {
1247        return vec![];
1248    }
1249    let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1250    let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
1251    let sum: f64 = exps.iter().sum();
1252    if sum < 1e-15 {
1253        vec![1.0 / logits.len() as f64; logits.len()]
1254    } else {
1255        exps.iter().map(|e| e / sum).collect()
1256    }
1257}
1258
1259/// Compute cross-entropy loss between `probs` and one-hot `targets`.
1260#[allow(dead_code)]
1261pub fn cross_entropy_loss(probs: &[f64], targets: &[f64]) -> f64 {
1262    probs
1263        .iter()
1264        .zip(targets)
1265        .map(|(&p, &t)| -t * (p.max(1e-15)).ln())
1266        .sum()
1267}
1268
1269/// Argmax: index of the maximum value.
1270#[allow(dead_code)]
1271pub fn argmax(values: &[f64]) -> usize {
1272    values
1273        .iter()
1274        .enumerate()
1275        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1276        .map(|(i, _)| i)
1277        .unwrap_or(0)
1278}
1279
1280/// Compute mean squared error.
1281#[allow(dead_code)]
1282pub fn mse(predictions: &[f64], targets: &[f64]) -> f64 {
1283    if predictions.is_empty() {
1284        return 0.0;
1285    }
1286    let n = predictions.len().min(targets.len()) as f64;
1287    predictions
1288        .iter()
1289        .zip(targets)
1290        .map(|(&p, &t)| {
1291            let d = p - t;
1292            d * d
1293        })
1294        .sum::<f64>()
1295        / n
1296}
1297
1298/// Compute mean absolute error.
1299#[allow(dead_code)]
1300pub fn mae(predictions: &[f64], targets: &[f64]) -> f64 {
1301    if predictions.is_empty() {
1302        return 0.0;
1303    }
1304    let n = predictions.len().min(targets.len()) as f64;
1305    predictions
1306        .iter()
1307        .zip(targets)
1308        .map(|(&p, &t)| (p - t).abs())
1309        .sum::<f64>()
1310        / n
1311}
1312
1313// ---------------------------------------------------------------------------
1314// Unit tests
1315// ---------------------------------------------------------------------------
1316
1317#[cfg(test)]
1318mod tests {
1319    use super::*;
1320
1321    // --- Tensor ---
1322
1323    #[test]
1324    fn test_tensor_new_shape_mismatch_panics() {
1325        let result = std::panic::catch_unwind(|| Tensor::new(vec![2, 3], vec![0.0; 5]));
1326        assert!(result.is_err());
1327    }
1328
1329    #[test]
1330    fn test_tensor_zeros() {
1331        let t = Tensor::zeros(vec![3, 4]);
1332        assert_eq!(t.numel(), 12);
1333        assert!(t.data.iter().all(|&v| v == 0.0));
1334    }
1335
1336    #[test]
1337    fn test_tensor_numel() {
1338        let t = Tensor::new(vec![2, 3], vec![1.0; 6]);
1339        assert_eq!(t.numel(), 6);
1340        assert_eq!(t.ndim(), 2);
1341    }
1342
1343    #[test]
1344    fn test_tensor_sum_mean() {
1345        let t = Tensor::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
1346        assert!((t.sum() - 10.0).abs() < 1e-12);
1347        assert!((t.mean() - 2.5).abs() < 1e-12);
1348    }
1349
1350    #[test]
1351    fn test_tensor_scale() {
1352        let t = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1353        let t2 = t.scale(2.0);
1354        assert!((t2.data[1] - 4.0).abs() < 1e-12);
1355    }
1356
1357    #[test]
1358    fn test_tensor_add() {
1359        let a = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1360        let b = Tensor::new(vec![3], vec![4.0, 5.0, 6.0]);
1361        let c = a.add(&b).unwrap();
1362        assert!((c.data[2] - 9.0).abs() < 1e-12);
1363    }
1364
1365    #[test]
1366    fn test_tensor_add_shape_mismatch() {
1367        let a = Tensor::new(vec![2], vec![1.0, 2.0]);
1368        let b = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1369        assert!(a.add(&b).is_none());
1370    }
1371
1372    #[test]
1373    fn test_tensor_roundtrip_bytes() {
1374        let t = Tensor::new(vec![2, 3], vec![1.0, -2.5, 0.0, 3.125, 1e10, -1e-5]);
1375        let bytes = t.to_bytes();
1376        let t2 = Tensor::from_bytes(&bytes).unwrap();
1377        assert_eq!(t2.shape, t.shape);
1378        for (a, b) in t.data.iter().zip(&t2.data) {
1379            assert!((a - b).abs() < 1e-15);
1380        }
1381    }
1382
1383    #[test]
1384    fn test_tensor_from_bytes_empty_is_none() {
1385        assert!(Tensor::from_bytes(&[]).is_none());
1386    }
1387
1388    // --- DenseLayer ---
1389
1390    #[test]
1391    fn test_dense_layer_param_count() {
1392        let layer = DenseLayer::new("fc1", 4, 3, "relu");
1393        // weights: 3*4=12, bias: 3
1394        assert_eq!(layer.param_count(), 15);
1395    }
1396
1397    #[test]
1398    fn test_dense_layer_forward_zero_weights() {
1399        let layer = DenseLayer::new("fc", 3, 2, "linear");
1400        let input = vec![1.0, 2.0, 3.0];
1401        let out = layer.forward(&input);
1402        assert_eq!(out.len(), 2);
1403        // All zero weights+bias → output is 0
1404        for v in &out {
1405            assert!(v.abs() < 1e-12);
1406        }
1407    }
1408
1409    #[test]
1410    fn test_dense_layer_activation_relu() {
1411        assert!((apply_activation(-5.0, "relu")).abs() < 1e-12);
1412        assert!((apply_activation(3.0, "relu") - 3.0).abs() < 1e-12);
1413    }
1414
1415    #[test]
1416    fn test_dense_layer_activation_sigmoid() {
1417        let v = apply_activation(0.0, "sigmoid");
1418        assert!((v - 0.5).abs() < 1e-12);
1419    }
1420
1421    #[test]
1422    fn test_dense_layer_activation_tanh() {
1423        let v = apply_activation(0.0, "tanh");
1424        assert!(v.abs() < 1e-12);
1425    }
1426
1427    // --- ModelWeights ---
1428
1429    #[test]
1430    fn test_model_weights_add_and_get() {
1431        let mut model = ModelWeights::new();
1432        model.add_layer(DenseLayer::new("l1", 4, 8, "relu"));
1433        model.add_layer(DenseLayer::new("l2", 8, 2, "sigmoid"));
1434        assert_eq!(model.layers.len(), 2);
1435        assert!(model.get_layer("l1").is_some());
1436        assert!(model.get_layer("l3").is_none());
1437    }
1438
1439    #[test]
1440    fn test_model_weights_total_params() {
1441        let mut model = ModelWeights::new();
1442        model.add_layer(DenseLayer::new("l1", 4, 3, "relu")); // 12+3=15
1443        model.add_layer(DenseLayer::new("l2", 3, 2, "linear")); // 6+2=8
1444        assert_eq!(model.total_params(), 23);
1445    }
1446
1447    #[test]
1448    fn test_model_weights_to_bytes_nonempty() {
1449        let mut model = ModelWeights::new();
1450        model.add_layer(DenseLayer::new("l1", 2, 2, "relu"));
1451        let bytes = model.to_bytes();
1452        assert!(!bytes.is_empty());
1453    }
1454
1455    // --- StateDict ---
1456
1457    #[test]
1458    fn test_state_dict_insert_and_get() {
1459        let mut sd = StateDict::new();
1460        sd.insert("w1", Tensor::zeros(vec![4, 4]));
1461        assert_eq!(sd.len(), 1);
1462        assert_eq!(sd.get("w1").unwrap().numel(), 16);
1463    }
1464
1465    #[test]
1466    fn test_state_dict_total_params() {
1467        let mut sd = StateDict::new();
1468        sd.insert("a", Tensor::zeros(vec![3, 3]));
1469        sd.insert("b", Tensor::zeros(vec![3]));
1470        assert_eq!(sd.total_params(), 12);
1471    }
1472
1473    #[test]
1474    fn test_state_dict_roundtrip() {
1475        let mut sd = StateDict::new();
1476        sd.insert("w", Tensor::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]));
1477        sd.insert("b", Tensor::new(vec![2], vec![0.5, -0.5]));
1478        let bytes = sd.to_bytes();
1479        let sd2 = StateDict::from_bytes(&bytes).unwrap();
1480        assert_eq!(sd2.len(), 2);
1481        let w = sd2.get("w").unwrap();
1482        assert!((w.data[3] - 4.0).abs() < 1e-12);
1483    }
1484
1485    // --- OnnxLikeGraph ---
1486
1487    #[test]
1488    fn test_onnx_graph_node_count() {
1489        let mut g = OnnxLikeGraph::new("test_model");
1490        g.add_node(OnnxNode::new(
1491            "n0",
1492            "MatMul",
1493            vec!["x".into(), "w0".into()],
1494            vec!["h0".into()],
1495        ));
1496        g.add_node(OnnxNode::new(
1497            "n1",
1498            "Relu",
1499            vec!["h0".into()],
1500            vec!["h1".into()],
1501        ));
1502        assert_eq!(g.node_count(), 2);
1503        assert_eq!(g.count_op("Relu"), 1);
1504    }
1505
1506    #[test]
1507    fn test_onnx_graph_topological_valid() {
1508        let mut g = OnnxLikeGraph::new("model");
1509        g.inputs.push("x".into());
1510        g.add_initializer("w0", Tensor::zeros(vec![4, 4]));
1511        g.add_node(OnnxNode::new(
1512            "mm",
1513            "MatMul",
1514            vec!["x".into(), "w0".into()],
1515            vec!["y".into()],
1516        ));
1517        g.add_node(OnnxNode::new(
1518            "act",
1519            "Relu",
1520            vec!["y".into()],
1521            vec!["z".into()],
1522        ));
1523        assert!(g.is_topologically_valid());
1524    }
1525
1526    #[test]
1527    fn test_onnx_graph_topological_invalid() {
1528        let mut g = OnnxLikeGraph::new("model");
1529        g.inputs.push("x".into());
1530        // "undefined" is neither an input nor output of any prior node
1531        g.add_node(OnnxNode::new(
1532            "act",
1533            "Relu",
1534            vec!["undefined".into()],
1535            vec!["z".into()],
1536        ));
1537        assert!(!g.is_topologically_valid());
1538    }
1539
1540    // --- Dataset ---
1541
1542    #[test]
1543    fn test_dataset_len_and_features() {
1544        let mut ds = Dataset::new();
1545        ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
1546        ds.push(DataRow::labelled(vec![3.0, 4.0], 1));
1547        assert_eq!(ds.len(), 2);
1548        assert_eq!(ds.num_features(), 2);
1549    }
1550
1551    #[test]
1552    fn test_dataset_shuffle_changes_order() {
1553        let mut ds = Dataset::new();
1554        for i in 0..20 {
1555            ds.push(DataRow::labelled(vec![i as f64], 0));
1556        }
1557        let original: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
1558        ds.shuffle(42);
1559        let shuffled: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
1560        assert_ne!(original, shuffled);
1561    }
1562
1563    #[test]
1564    fn test_dataset_train_val_split() {
1565        let mut ds = Dataset::new();
1566        for i in 0..100 {
1567            ds.push(DataRow::labelled(vec![i as f64], 0));
1568        }
1569        let (train, val) = ds.train_val_split(0.2);
1570        assert_eq!(train.len(), 80);
1571        assert_eq!(val.len(), 20);
1572    }
1573
1574    #[test]
1575    fn test_dataset_feature_stats() {
1576        let mut ds = Dataset::new();
1577        ds.push(DataRow::labelled(vec![0.0, 10.0], 0));
1578        ds.push(DataRow::labelled(vec![2.0, 10.0], 1));
1579        let (means, _stds) = ds.feature_stats();
1580        assert!((means[0] - 1.0).abs() < 1e-12);
1581        assert!((means[1] - 10.0).abs() < 1e-12);
1582    }
1583
1584    // --- NormalizationParams ---
1585
1586    #[test]
1587    fn test_normalization_zscore() {
1588        let mut ds = Dataset::new();
1589        ds.push(DataRow::labelled(vec![0.0], 0));
1590        ds.push(DataRow::labelled(vec![2.0], 0));
1591        let norm = NormalizationParams::from_dataset(&ds);
1592        let z = norm.normalize_zscore(&[1.0]);
1593        // (1 - 1) / std = 0
1594        assert!(z[0].abs() < 1e-10);
1595    }
1596
1597    #[test]
1598    fn test_normalization_minmax() {
1599        let mut ds = Dataset::new();
1600        ds.push(DataRow::labelled(vec![0.0], 0));
1601        ds.push(DataRow::labelled(vec![10.0], 0));
1602        let norm = NormalizationParams::from_dataset(&ds);
1603        let v = norm.normalize_minmax(&[5.0]);
1604        assert!((v[0] - 0.5).abs() < 1e-12);
1605    }
1606
1607    #[test]
1608    fn test_normalization_bytes_nonempty() {
1609        let mut ds = Dataset::new();
1610        ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
1611        let norm = NormalizationParams::from_dataset(&ds);
1612        assert!(!norm.to_bytes().is_empty());
1613    }
1614
1615    // --- LabelEncoder ---
1616
1617    #[test]
1618    fn test_label_encoder_fit_and_encode() {
1619        let enc = LabelEncoder::fit(vec!["cat".into(), "dog".into(), "bird".into()]);
1620        assert_eq!(enc.num_classes(), 3);
1621        let i = enc.encode("dog").unwrap();
1622        assert_eq!(enc.decode(i), Some("dog"));
1623    }
1624
1625    #[test]
1626    fn test_label_encoder_one_hot() {
1627        let enc = LabelEncoder::fit(vec!["a".into(), "b".into(), "c".into()]);
1628        let oh = enc.one_hot(enc.encode("b").unwrap());
1629        assert_eq!(oh.iter().filter(|&&v| v == 1.0).count(), 1);
1630        assert!((oh.iter().sum::<f64>() - 1.0).abs() < 1e-12);
1631    }
1632
1633    #[test]
1634    fn test_label_encoder_unknown_returns_none() {
1635        let enc = LabelEncoder::fit(vec!["a".into()]);
1636        assert!(enc.encode("z").is_none());
1637    }
1638
1639    // --- ConfusionMatrix ---
1640
1641    #[test]
1642    fn test_confusion_matrix_accuracy() {
1643        let mut cm = ConfusionMatrix::new(2);
1644        cm.record(0, 0);
1645        cm.record(0, 0);
1646        cm.record(1, 1);
1647        cm.record(1, 0); // wrong
1648        assert!((cm.accuracy() - 0.75).abs() < 1e-12);
1649    }
1650
1651    #[test]
1652    fn test_confusion_matrix_precision_recall() {
1653        let mut cm = ConfusionMatrix::new(2);
1654        cm.record(0, 0); // TP for class 0
1655        cm.record(0, 1); // FN for class 0
1656        cm.record(1, 0); // FP for class 0
1657        cm.record(1, 1); // TN for class 0
1658        let p = cm.precision(0);
1659        let r = cm.recall(0);
1660        assert!((p - 0.5).abs() < 1e-12);
1661        assert!((r - 0.5).abs() < 1e-12);
1662    }
1663
1664    #[test]
1665    fn test_confusion_matrix_to_csv() {
1666        let mut cm = ConfusionMatrix::new(2);
1667        cm.record(0, 0);
1668        cm.record(1, 1);
1669        let csv = cm.to_csv();
1670        assert!(csv.contains("class_0"));
1671        assert!(csv.contains("class_1"));
1672    }
1673
1674    // --- TrainingHistory ---
1675
1676    #[test]
1677    fn test_training_history_best_val_acc() {
1678        let mut hist = TrainingHistory::new();
1679        for e in 0..5 {
1680            hist.push(EpochRecord {
1681                epoch: e,
1682                train_loss: 1.0 - e as f64 * 0.1,
1683                val_loss: 1.0 - e as f64 * 0.08,
1684                train_acc: e as f64 * 0.2,
1685                val_acc: e as f64 * 0.18,
1686                learning_rate: 0.001,
1687            });
1688        }
1689        let (best_epoch, best_acc) = hist.best_val_acc().unwrap();
1690        assert_eq!(best_epoch, 4);
1691        assert!((best_acc - 0.72).abs() < 1e-10);
1692    }
1693
1694    #[test]
1695    fn test_training_history_to_csv() {
1696        let mut hist = TrainingHistory::new();
1697        hist.push(EpochRecord {
1698            epoch: 0,
1699            train_loss: 0.9,
1700            val_loss: 0.85,
1701            train_acc: 0.6,
1702            val_acc: 0.62,
1703            learning_rate: 0.01,
1704        });
1705        let csv = hist.to_csv();
1706        assert!(csv.starts_with("epoch,"));
1707        assert!(csv.contains("0,"));
1708    }
1709
1710    // --- HyperparamConfig ---
1711
1712    #[test]
1713    fn test_hyperparam_config_get_set() {
1714        let mut cfg = HyperparamConfig::new();
1715        cfg.set_float("lr", 0.001);
1716        cfg.set_bool("dropout", true);
1717        cfg.set_str("optimizer", "adam");
1718        assert!((cfg.get_float("lr").unwrap() - 0.001).abs() < 1e-15);
1719        assert!(cfg.get_bool("dropout").unwrap());
1720        assert_eq!(cfg.get_str("optimizer").unwrap(), "adam");
1721    }
1722
1723    #[test]
1724    fn test_hyperparam_config_to_json() {
1725        let mut cfg = HyperparamConfig::new();
1726        cfg.set_float("lr", 0.01);
1727        let json = cfg.to_json();
1728        assert!(json.contains("lr"));
1729        assert!(json.starts_with('{'));
1730        assert!(json.ends_with('}'));
1731    }
1732
1733    // --- ModelCheckpoint ---
1734
1735    #[test]
1736    fn test_checkpoint_byte_size_nonzero() {
1737        let state = StateDict::new();
1738        let meta = CheckpointMeta {
1739            epoch: 10,
1740            val_loss: 0.1,
1741            val_acc: 0.95,
1742            train_time_secs: 3600.0,
1743            architecture: "MLP".into(),
1744            framework_version: "0.1.0".into(),
1745        };
1746        let hparams = HyperparamConfig::new();
1747        let ck = ModelCheckpoint::new(state, meta, hparams);
1748        assert!(ck.byte_size() > 0);
1749    }
1750
1751    #[test]
1752    fn test_checkpoint_meta_to_text_contains_epoch() {
1753        let meta = CheckpointMeta {
1754            epoch: 42,
1755            val_loss: 0.05,
1756            val_acc: 0.98,
1757            train_time_secs: 100.0,
1758            architecture: "CNN".into(),
1759            framework_version: "0.1.0".into(),
1760        };
1761        let text = meta.to_text();
1762        assert!(text.contains("epoch=42"));
1763    }
1764
1765    // --- utility functions ---
1766
1767    #[test]
1768    fn test_softmax_sums_to_one() {
1769        let logits = vec![1.0, 2.0, 3.0];
1770        let probs = softmax(&logits);
1771        let total: f64 = probs.iter().sum();
1772        assert!((total - 1.0).abs() < 1e-12);
1773    }
1774
1775    #[test]
1776    fn test_softmax_max_has_highest_prob() {
1777        let logits = vec![1.0, 5.0, 2.0];
1778        let probs = softmax(&logits);
1779        assert!(probs[1] > probs[0] && probs[1] > probs[2]);
1780    }
1781
1782    #[test]
1783    fn test_cross_entropy_perfect_prediction() {
1784        let probs = vec![0.0, 1.0, 0.0];
1785        let targets = vec![0.0, 1.0, 0.0];
1786        let loss = cross_entropy_loss(&probs, &targets);
1787        assert!(loss < 1e-10);
1788    }
1789
1790    #[test]
1791    fn test_argmax_basic() {
1792        let v = vec![0.1, 0.7, 0.2];
1793        assert_eq!(argmax(&v), 1);
1794    }
1795
1796    #[test]
1797    fn test_mse_zero() {
1798        let p = vec![1.0, 2.0, 3.0];
1799        let t = vec![1.0, 2.0, 3.0];
1800        assert!(mse(&p, &t).abs() < 1e-12);
1801    }
1802
1803    #[test]
1804    fn test_mse_known() {
1805        let p = vec![0.0, 0.0];
1806        let t = vec![1.0, 1.0];
1807        assert!((mse(&p, &t) - 1.0).abs() < 1e-12);
1808    }
1809
1810    #[test]
1811    fn test_mae_basic() {
1812        let p = vec![0.0, 1.0, 2.0];
1813        let t = vec![1.0, 1.0, 3.0];
1814        // |0-1| + |1-1| + |2-3| = 1+0+1 = 2, /3 = 0.666...
1815        let m = mae(&p, &t);
1816        assert!((m - 2.0 / 3.0).abs() < 1e-12);
1817    }
1818
1819    #[test]
1820    fn test_apply_activation_leaky_relu() {
1821        assert!((apply_activation(-1.0, "leaky_relu") - (-0.01)).abs() < 1e-12);
1822        assert!((apply_activation(2.0, "leaky_relu") - 2.0).abs() < 1e-12);
1823    }
1824
1825    #[test]
1826    fn test_apply_activation_elu() {
1827        let v = apply_activation(-1.0, "elu");
1828        // elu(-1) = e^(-1) - 1 ≈ -0.6321
1829        assert!(v < 0.0 && v > -1.0);
1830    }
1831
1832    #[test]
1833    fn test_lcg_rng_produces_different_values() {
1834        let mut rng = LcgRng::new(1234);
1835        let a = rng.next_u64();
1836        let b = rng.next_u64();
1837        assert_ne!(a, b);
1838    }
1839}