pr_ml/
util.rs

1//! Utility functions that don't fit elsewhere.
2
3use super::RowVector;
4use matfile::NumericData;
5
6/// Get the data type of the given [`NumericData`] as a string.
7#[must_use]
8pub const fn get_type(data: &NumericData) -> &'static str {
9    match data {
10        NumericData::Int8 { .. } => "i8",
11        NumericData::UInt8 { .. } => "u8",
12        NumericData::Int16 { .. } => "i16",
13        NumericData::UInt16 { .. } => "u16",
14        NumericData::Int32 { .. } => "i32",
15        NumericData::UInt32 { .. } => "u32",
16        NumericData::Int64 { .. } => "i64",
17        NumericData::UInt64 { .. } => "u64",
18        NumericData::Single { .. } => "f32",
19        NumericData::Double { .. } => "f64",
20    }
21}
22
23/// Display the MNIST image.
24pub fn display_image(image: &RowVector<784, u8>) {
25    for y in 0..28 {
26        for x in 0..28 {
27            let pixel = image[y * 28 + x];
28            if pixel > 128 {
29                print!("██");
30            } else if pixel > 64 {
31                print!("▓▓");
32            } else if pixel > 32 {
33                print!("▒▒");
34            } else if pixel > 16 {
35                print!("░░");
36            } else {
37                print!("  ");
38            }
39            // Print by hex value
40            // print!("{pixel:02X} ");
41        }
42        println!();
43    }
44}
45
46/// Get the label (index of the maximum value) from the output vector.
47///
48/// # Panics
49///
50/// If the output vector is empty, or contains values that cannot be compared.
51#[must_use]
52pub fn get_label(output: &RowVector<10>) -> usize {
53    output
54        .iter()
55        .enumerate()
56        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
57        .map(|(idx, _)| idx)
58        .unwrap()
59}