1#![allow(clippy::needless_range_loop)]
2use std::collections::HashMap;
20
21#[allow(dead_code)]
27#[derive(Debug, Clone, PartialEq)]
28pub struct Tensor {
29 pub shape: Vec<usize>,
31 pub data: Vec<f64>,
33}
34
35impl Tensor {
36 pub fn new(shape: Vec<usize>, data: Vec<f64>) -> Self {
41 let expected: usize = shape.iter().product();
42 assert_eq!(
43 data.len(),
44 expected,
45 "data length {} does not match shape {:?} (product {})",
46 data.len(),
47 shape,
48 expected
49 );
50 Tensor { shape, data }
51 }
52
53 pub fn zeros(shape: Vec<usize>) -> Self {
55 let n: usize = shape.iter().product();
56 Tensor {
57 shape,
58 data: vec![0.0; n],
59 }
60 }
61
62 pub fn numel(&self) -> usize {
64 self.data.len()
65 }
66
67 pub fn ndim(&self) -> usize {
69 self.shape.len()
70 }
71
72 pub fn to_bytes(&self) -> Vec<u8> {
74 let mut buf = Vec::with_capacity(8 + 8 * self.shape.len() + 8 * self.data.len());
75 buf.extend_from_slice(&(self.shape.len() as u64).to_le_bytes());
76 for &d in &self.shape {
77 buf.extend_from_slice(&(d as u64).to_le_bytes());
78 }
79 for &v in &self.data {
80 buf.extend_from_slice(&v.to_bits().to_le_bytes());
81 }
82 buf
83 }
84
85 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
87 if bytes.len() < 8 {
88 return None;
89 }
90 let ndim = u64::from_le_bytes(bytes[0..8].try_into().ok()?) as usize;
91 let header_len = 8 + 8 * ndim;
92 if bytes.len() < header_len {
93 return None;
94 }
95 let mut shape = Vec::with_capacity(ndim);
96 for i in 0..ndim {
97 let off = 8 + 8 * i;
98 shape.push(u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?) as usize);
99 }
100 let n: usize = shape.iter().product();
101 if bytes.len() < header_len + 8 * n {
102 return None;
103 }
104 let mut data = Vec::with_capacity(n);
105 for i in 0..n {
106 let off = header_len + 8 * i;
107 let bits = u64::from_le_bytes(bytes[off..off + 8].try_into().ok()?);
108 data.push(f64::from_bits(bits));
109 }
110 Some(Tensor { shape, data })
111 }
112
113 pub fn add(&self, other: &Tensor) -> Option<Tensor> {
115 if self.shape != other.shape {
116 return None;
117 }
118 let data = self
119 .data
120 .iter()
121 .zip(&other.data)
122 .map(|(a, b)| a + b)
123 .collect();
124 Some(Tensor {
125 shape: self.shape.clone(),
126 data,
127 })
128 }
129
130 pub fn scale(&self, s: f64) -> Tensor {
132 Tensor {
133 shape: self.shape.clone(),
134 data: self.data.iter().map(|v| v * s).collect(),
135 }
136 }
137
138 pub fn sum(&self) -> f64 {
140 self.data.iter().sum()
141 }
142
143 pub fn mean(&self) -> f64 {
145 if self.data.is_empty() {
146 return 0.0;
147 }
148 self.sum() / self.data.len() as f64
149 }
150}
151
152#[allow(dead_code)]
158#[derive(Debug, Clone)]
159pub struct DenseLayer {
160 pub name: String,
162 pub weights: Tensor,
164 pub bias: Tensor,
166 pub activation: String,
168}
169
170impl DenseLayer {
171 pub fn new(
173 name: impl Into<String>,
174 in_features: usize,
175 out_features: usize,
176 activation: impl Into<String>,
177 ) -> Self {
178 DenseLayer {
179 name: name.into(),
180 weights: Tensor::zeros(vec![out_features, in_features]),
181 bias: Tensor::zeros(vec![out_features]),
182 activation: activation.into(),
183 }
184 }
185
186 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
188 let in_feat = input.len();
189 let out_feat = self.bias.data.len();
190 let mut out = vec![0.0f64; out_feat];
191 for i in 0..out_feat {
192 let mut acc = self.bias.data[i];
193 for j in 0..in_feat.min(self.weights.data.len() / out_feat) {
194 acc += self.weights.data[i * in_feat + j] * input[j];
195 }
196 out[i] = apply_activation(acc, &self.activation);
197 }
198 out
199 }
200
201 pub fn param_count(&self) -> usize {
203 self.weights.numel() + self.bias.numel()
204 }
205
206 pub fn to_bytes(&self) -> Vec<u8> {
208 let name_bytes = self.name.as_bytes();
209 let act_bytes = self.activation.as_bytes();
210 let mut buf = Vec::new();
211 buf.extend_from_slice(&(name_bytes.len() as u64).to_le_bytes());
212 buf.extend_from_slice(name_bytes);
213 buf.extend_from_slice(&(act_bytes.len() as u64).to_le_bytes());
214 buf.extend_from_slice(act_bytes);
215 let wb = self.weights.to_bytes();
216 buf.extend_from_slice(&(wb.len() as u64).to_le_bytes());
217 buf.extend_from_slice(&wb);
218 let bb = self.bias.to_bytes();
219 buf.extend_from_slice(&(bb.len() as u64).to_le_bytes());
220 buf.extend_from_slice(&bb);
221 buf
222 }
223}
224
225#[allow(dead_code)]
227pub fn apply_activation(x: f64, activation: &str) -> f64 {
228 match activation {
229 "relu" => x.max(0.0),
230 "sigmoid" => 1.0 / (1.0 + (-x).exp()),
231 "tanh" => x.tanh(),
232 "softplus" => (1.0 + x.exp()).ln(),
233 "elu" => {
234 if x >= 0.0 {
235 x
236 } else {
237 x.exp() - 1.0
238 }
239 }
240 "leaky_relu" => {
241 if x >= 0.0 {
242 x
243 } else {
244 0.01 * x
245 }
246 }
247 _ => x, }
249}
250
251#[allow(dead_code)]
257#[derive(Debug, Clone, Default)]
258pub struct ModelWeights {
259 pub layers: Vec<DenseLayer>,
261}
262
263impl ModelWeights {
264 pub fn new() -> Self {
266 ModelWeights { layers: Vec::new() }
267 }
268
269 pub fn add_layer(&mut self, layer: DenseLayer) {
271 self.layers.push(layer);
272 }
273
274 pub fn get_layer(&self, name: &str) -> Option<&DenseLayer> {
276 self.layers.iter().find(|l| l.name == name)
277 }
278
279 pub fn total_params(&self) -> usize {
281 self.layers.iter().map(|l| l.param_count()).sum()
282 }
283
284 pub fn to_bytes(&self) -> Vec<u8> {
288 let mut buf = Vec::new();
289 buf.extend_from_slice(&(self.layers.len() as u64).to_le_bytes());
290 for layer in &self.layers {
291 let lb = layer.to_bytes();
292 buf.extend_from_slice(&(lb.len() as u64).to_le_bytes());
293 buf.extend_from_slice(&lb);
294 }
295 buf
296 }
297}
298
299#[allow(dead_code)]
305#[derive(Debug, Clone, Default)]
306pub struct StateDict {
307 pub tensors: HashMap<String, Tensor>,
309}
310
311impl StateDict {
312 pub fn new() -> Self {
314 StateDict {
315 tensors: HashMap::new(),
316 }
317 }
318
319 pub fn insert(&mut self, key: impl Into<String>, tensor: Tensor) {
321 self.tensors.insert(key.into(), tensor);
322 }
323
324 pub fn get(&self, key: &str) -> Option<&Tensor> {
326 self.tensors.get(key)
327 }
328
329 pub fn len(&self) -> usize {
331 self.tensors.len()
332 }
333
334 pub fn is_empty(&self) -> bool {
336 self.tensors.is_empty()
337 }
338
339 pub fn total_params(&self) -> usize {
341 self.tensors.values().map(|t| t.numel()).sum()
342 }
343
344 pub fn to_bytes(&self) -> Vec<u8> {
348 let mut buf = Vec::new();
349 buf.extend_from_slice(&(self.tensors.len() as u64).to_le_bytes());
350 let mut keys: Vec<&String> = self.tensors.keys().collect();
351 keys.sort(); for k in keys {
353 let kb = k.as_bytes();
354 buf.extend_from_slice(&(kb.len() as u64).to_le_bytes());
355 buf.extend_from_slice(kb);
356 let tb = self.tensors[k].to_bytes();
357 buf.extend_from_slice(&(tb.len() as u64).to_le_bytes());
358 buf.extend_from_slice(&tb);
359 }
360 buf
361 }
362
363 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
365 let mut pos = 0usize;
366 let n = read_u64(bytes, &mut pos)? as usize;
367 let mut dict = StateDict::new();
368 for _ in 0..n {
369 let klen = read_u64(bytes, &mut pos)? as usize;
370 if pos + klen > bytes.len() {
371 return None;
372 }
373 let key = String::from_utf8(bytes[pos..pos + klen].to_vec()).ok()?;
374 pos += klen;
375 let tlen = read_u64(bytes, &mut pos)? as usize;
376 if pos + tlen > bytes.len() {
377 return None;
378 }
379 let tensor = Tensor::from_bytes(&bytes[pos..pos + tlen])?;
380 pos += tlen;
381 dict.insert(key, tensor);
382 }
383 Some(dict)
384 }
385}
386
387fn read_u64(bytes: &[u8], pos: &mut usize) -> Option<u64> {
389 if *pos + 8 > bytes.len() {
390 return None;
391 }
392 let v = u64::from_le_bytes(bytes[*pos..*pos + 8].try_into().ok()?);
393 *pos += 8;
394 Some(v)
395}
396
397#[allow(dead_code)]
403#[derive(Debug, Clone)]
404pub struct OnnxNode {
405 pub name: String,
407 pub op_type: String,
409 pub inputs: Vec<String>,
411 pub outputs: Vec<String>,
413 pub attributes: HashMap<String, f64>,
415}
416
417impl OnnxNode {
418 pub fn new(
420 name: impl Into<String>,
421 op_type: impl Into<String>,
422 inputs: Vec<String>,
423 outputs: Vec<String>,
424 ) -> Self {
425 OnnxNode {
426 name: name.into(),
427 op_type: op_type.into(),
428 inputs,
429 outputs,
430 attributes: HashMap::new(),
431 }
432 }
433
434 pub fn with_attr(mut self, key: impl Into<String>, value: f64) -> Self {
436 self.attributes.insert(key.into(), value);
437 self
438 }
439}
440
441#[allow(dead_code)]
443#[derive(Debug, Clone, Default)]
444pub struct OnnxLikeGraph {
445 pub nodes: Vec<OnnxNode>,
447 pub initializers: StateDict,
449 pub inputs: Vec<String>,
451 pub outputs: Vec<String>,
453 pub name: String,
455}
456
457impl OnnxLikeGraph {
458 pub fn new(name: impl Into<String>) -> Self {
460 OnnxLikeGraph {
461 name: name.into(),
462 nodes: Vec::new(),
463 initializers: StateDict::new(),
464 inputs: Vec::new(),
465 outputs: Vec::new(),
466 }
467 }
468
469 pub fn add_node(&mut self, node: OnnxNode) {
471 self.nodes.push(node);
472 }
473
474 pub fn add_initializer(&mut self, name: impl Into<String>, tensor: Tensor) {
476 self.initializers.insert(name, tensor);
477 }
478
479 pub fn node_count(&self) -> usize {
481 self.nodes.len()
482 }
483
484 pub fn count_op(&self, op: &str) -> usize {
486 self.nodes.iter().filter(|n| n.op_type == op).count()
487 }
488
489 pub fn is_topologically_valid(&self) -> bool {
492 let mut available: std::collections::HashSet<&str> =
493 self.inputs.iter().map(|s| s.as_str()).collect();
494 for k in self.initializers.tensors.keys() {
496 available.insert(k.as_str());
497 }
498 for node in &self.nodes {
499 for inp in &node.inputs {
500 if !available.contains(inp.as_str()) {
501 return false;
502 }
503 }
504 for out in &node.outputs {
505 available.insert(out.as_str());
506 }
507 }
508 true
509 }
510}
511
512#[allow(dead_code)]
518#[derive(Debug, Clone)]
519pub struct DataRow {
520 pub features: Vec<f64>,
522 pub label: Option<usize>,
524}
525
526impl DataRow {
527 pub fn labelled(features: Vec<f64>, label: usize) -> Self {
529 DataRow {
530 features,
531 label: Some(label),
532 }
533 }
534
535 pub fn unlabelled(features: Vec<f64>) -> Self {
537 DataRow {
538 features,
539 label: None,
540 }
541 }
542}
543
544#[allow(dead_code)]
546#[derive(Debug, Clone, Default)]
547pub struct Dataset {
548 pub rows: Vec<DataRow>,
550 pub feature_names: Vec<String>,
552 pub class_names: Vec<String>,
554}
555
556impl Dataset {
557 pub fn new() -> Self {
559 Dataset {
560 rows: Vec::new(),
561 feature_names: Vec::new(),
562 class_names: Vec::new(),
563 }
564 }
565
566 pub fn push(&mut self, row: DataRow) {
568 self.rows.push(row);
569 }
570
571 pub fn len(&self) -> usize {
573 self.rows.len()
574 }
575
576 pub fn is_empty(&self) -> bool {
578 self.rows.is_empty()
579 }
580
581 pub fn num_features(&self) -> usize {
583 self.rows.first().map(|r| r.features.len()).unwrap_or(0)
584 }
585
586 pub fn shuffle(&mut self, seed: u64) {
588 let n = self.rows.len();
589 if n < 2 {
590 return;
591 }
592 let mut rng = LcgRng::new(seed);
593 for i in (1..n).rev() {
594 let j = rng.next_usize_below(i + 1);
595 self.rows.swap(i, j);
596 }
597 }
598
599 pub fn train_val_split(&self, val_fraction: f64) -> (Dataset, Dataset) {
603 let val_count = ((self.rows.len() as f64) * val_fraction.clamp(0.0, 1.0)) as usize;
604 let train_count = self.rows.len().saturating_sub(val_count);
605 let mut train = Dataset {
606 rows: self.rows[..train_count].to_vec(),
607 feature_names: self.feature_names.clone(),
608 class_names: self.class_names.clone(),
609 };
610 let mut val = Dataset {
611 rows: self.rows[train_count..].to_vec(),
612 feature_names: self.feature_names.clone(),
613 class_names: self.class_names.clone(),
614 };
615 let _ = &mut train;
617 let _ = &mut val;
618 (train, val)
619 }
620
621 pub fn feature_stats(&self) -> (Vec<f64>, Vec<f64>) {
625 let nf = self.num_features();
626 if nf == 0 || self.rows.is_empty() {
627 return (vec![], vec![]);
628 }
629 let n = self.rows.len() as f64;
630 let mut means = vec![0.0f64; nf];
631 for row in &self.rows {
632 for (k, &v) in row.features.iter().enumerate() {
633 means[k] += v;
634 }
635 }
636 for m in &mut means {
637 *m /= n;
638 }
639 let mut stds = vec![0.0f64; nf];
640 for row in &self.rows {
641 for (k, &v) in row.features.iter().enumerate() {
642 let d = v - means[k];
643 stds[k] += d * d;
644 }
645 }
646 for s in &mut stds {
647 *s = (*s / n).sqrt();
648 }
649 (means, stds)
650 }
651}
652
653#[allow(dead_code)]
659struct LcgRng {
660 state: u64,
661}
662
663impl LcgRng {
664 fn new(seed: u64) -> Self {
665 LcgRng {
666 state: seed ^ 0x1234_5678_9abc_def0,
667 }
668 }
669
670 fn next_u64(&mut self) -> u64 {
671 self.state = self
673 .state
674 .wrapping_mul(6_364_136_223_846_793_005)
675 .wrapping_add(1_442_695_040_888_963_407);
676 self.state
677 }
678
679 fn next_usize_below(&mut self, n: usize) -> usize {
680 if n == 0 {
681 return 0;
682 }
683 (self.next_u64() % n as u64) as usize
684 }
685}
686
687#[allow(dead_code)]
693#[derive(Debug, Clone)]
694pub struct NormalizationParams {
695 pub means: Vec<f64>,
697 pub stds: Vec<f64>,
699 pub mins: Vec<f64>,
701 pub maxs: Vec<f64>,
703}
704
705impl NormalizationParams {
706 pub fn from_dataset(dataset: &Dataset) -> Self {
708 let (means, stds) = dataset.feature_stats();
709 let nf = means.len();
710 let mut mins = vec![f64::INFINITY; nf];
711 let mut maxs = vec![f64::NEG_INFINITY; nf];
712 for row in &dataset.rows {
713 for (k, &v) in row.features.iter().enumerate() {
714 if v < mins[k] {
715 mins[k] = v;
716 }
717 if v > maxs[k] {
718 maxs[k] = v;
719 }
720 }
721 }
722 NormalizationParams {
723 means,
724 stds,
725 mins,
726 maxs,
727 }
728 }
729
730 pub fn normalize_zscore(&self, features: &[f64]) -> Vec<f64> {
732 features
733 .iter()
734 .enumerate()
735 .map(|(k, &v)| {
736 let s = if k < self.stds.len() {
737 self.stds[k]
738 } else {
739 1.0
740 };
741 let m = if k < self.means.len() {
742 self.means[k]
743 } else {
744 0.0
745 };
746 if s.abs() < 1e-15 { 0.0 } else { (v - m) / s }
747 })
748 .collect()
749 }
750
751 pub fn normalize_minmax(&self, features: &[f64]) -> Vec<f64> {
753 features
754 .iter()
755 .enumerate()
756 .map(|(k, &v)| {
757 let mn = if k < self.mins.len() {
758 self.mins[k]
759 } else {
760 0.0
761 };
762 let mx = if k < self.maxs.len() {
763 self.maxs[k]
764 } else {
765 1.0
766 };
767 let range = mx - mn;
768 if range.abs() < 1e-15 {
769 0.0
770 } else {
771 (v - mn) / range
772 }
773 })
774 .collect()
775 }
776
777 pub fn to_bytes(&self) -> Vec<u8> {
779 let mut buf = Vec::new();
780 let write_vec = |buf: &mut Vec<u8>, v: &[f64]| {
781 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
782 for &x in v {
783 buf.extend_from_slice(&x.to_bits().to_le_bytes());
784 }
785 };
786 write_vec(&mut buf, &self.means);
787 write_vec(&mut buf, &self.stds);
788 write_vec(&mut buf, &self.mins);
789 write_vec(&mut buf, &self.maxs);
790 buf
791 }
792}
793
794#[allow(dead_code)]
800#[derive(Debug, Clone, Default)]
801pub struct LabelEncoder {
802 pub classes: Vec<String>,
804 index: HashMap<String, usize>,
806}
807
808impl LabelEncoder {
809 pub fn new() -> Self {
811 LabelEncoder {
812 classes: Vec::new(),
813 index: HashMap::new(),
814 }
815 }
816
817 pub fn fit(mut class_names: Vec<String>) -> Self {
819 class_names.sort();
820 class_names.dedup();
821 let index = class_names
822 .iter()
823 .enumerate()
824 .map(|(i, s)| (s.clone(), i))
825 .collect();
826 LabelEncoder {
827 classes: class_names,
828 index,
829 }
830 }
831
832 pub fn encode(&self, name: &str) -> Option<usize> {
834 self.index.get(name).copied()
835 }
836
837 pub fn decode(&self, idx: usize) -> Option<&str> {
839 self.classes.get(idx).map(|s| s.as_str())
840 }
841
842 pub fn num_classes(&self) -> usize {
844 self.classes.len()
845 }
846
847 pub fn one_hot(&self, idx: usize) -> Vec<f64> {
849 let mut v = vec![0.0f64; self.num_classes()];
850 if idx < v.len() {
851 v[idx] = 1.0;
852 }
853 v
854 }
855}
856
857#[allow(dead_code)]
863#[derive(Debug, Clone)]
864pub struct ConfusionMatrix {
865 pub num_classes: usize,
867 pub counts: Vec<u64>,
869}
870
871impl ConfusionMatrix {
872 pub fn new(num_classes: usize) -> Self {
874 ConfusionMatrix {
875 num_classes,
876 counts: vec![0; num_classes * num_classes],
877 }
878 }
879
880 pub fn record(&mut self, true_label: usize, predicted: usize) {
882 if true_label < self.num_classes && predicted < self.num_classes {
883 self.counts[true_label * self.num_classes + predicted] += 1;
884 }
885 }
886
887 pub fn accuracy(&self) -> f64 {
889 let total: u64 = self.counts.iter().sum();
890 if total == 0 {
891 return 0.0;
892 }
893 let correct: u64 = (0..self.num_classes)
894 .map(|i| self.counts[i * self.num_classes + i])
895 .sum();
896 correct as f64 / total as f64
897 }
898
899 pub fn precision(&self, class: usize) -> f64 {
901 if class >= self.num_classes {
902 return 0.0;
903 }
904 let tp = self.counts[class * self.num_classes + class] as f64;
905 let fp: f64 = (0..self.num_classes)
906 .filter(|&r| r != class)
907 .map(|r| self.counts[r * self.num_classes + class] as f64)
908 .sum();
909 if tp + fp < 1e-15 { 0.0 } else { tp / (tp + fp) }
910 }
911
912 pub fn recall(&self, class: usize) -> f64 {
914 if class >= self.num_classes {
915 return 0.0;
916 }
917 let tp = self.counts[class * self.num_classes + class] as f64;
918 let fn_: f64 = (0..self.num_classes)
919 .filter(|&c| c != class)
920 .map(|c| self.counts[class * self.num_classes + c] as f64)
921 .sum();
922 if tp + fn_ < 1e-15 {
923 0.0
924 } else {
925 tp / (tp + fn_)
926 }
927 }
928
929 pub fn f1(&self, class: usize) -> f64 {
931 let p = self.precision(class);
932 let r = self.recall(class);
933 if p + r < 1e-15 {
934 0.0
935 } else {
936 2.0 * p * r / (p + r)
937 }
938 }
939
940 pub fn to_csv(&self) -> String {
942 let mut s = String::new();
943 s.push_str("true\\pred");
945 for j in 0..self.num_classes {
946 s.push_str(&format!(",class_{j}"));
947 }
948 s.push('\n');
949 for i in 0..self.num_classes {
950 s.push_str(&format!("class_{i}"));
951 for j in 0..self.num_classes {
952 s.push_str(&format!(",{}", self.counts[i * self.num_classes + j]));
953 }
954 s.push('\n');
955 }
956 s
957 }
958}
959
960#[allow(dead_code)]
966#[derive(Debug, Clone)]
967pub struct EpochRecord {
968 pub epoch: usize,
970 pub train_loss: f64,
972 pub val_loss: f64,
974 pub train_acc: f64,
976 pub val_acc: f64,
978 pub learning_rate: f64,
980}
981
982#[allow(dead_code)]
984#[derive(Debug, Clone, Default)]
985pub struct TrainingHistory {
986 pub records: Vec<EpochRecord>,
988}
989
990impl TrainingHistory {
991 pub fn new() -> Self {
993 TrainingHistory {
994 records: Vec::new(),
995 }
996 }
997
998 pub fn push(&mut self, record: EpochRecord) {
1000 self.records.push(record);
1001 }
1002
1003 pub fn num_epochs(&self) -> usize {
1005 self.records.len()
1006 }
1007
1008 pub fn best_val_acc(&self) -> Option<(usize, f64)> {
1010 self.records
1011 .iter()
1012 .enumerate()
1013 .max_by(|(_, a), (_, b)| {
1014 a.val_acc
1015 .partial_cmp(&b.val_acc)
1016 .unwrap_or(std::cmp::Ordering::Equal)
1017 })
1018 .map(|(i, r)| (i, r.val_acc))
1019 }
1020
1021 pub fn best_val_loss(&self) -> Option<(usize, f64)> {
1023 self.records
1024 .iter()
1025 .enumerate()
1026 .min_by(|(_, a), (_, b)| {
1027 a.val_loss
1028 .partial_cmp(&b.val_loss)
1029 .unwrap_or(std::cmp::Ordering::Equal)
1030 })
1031 .map(|(i, r)| (i, r.val_loss))
1032 }
1033
1034 pub fn to_csv(&self) -> String {
1036 let mut s = String::from("epoch,train_loss,val_loss,train_acc,val_acc,lr\n");
1037 for r in &self.records {
1038 s.push_str(&format!(
1039 "{},{:.6},{:.6},{:.6},{:.6},{:.8}\n",
1040 r.epoch, r.train_loss, r.val_loss, r.train_acc, r.val_acc, r.learning_rate
1041 ));
1042 }
1043 s
1044 }
1045}
1046
1047#[allow(dead_code)]
1053#[derive(Debug, Clone, PartialEq)]
1054pub enum HpValue {
1055 Float(f64),
1057 Bool(bool),
1059 Str(String),
1061}
1062
1063impl HpValue {
1064 pub fn as_float(&self) -> Option<f64> {
1066 if let HpValue::Float(v) = self {
1067 Some(*v)
1068 } else {
1069 None
1070 }
1071 }
1072
1073 pub fn as_bool(&self) -> Option<bool> {
1075 if let HpValue::Bool(v) = self {
1076 Some(*v)
1077 } else {
1078 None
1079 }
1080 }
1081
1082 pub fn as_str(&self) -> Option<&str> {
1084 if let HpValue::Str(s) = self {
1085 Some(s.as_str())
1086 } else {
1087 None
1088 }
1089 }
1090}
1091
1092#[allow(dead_code)]
1094#[derive(Debug, Clone, Default)]
1095pub struct HyperparamConfig {
1096 pub params: HashMap<String, HpValue>,
1098}
1099
1100impl HyperparamConfig {
1101 pub fn new() -> Self {
1103 HyperparamConfig {
1104 params: HashMap::new(),
1105 }
1106 }
1107
1108 pub fn set_float(&mut self, key: impl Into<String>, value: f64) {
1110 self.params.insert(key.into(), HpValue::Float(value));
1111 }
1112
1113 pub fn set_bool(&mut self, key: impl Into<String>, value: bool) {
1115 self.params.insert(key.into(), HpValue::Bool(value));
1116 }
1117
1118 pub fn set_str(&mut self, key: impl Into<String>, value: impl Into<String>) {
1120 self.params.insert(key.into(), HpValue::Str(value.into()));
1121 }
1122
1123 pub fn get_float(&self, key: &str) -> Option<f64> {
1125 self.params.get(key)?.as_float()
1126 }
1127
1128 pub fn get_bool(&self, key: &str) -> Option<bool> {
1130 self.params.get(key)?.as_bool()
1131 }
1132
1133 pub fn get_str(&self, key: &str) -> Option<&str> {
1135 self.params.get(key)?.as_str()
1136 }
1137
1138 pub fn to_json(&self) -> String {
1140 let mut parts: Vec<String> = Vec::new();
1141 let mut keys: Vec<&String> = self.params.keys().collect();
1142 keys.sort();
1143 for k in keys {
1144 let v_str = match &self.params[k] {
1145 HpValue::Float(f) => format!("{f}"),
1146 HpValue::Bool(b) => format!("{b}"),
1147 HpValue::Str(s) => format!("\"{}\"", s.replace('"', "\\\"")),
1148 };
1149 parts.push(format!("\"{}\":{}", k.replace('"', "\\\""), v_str));
1150 }
1151 format!("{{{}}}", parts.join(","))
1152 }
1153}
1154
1155#[allow(dead_code)]
1161#[derive(Debug, Clone)]
1162pub struct CheckpointMeta {
1163 pub epoch: usize,
1165 pub val_loss: f64,
1167 pub val_acc: f64,
1169 pub train_time_secs: f64,
1171 pub architecture: String,
1173 pub framework_version: String,
1175}
1176
1177impl CheckpointMeta {
1178 pub fn to_text(&self) -> String {
1180 format!(
1181 "epoch={}\nval_loss={:.8}\nval_acc={:.8}\ntrain_time_secs={:.3}\narchitecture={}\nframework_version={}\n",
1182 self.epoch,
1183 self.val_loss,
1184 self.val_acc,
1185 self.train_time_secs,
1186 self.architecture,
1187 self.framework_version
1188 )
1189 }
1190}
1191
1192#[allow(dead_code)]
1194#[derive(Debug, Clone)]
1195pub struct ModelCheckpoint {
1196 pub state: StateDict,
1198 pub meta: CheckpointMeta,
1200 pub hparams: HyperparamConfig,
1202}
1203
1204impl ModelCheckpoint {
1205 pub fn new(state: StateDict, meta: CheckpointMeta, hparams: HyperparamConfig) -> Self {
1207 ModelCheckpoint {
1208 state,
1209 meta,
1210 hparams,
1211 }
1212 }
1213
1214 pub fn to_bytes(&self) -> Vec<u8> {
1218 let mut buf = Vec::new();
1219 let sb = self.state.to_bytes();
1220 buf.extend_from_slice(&(sb.len() as u64).to_le_bytes());
1221 buf.extend_from_slice(&sb);
1222 let mt = self.meta.to_text();
1223 let mb = mt.as_bytes();
1224 buf.extend_from_slice(&(mb.len() as u64).to_le_bytes());
1225 buf.extend_from_slice(mb);
1226 let hp = self.hparams.to_json();
1227 let hb = hp.as_bytes();
1228 buf.extend_from_slice(&(hb.len() as u64).to_le_bytes());
1229 buf.extend_from_slice(hb);
1230 buf
1231 }
1232
1233 pub fn byte_size(&self) -> usize {
1235 self.to_bytes().len()
1236 }
1237}
1238
1239#[allow(dead_code)]
1245pub fn softmax(logits: &[f64]) -> Vec<f64> {
1246 if logits.is_empty() {
1247 return vec![];
1248 }
1249 let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1250 let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
1251 let sum: f64 = exps.iter().sum();
1252 if sum < 1e-15 {
1253 vec![1.0 / logits.len() as f64; logits.len()]
1254 } else {
1255 exps.iter().map(|e| e / sum).collect()
1256 }
1257}
1258
1259#[allow(dead_code)]
1261pub fn cross_entropy_loss(probs: &[f64], targets: &[f64]) -> f64 {
1262 probs
1263 .iter()
1264 .zip(targets)
1265 .map(|(&p, &t)| -t * (p.max(1e-15)).ln())
1266 .sum()
1267}
1268
1269#[allow(dead_code)]
1271pub fn argmax(values: &[f64]) -> usize {
1272 values
1273 .iter()
1274 .enumerate()
1275 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1276 .map(|(i, _)| i)
1277 .unwrap_or(0)
1278}
1279
1280#[allow(dead_code)]
1282pub fn mse(predictions: &[f64], targets: &[f64]) -> f64 {
1283 if predictions.is_empty() {
1284 return 0.0;
1285 }
1286 let n = predictions.len().min(targets.len()) as f64;
1287 predictions
1288 .iter()
1289 .zip(targets)
1290 .map(|(&p, &t)| {
1291 let d = p - t;
1292 d * d
1293 })
1294 .sum::<f64>()
1295 / n
1296}
1297
1298#[allow(dead_code)]
1300pub fn mae(predictions: &[f64], targets: &[f64]) -> f64 {
1301 if predictions.is_empty() {
1302 return 0.0;
1303 }
1304 let n = predictions.len().min(targets.len()) as f64;
1305 predictions
1306 .iter()
1307 .zip(targets)
1308 .map(|(&p, &t)| (p - t).abs())
1309 .sum::<f64>()
1310 / n
1311}
1312
1313#[cfg(test)]
1318mod tests {
1319 use super::*;
1320
1321 #[test]
1324 fn test_tensor_new_shape_mismatch_panics() {
1325 let result = std::panic::catch_unwind(|| Tensor::new(vec![2, 3], vec![0.0; 5]));
1326 assert!(result.is_err());
1327 }
1328
1329 #[test]
1330 fn test_tensor_zeros() {
1331 let t = Tensor::zeros(vec![3, 4]);
1332 assert_eq!(t.numel(), 12);
1333 assert!(t.data.iter().all(|&v| v == 0.0));
1334 }
1335
1336 #[test]
1337 fn test_tensor_numel() {
1338 let t = Tensor::new(vec![2, 3], vec![1.0; 6]);
1339 assert_eq!(t.numel(), 6);
1340 assert_eq!(t.ndim(), 2);
1341 }
1342
1343 #[test]
1344 fn test_tensor_sum_mean() {
1345 let t = Tensor::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
1346 assert!((t.sum() - 10.0).abs() < 1e-12);
1347 assert!((t.mean() - 2.5).abs() < 1e-12);
1348 }
1349
1350 #[test]
1351 fn test_tensor_scale() {
1352 let t = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1353 let t2 = t.scale(2.0);
1354 assert!((t2.data[1] - 4.0).abs() < 1e-12);
1355 }
1356
1357 #[test]
1358 fn test_tensor_add() {
1359 let a = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1360 let b = Tensor::new(vec![3], vec![4.0, 5.0, 6.0]);
1361 let c = a.add(&b).unwrap();
1362 assert!((c.data[2] - 9.0).abs() < 1e-12);
1363 }
1364
1365 #[test]
1366 fn test_tensor_add_shape_mismatch() {
1367 let a = Tensor::new(vec![2], vec![1.0, 2.0]);
1368 let b = Tensor::new(vec![3], vec![1.0, 2.0, 3.0]);
1369 assert!(a.add(&b).is_none());
1370 }
1371
1372 #[test]
1373 fn test_tensor_roundtrip_bytes() {
1374 let t = Tensor::new(vec![2, 3], vec![1.0, -2.5, 0.0, 3.125, 1e10, -1e-5]);
1375 let bytes = t.to_bytes();
1376 let t2 = Tensor::from_bytes(&bytes).unwrap();
1377 assert_eq!(t2.shape, t.shape);
1378 for (a, b) in t.data.iter().zip(&t2.data) {
1379 assert!((a - b).abs() < 1e-15);
1380 }
1381 }
1382
1383 #[test]
1384 fn test_tensor_from_bytes_empty_is_none() {
1385 assert!(Tensor::from_bytes(&[]).is_none());
1386 }
1387
1388 #[test]
1391 fn test_dense_layer_param_count() {
1392 let layer = DenseLayer::new("fc1", 4, 3, "relu");
1393 assert_eq!(layer.param_count(), 15);
1395 }
1396
1397 #[test]
1398 fn test_dense_layer_forward_zero_weights() {
1399 let layer = DenseLayer::new("fc", 3, 2, "linear");
1400 let input = vec![1.0, 2.0, 3.0];
1401 let out = layer.forward(&input);
1402 assert_eq!(out.len(), 2);
1403 for v in &out {
1405 assert!(v.abs() < 1e-12);
1406 }
1407 }
1408
1409 #[test]
1410 fn test_dense_layer_activation_relu() {
1411 assert!((apply_activation(-5.0, "relu")).abs() < 1e-12);
1412 assert!((apply_activation(3.0, "relu") - 3.0).abs() < 1e-12);
1413 }
1414
1415 #[test]
1416 fn test_dense_layer_activation_sigmoid() {
1417 let v = apply_activation(0.0, "sigmoid");
1418 assert!((v - 0.5).abs() < 1e-12);
1419 }
1420
1421 #[test]
1422 fn test_dense_layer_activation_tanh() {
1423 let v = apply_activation(0.0, "tanh");
1424 assert!(v.abs() < 1e-12);
1425 }
1426
1427 #[test]
1430 fn test_model_weights_add_and_get() {
1431 let mut model = ModelWeights::new();
1432 model.add_layer(DenseLayer::new("l1", 4, 8, "relu"));
1433 model.add_layer(DenseLayer::new("l2", 8, 2, "sigmoid"));
1434 assert_eq!(model.layers.len(), 2);
1435 assert!(model.get_layer("l1").is_some());
1436 assert!(model.get_layer("l3").is_none());
1437 }
1438
1439 #[test]
1440 fn test_model_weights_total_params() {
1441 let mut model = ModelWeights::new();
1442 model.add_layer(DenseLayer::new("l1", 4, 3, "relu")); model.add_layer(DenseLayer::new("l2", 3, 2, "linear")); assert_eq!(model.total_params(), 23);
1445 }
1446
1447 #[test]
1448 fn test_model_weights_to_bytes_nonempty() {
1449 let mut model = ModelWeights::new();
1450 model.add_layer(DenseLayer::new("l1", 2, 2, "relu"));
1451 let bytes = model.to_bytes();
1452 assert!(!bytes.is_empty());
1453 }
1454
1455 #[test]
1458 fn test_state_dict_insert_and_get() {
1459 let mut sd = StateDict::new();
1460 sd.insert("w1", Tensor::zeros(vec![4, 4]));
1461 assert_eq!(sd.len(), 1);
1462 assert_eq!(sd.get("w1").unwrap().numel(), 16);
1463 }
1464
1465 #[test]
1466 fn test_state_dict_total_params() {
1467 let mut sd = StateDict::new();
1468 sd.insert("a", Tensor::zeros(vec![3, 3]));
1469 sd.insert("b", Tensor::zeros(vec![3]));
1470 assert_eq!(sd.total_params(), 12);
1471 }
1472
1473 #[test]
1474 fn test_state_dict_roundtrip() {
1475 let mut sd = StateDict::new();
1476 sd.insert("w", Tensor::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]));
1477 sd.insert("b", Tensor::new(vec![2], vec![0.5, -0.5]));
1478 let bytes = sd.to_bytes();
1479 let sd2 = StateDict::from_bytes(&bytes).unwrap();
1480 assert_eq!(sd2.len(), 2);
1481 let w = sd2.get("w").unwrap();
1482 assert!((w.data[3] - 4.0).abs() < 1e-12);
1483 }
1484
1485 #[test]
1488 fn test_onnx_graph_node_count() {
1489 let mut g = OnnxLikeGraph::new("test_model");
1490 g.add_node(OnnxNode::new(
1491 "n0",
1492 "MatMul",
1493 vec!["x".into(), "w0".into()],
1494 vec!["h0".into()],
1495 ));
1496 g.add_node(OnnxNode::new(
1497 "n1",
1498 "Relu",
1499 vec!["h0".into()],
1500 vec!["h1".into()],
1501 ));
1502 assert_eq!(g.node_count(), 2);
1503 assert_eq!(g.count_op("Relu"), 1);
1504 }
1505
1506 #[test]
1507 fn test_onnx_graph_topological_valid() {
1508 let mut g = OnnxLikeGraph::new("model");
1509 g.inputs.push("x".into());
1510 g.add_initializer("w0", Tensor::zeros(vec![4, 4]));
1511 g.add_node(OnnxNode::new(
1512 "mm",
1513 "MatMul",
1514 vec!["x".into(), "w0".into()],
1515 vec!["y".into()],
1516 ));
1517 g.add_node(OnnxNode::new(
1518 "act",
1519 "Relu",
1520 vec!["y".into()],
1521 vec!["z".into()],
1522 ));
1523 assert!(g.is_topologically_valid());
1524 }
1525
1526 #[test]
1527 fn test_onnx_graph_topological_invalid() {
1528 let mut g = OnnxLikeGraph::new("model");
1529 g.inputs.push("x".into());
1530 g.add_node(OnnxNode::new(
1532 "act",
1533 "Relu",
1534 vec!["undefined".into()],
1535 vec!["z".into()],
1536 ));
1537 assert!(!g.is_topologically_valid());
1538 }
1539
1540 #[test]
1543 fn test_dataset_len_and_features() {
1544 let mut ds = Dataset::new();
1545 ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
1546 ds.push(DataRow::labelled(vec![3.0, 4.0], 1));
1547 assert_eq!(ds.len(), 2);
1548 assert_eq!(ds.num_features(), 2);
1549 }
1550
1551 #[test]
1552 fn test_dataset_shuffle_changes_order() {
1553 let mut ds = Dataset::new();
1554 for i in 0..20 {
1555 ds.push(DataRow::labelled(vec![i as f64], 0));
1556 }
1557 let original: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
1558 ds.shuffle(42);
1559 let shuffled: Vec<f64> = ds.rows.iter().map(|r| r.features[0]).collect();
1560 assert_ne!(original, shuffled);
1561 }
1562
1563 #[test]
1564 fn test_dataset_train_val_split() {
1565 let mut ds = Dataset::new();
1566 for i in 0..100 {
1567 ds.push(DataRow::labelled(vec![i as f64], 0));
1568 }
1569 let (train, val) = ds.train_val_split(0.2);
1570 assert_eq!(train.len(), 80);
1571 assert_eq!(val.len(), 20);
1572 }
1573
1574 #[test]
1575 fn test_dataset_feature_stats() {
1576 let mut ds = Dataset::new();
1577 ds.push(DataRow::labelled(vec![0.0, 10.0], 0));
1578 ds.push(DataRow::labelled(vec![2.0, 10.0], 1));
1579 let (means, _stds) = ds.feature_stats();
1580 assert!((means[0] - 1.0).abs() < 1e-12);
1581 assert!((means[1] - 10.0).abs() < 1e-12);
1582 }
1583
1584 #[test]
1587 fn test_normalization_zscore() {
1588 let mut ds = Dataset::new();
1589 ds.push(DataRow::labelled(vec![0.0], 0));
1590 ds.push(DataRow::labelled(vec![2.0], 0));
1591 let norm = NormalizationParams::from_dataset(&ds);
1592 let z = norm.normalize_zscore(&[1.0]);
1593 assert!(z[0].abs() < 1e-10);
1595 }
1596
1597 #[test]
1598 fn test_normalization_minmax() {
1599 let mut ds = Dataset::new();
1600 ds.push(DataRow::labelled(vec![0.0], 0));
1601 ds.push(DataRow::labelled(vec![10.0], 0));
1602 let norm = NormalizationParams::from_dataset(&ds);
1603 let v = norm.normalize_minmax(&[5.0]);
1604 assert!((v[0] - 0.5).abs() < 1e-12);
1605 }
1606
1607 #[test]
1608 fn test_normalization_bytes_nonempty() {
1609 let mut ds = Dataset::new();
1610 ds.push(DataRow::labelled(vec![1.0, 2.0], 0));
1611 let norm = NormalizationParams::from_dataset(&ds);
1612 assert!(!norm.to_bytes().is_empty());
1613 }
1614
1615 #[test]
1618 fn test_label_encoder_fit_and_encode() {
1619 let enc = LabelEncoder::fit(vec!["cat".into(), "dog".into(), "bird".into()]);
1620 assert_eq!(enc.num_classes(), 3);
1621 let i = enc.encode("dog").unwrap();
1622 assert_eq!(enc.decode(i), Some("dog"));
1623 }
1624
1625 #[test]
1626 fn test_label_encoder_one_hot() {
1627 let enc = LabelEncoder::fit(vec!["a".into(), "b".into(), "c".into()]);
1628 let oh = enc.one_hot(enc.encode("b").unwrap());
1629 assert_eq!(oh.iter().filter(|&&v| v == 1.0).count(), 1);
1630 assert!((oh.iter().sum::<f64>() - 1.0).abs() < 1e-12);
1631 }
1632
1633 #[test]
1634 fn test_label_encoder_unknown_returns_none() {
1635 let enc = LabelEncoder::fit(vec!["a".into()]);
1636 assert!(enc.encode("z").is_none());
1637 }
1638
1639 #[test]
1642 fn test_confusion_matrix_accuracy() {
1643 let mut cm = ConfusionMatrix::new(2);
1644 cm.record(0, 0);
1645 cm.record(0, 0);
1646 cm.record(1, 1);
1647 cm.record(1, 0); assert!((cm.accuracy() - 0.75).abs() < 1e-12);
1649 }
1650
1651 #[test]
1652 fn test_confusion_matrix_precision_recall() {
1653 let mut cm = ConfusionMatrix::new(2);
1654 cm.record(0, 0); cm.record(0, 1); cm.record(1, 0); cm.record(1, 1); let p = cm.precision(0);
1659 let r = cm.recall(0);
1660 assert!((p - 0.5).abs() < 1e-12);
1661 assert!((r - 0.5).abs() < 1e-12);
1662 }
1663
1664 #[test]
1665 fn test_confusion_matrix_to_csv() {
1666 let mut cm = ConfusionMatrix::new(2);
1667 cm.record(0, 0);
1668 cm.record(1, 1);
1669 let csv = cm.to_csv();
1670 assert!(csv.contains("class_0"));
1671 assert!(csv.contains("class_1"));
1672 }
1673
1674 #[test]
1677 fn test_training_history_best_val_acc() {
1678 let mut hist = TrainingHistory::new();
1679 for e in 0..5 {
1680 hist.push(EpochRecord {
1681 epoch: e,
1682 train_loss: 1.0 - e as f64 * 0.1,
1683 val_loss: 1.0 - e as f64 * 0.08,
1684 train_acc: e as f64 * 0.2,
1685 val_acc: e as f64 * 0.18,
1686 learning_rate: 0.001,
1687 });
1688 }
1689 let (best_epoch, best_acc) = hist.best_val_acc().unwrap();
1690 assert_eq!(best_epoch, 4);
1691 assert!((best_acc - 0.72).abs() < 1e-10);
1692 }
1693
1694 #[test]
1695 fn test_training_history_to_csv() {
1696 let mut hist = TrainingHistory::new();
1697 hist.push(EpochRecord {
1698 epoch: 0,
1699 train_loss: 0.9,
1700 val_loss: 0.85,
1701 train_acc: 0.6,
1702 val_acc: 0.62,
1703 learning_rate: 0.01,
1704 });
1705 let csv = hist.to_csv();
1706 assert!(csv.starts_with("epoch,"));
1707 assert!(csv.contains("0,"));
1708 }
1709
1710 #[test]
1713 fn test_hyperparam_config_get_set() {
1714 let mut cfg = HyperparamConfig::new();
1715 cfg.set_float("lr", 0.001);
1716 cfg.set_bool("dropout", true);
1717 cfg.set_str("optimizer", "adam");
1718 assert!((cfg.get_float("lr").unwrap() - 0.001).abs() < 1e-15);
1719 assert!(cfg.get_bool("dropout").unwrap());
1720 assert_eq!(cfg.get_str("optimizer").unwrap(), "adam");
1721 }
1722
1723 #[test]
1724 fn test_hyperparam_config_to_json() {
1725 let mut cfg = HyperparamConfig::new();
1726 cfg.set_float("lr", 0.01);
1727 let json = cfg.to_json();
1728 assert!(json.contains("lr"));
1729 assert!(json.starts_with('{'));
1730 assert!(json.ends_with('}'));
1731 }
1732
1733 #[test]
1736 fn test_checkpoint_byte_size_nonzero() {
1737 let state = StateDict::new();
1738 let meta = CheckpointMeta {
1739 epoch: 10,
1740 val_loss: 0.1,
1741 val_acc: 0.95,
1742 train_time_secs: 3600.0,
1743 architecture: "MLP".into(),
1744 framework_version: "0.1.0".into(),
1745 };
1746 let hparams = HyperparamConfig::new();
1747 let ck = ModelCheckpoint::new(state, meta, hparams);
1748 assert!(ck.byte_size() > 0);
1749 }
1750
1751 #[test]
1752 fn test_checkpoint_meta_to_text_contains_epoch() {
1753 let meta = CheckpointMeta {
1754 epoch: 42,
1755 val_loss: 0.05,
1756 val_acc: 0.98,
1757 train_time_secs: 100.0,
1758 architecture: "CNN".into(),
1759 framework_version: "0.1.0".into(),
1760 };
1761 let text = meta.to_text();
1762 assert!(text.contains("epoch=42"));
1763 }
1764
1765 #[test]
1768 fn test_softmax_sums_to_one() {
1769 let logits = vec![1.0, 2.0, 3.0];
1770 let probs = softmax(&logits);
1771 let total: f64 = probs.iter().sum();
1772 assert!((total - 1.0).abs() < 1e-12);
1773 }
1774
1775 #[test]
1776 fn test_softmax_max_has_highest_prob() {
1777 let logits = vec![1.0, 5.0, 2.0];
1778 let probs = softmax(&logits);
1779 assert!(probs[1] > probs[0] && probs[1] > probs[2]);
1780 }
1781
1782 #[test]
1783 fn test_cross_entropy_perfect_prediction() {
1784 let probs = vec![0.0, 1.0, 0.0];
1785 let targets = vec![0.0, 1.0, 0.0];
1786 let loss = cross_entropy_loss(&probs, &targets);
1787 assert!(loss < 1e-10);
1788 }
1789
1790 #[test]
1791 fn test_argmax_basic() {
1792 let v = vec![0.1, 0.7, 0.2];
1793 assert_eq!(argmax(&v), 1);
1794 }
1795
1796 #[test]
1797 fn test_mse_zero() {
1798 let p = vec![1.0, 2.0, 3.0];
1799 let t = vec![1.0, 2.0, 3.0];
1800 assert!(mse(&p, &t).abs() < 1e-12);
1801 }
1802
1803 #[test]
1804 fn test_mse_known() {
1805 let p = vec![0.0, 0.0];
1806 let t = vec![1.0, 1.0];
1807 assert!((mse(&p, &t) - 1.0).abs() < 1e-12);
1808 }
1809
1810 #[test]
1811 fn test_mae_basic() {
1812 let p = vec![0.0, 1.0, 2.0];
1813 let t = vec![1.0, 1.0, 3.0];
1814 let m = mae(&p, &t);
1816 assert!((m - 2.0 / 3.0).abs() < 1e-12);
1817 }
1818
1819 #[test]
1820 fn test_apply_activation_leaky_relu() {
1821 assert!((apply_activation(-1.0, "leaky_relu") - (-0.01)).abs() < 1e-12);
1822 assert!((apply_activation(2.0, "leaky_relu") - 2.0).abs() < 1e-12);
1823 }
1824
1825 #[test]
1826 fn test_apply_activation_elu() {
1827 let v = apply_activation(-1.0, "elu");
1828 assert!(v < 0.0 && v > -1.0);
1830 }
1831
1832 #[test]
1833 fn test_lcg_rng_produces_different_values() {
1834 let mut rng = LcgRng::new(1234);
1835 let a = rng.next_u64();
1836 let b = rng.next_u64();
1837 assert_ne!(a, b);
1838 }
1839}