brique/
utils.rs

1use crate::matrix::*;
2use rand::seq::SliceRandom;
3use std::fs::read;
4
5pub fn generate_vec_rand_unique(size: u32) -> Vec<u32> {
6    let mut rng = rand::rng();
7    let mut output: Vec<u32> = (0..size).collect();
8
9    output.shuffle(&mut rng);
10    output
11}
12
13// not the optimal way to return Matrix with f64s. can be optimised with matrix that accepts
14// generic type
15pub fn generate_batch_index(index_table: &Vec<u32>, batch_size: u32) -> Vec<Vec<f64>> {
16    assert!(
17        index_table.len() as u32 >= batch_size,
18        "Batch size cannot be bigger than training dataset size"
19    );
20    assert!(batch_size > 0, "Batch size must be strictly positive");
21
22    let mut number_of_batches: usize = index_table.len() / batch_size as usize;
23    if index_table.len() % (batch_size as usize) != 0 {
24        number_of_batches += 1;
25    }
26
27    let mut output: Vec<Vec<f64>> = vec![];
28
29    for i in 0..number_of_batches {
30        let mut tmp: Vec<f64> = vec![];
31        for j in 0..batch_size as usize {
32            let index: usize = (i * batch_size as usize) + j;
33            if index < index_table.len() {
34                tmp.push(index_table[index] as f64);
35            }
36        }
37        output.push(tmp);
38    }
39
40    output
41}
42
43fn convert_4_bytes_to_u32_big_endian(bytes: Vec<u8>) -> u32 {
44    assert_eq!(bytes.len(), 4, "byte array should be of size 4");
45    let output: u32 = (bytes[0] as u32) * 2_u32.pow(24)
46        + (bytes[1] as u32) * 2_u32.pow(16)
47        + (bytes[2] as u32) * 2_u32.pow(8)
48        + (bytes[3] as u32);
49
50    output
51}
52
53fn check_label_file_header(array: &Vec<u8>) {
54    // check out the documentation : http://yann.lecun.com/exdb/mnist/
55    let expected_file_header: Vec<u8> = vec![0, 0, 8, 1];
56    let array_size: u32 = convert_4_bytes_to_u32_big_endian(array[4..8].to_vec());
57
58    assert_eq!(
59        array[0..4].to_vec(),
60        expected_file_header,
61        "File incompatibility detected, are you sure you added the correct LABEL file ?"
62    );
63    assert_eq!(
64        array_size,
65        array.len() as u32 - 8,
66        "File incompatibility detected, are you sure you added the correct LABEL file ?"
67    );
68}
69
70fn check_image_file_header(array: &Vec<u8>) {
71    // check out the documentation : http://yann.lecun.com/exdb/mnist/
72    let expected_file_header: Vec<u8> = vec![0, 0, 8, 3];
73
74    let array_size: u32 = convert_4_bytes_to_u32_big_endian(array[4..8].to_vec());
75    let array_size_row: u32 = convert_4_bytes_to_u32_big_endian(array[8..12].to_vec());
76    let array_size_column: u32 = convert_4_bytes_to_u32_big_endian(array[12..16].to_vec());
77
78    assert_eq!(
79        array_size * array_size_column * array_size_row,
80        array.len() as u32 - 16,
81        "File incompatibility detected, are you sure you added the correct IMAGE file ?"
82    );
83    assert_eq!(
84        array[0..4].to_vec(),
85        expected_file_header,
86        "File incompatibility detected, are you sure you added the correct IMAGE file ?"
87    );
88}
89
90pub fn extract_labels(path: &str) -> Matrix {
91    let res: Vec<u8> = read(path).unwrap();
92    check_label_file_header(&res);
93    let slice: Vec<u8> = res[8..].to_vec();
94    let mut output: Matrix = Matrix::init_zero(1, slice.len());
95
96    slice
97        .iter()
98        .enumerate()
99        .for_each(|(index, value)| output.set(*value as f64, 0, index));
100
101    output
102}
103
104pub fn extract_images(path: &str) -> Matrix {
105    let res: Vec<u8> = read(path).unwrap();
106
107    check_image_file_header(&res);
108
109    let array_size: u32 = convert_4_bytes_to_u32_big_endian(res[4..8].to_vec());
110    let array_size_row: u32 = convert_4_bytes_to_u32_big_endian(res[8..12].to_vec());
111    let array_size_column: u32 = convert_4_bytes_to_u32_big_endian(res[12..16].to_vec());
112
113    let pixels_per_image: u32 = array_size_row * array_size_column;
114
115    let mut output: Matrix = Matrix::init_zero(
116        array_size.try_into().unwrap(),
117        pixels_per_image.try_into().unwrap(),
118    );
119    let mut index = 0;
120
121    for i in res[16..].to_vec() {
122        let x: usize = (index / pixels_per_image).try_into().unwrap();
123        let y: usize = (index % pixels_per_image).try_into().unwrap();
124        output.set(i as f64, x, y);
125        index += 1;
126    }
127
128    output
129}