1use crate::matrix::*;
2use rand::seq::SliceRandom;
3use std::fs::read;
4
5pub fn generate_vec_rand_unique(size: u32) -> Vec<u32> {
6 let mut rng = rand::rng();
7 let mut output: Vec<u32> = (0..size).collect();
8
9 output.shuffle(&mut rng);
10 output
11}
12
13pub fn generate_batch_index(index_table: &Vec<u32>, batch_size: u32) -> Vec<Vec<f64>> {
16 assert!(
17 index_table.len() as u32 >= batch_size,
18 "Batch size cannot be bigger than training dataset size"
19 );
20 assert!(batch_size > 0, "Batch size must be strictly positive");
21
22 let mut number_of_batches: usize = index_table.len() / batch_size as usize;
23 if index_table.len() % (batch_size as usize) != 0 {
24 number_of_batches += 1;
25 }
26
27 let mut output: Vec<Vec<f64>> = vec![];
28
29 for i in 0..number_of_batches {
30 let mut tmp: Vec<f64> = vec![];
31 for j in 0..batch_size as usize {
32 let index: usize = (i * batch_size as usize) + j;
33 if index < index_table.len() {
34 tmp.push(index_table[index] as f64);
35 }
36 }
37 output.push(tmp);
38 }
39
40 output
41}
42
43fn convert_4_bytes_to_u32_big_endian(bytes: Vec<u8>) -> u32 {
44 assert_eq!(bytes.len(), 4, "byte array should be of size 4");
45 let output: u32 = (bytes[0] as u32) * 2_u32.pow(24)
46 + (bytes[1] as u32) * 2_u32.pow(16)
47 + (bytes[2] as u32) * 2_u32.pow(8)
48 + (bytes[3] as u32);
49
50 output
51}
52
53fn check_label_file_header(array: &Vec<u8>) {
54 let expected_file_header: Vec<u8> = vec![0, 0, 8, 1];
56 let array_size: u32 = convert_4_bytes_to_u32_big_endian(array[4..8].to_vec());
57
58 assert_eq!(
59 array[0..4].to_vec(),
60 expected_file_header,
61 "File incompatibility detected, are you sure you added the correct LABEL file ?"
62 );
63 assert_eq!(
64 array_size,
65 array.len() as u32 - 8,
66 "File incompatibility detected, are you sure you added the correct LABEL file ?"
67 );
68}
69
70fn check_image_file_header(array: &Vec<u8>) {
71 let expected_file_header: Vec<u8> = vec![0, 0, 8, 3];
73
74 let array_size: u32 = convert_4_bytes_to_u32_big_endian(array[4..8].to_vec());
75 let array_size_row: u32 = convert_4_bytes_to_u32_big_endian(array[8..12].to_vec());
76 let array_size_column: u32 = convert_4_bytes_to_u32_big_endian(array[12..16].to_vec());
77
78 assert_eq!(
79 array_size * array_size_column * array_size_row,
80 array.len() as u32 - 16,
81 "File incompatibility detected, are you sure you added the correct IMAGE file ?"
82 );
83 assert_eq!(
84 array[0..4].to_vec(),
85 expected_file_header,
86 "File incompatibility detected, are you sure you added the correct IMAGE file ?"
87 );
88}
89
90pub fn extract_labels(path: &str) -> Matrix {
91 let res: Vec<u8> = read(path).unwrap();
92 check_label_file_header(&res);
93 let slice: Vec<u8> = res[8..].to_vec();
94 let mut output: Matrix = Matrix::init_zero(1, slice.len());
95
96 slice
97 .iter()
98 .enumerate()
99 .for_each(|(index, value)| output.set(*value as f64, 0, index));
100
101 output
102}
103
104pub fn extract_images(path: &str) -> Matrix {
105 let res: Vec<u8> = read(path).unwrap();
106
107 check_image_file_header(&res);
108
109 let array_size: u32 = convert_4_bytes_to_u32_big_endian(res[4..8].to_vec());
110 let array_size_row: u32 = convert_4_bytes_to_u32_big_endian(res[8..12].to_vec());
111 let array_size_column: u32 = convert_4_bytes_to_u32_big_endian(res[12..16].to_vec());
112
113 let pixels_per_image: u32 = array_size_row * array_size_column;
114
115 let mut output: Matrix = Matrix::init_zero(
116 array_size.try_into().unwrap(),
117 pixels_per_image.try_into().unwrap(),
118 );
119 let mut index = 0;
120
121 for i in res[16..].to_vec() {
122 let x: usize = (index / pixels_per_image).try_into().unwrap();
123 let y: usize = (index % pixels_per_image).try_into().unwrap();
124 output.set(i as f64, x, y);
125 index += 1;
126 }
127
128 output
129}