1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::{
fs::File,
io::{Read, Write,BufWriter},
mem,
};
use num_bytes::{IntoBytes,TryFromBytes};
use ndarray::{Axis,Array2};
pub fn read<T:TryFromBytes,P:TryFromBytes>(
path: &str,
example_size: usize,
) -> Result<(Array2<T>, Array2<P>), Box<dyn std::error::Error>> {
let mut file = File::open(path)?;
let mut buffer: Vec<u8> = Vec::new();
file.read_to_end(&mut buffer)?;
let label_size = mem::size_of::<P>();
let point_data_size = mem::size_of::<T>();
let data_size = point_data_size * example_size;
let sample_size = data_size+label_size;
assert_eq!(buffer.len() % sample_size, 0);
let length = buffer.len() / sample_size;
let (data,labels): (Vec<Result<Vec<T>,_>>,Vec<Result<P,_>>) = buffer.chunks_exact(sample_size).map(|chunk| {
let temp_data: Result<Vec<T>,_> = chunk[0..data_size].chunks_exact(mem::size_of::<T>()).map(|c|{
T::try_from_le_bytes(c)
}).collect();
debug_assert!(temp_data.is_ok());
let label = P::try_from_le_bytes(&chunk[data_size..]);
debug_assert!(label.is_ok());
(temp_data,label)
}).unzip();
let clean_data: Vec<Vec<T>> = data.into_iter().collect::<Result<Vec<Vec<T>>,_>>()?;
let clean_labels: Vec<P> = labels.into_iter().collect::<Result<Vec<P>,_>>()?;
let flat_data: Vec<T> = clean_data.into_iter().flatten().collect();
return Ok((
ndarray::Array::from_shape_vec((length, example_size),flat_data)?,
ndarray::Array::from_shape_vec((length, 1), clean_labels)?,
));
}
pub fn write<T: IntoBytes + Copy, P: IntoBytes + Copy>(
path: &str,
data: &Array2<T>,
labels: &Array2<P>,
) -> Result<(),Box<dyn std::error::Error>> {
assert_eq!(data.len_of(Axis(0)), data.len_of(Axis(0)));
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
for (data_ndarray, label_ndarray) in data
.axis_iter(Axis(0))
.zip(labels.axis_iter(Axis(0)))
{
let data_slice: &[T] = data_ndarray.as_slice().unwrap();
let data_bytes: Vec<u8> = data_slice.iter().flat_map(|t|t.into_le_bytes()).collect();
assert_eq!(label_ndarray.len(),1);
let label: P = label_ndarray[0];
let label_bytes = label.into_le_bytes();
writer.write_all(&data_bytes)?;
writer.write_all(&label_bytes)?;
}
Ok(())
}