flashlight/input_handler/
mod.rs1use flashlight_tensor::prelude::*;
2
3use rand::seq::SliceRandom;
4
5pub struct InputPrePrepared{
7 pub input_data: Vec<Tensor<f32>>,
8 pub output_data: Vec<Tensor<f32>>,
9 bach_size: u32,
10}
11
12pub struct InputHandler{
14 input_data: Tensor<f32>,
15 output_data: Tensor<f32>,
16 bach_size: u32,
17}
18
19impl InputPrePrepared{
20 pub fn new(input_sample: &Tensor<f32>, output_sample: &Tensor<f32>) -> Self{
21 Self{
22 input_data: vec!{input_sample.clone()},
23 output_data: vec!{output_sample.clone()},
24 bach_size: 1,
25 }
26 }
27 pub fn set_bach_size(&mut self, _bach_size: u32){
28 if self.input_data.len() % _bach_size as usize == 0 {
29 self.bach_size = _bach_size;
30 }
31 }
32 pub fn append(&mut self, input_sample: &Tensor<f32>, output_sample: &Tensor<f32>){
33 self.input_data.push(input_sample.clone());
34 self.output_data.push(output_sample.clone());
35 self.bach_size = 1;
36 }
37
38 pub fn to_handler(&mut self) -> InputHandler{
39 let mut rng = rand::rng();
40
41 let mut input_tensor = self.input_data[0].clone();
42 for i in 1..self.input_data.len(){
43 input_tensor = input_tensor.append(&self.input_data[i]).unwrap();
44 }
45
46 let input_mean: f32 = input_tensor.sum() / input_tensor.count_data() as f32;
47
48 let mut input_std_dev: f32 = 0.0;
49 for i in 0..input_tensor.get_data().len(){
50 input_std_dev += input_tensor.get_data()[i].powi(2);
51 }
52 input_std_dev = (input_std_dev/input_tensor.count_data() as f32).sqrt();
53
54 let mut normalized_vec: Vec<f32> = Vec::with_capacity(input_tensor.count_data());
55 for i in 0..input_tensor.get_data().len(){
56 normalized_vec.push((input_tensor.get_data()[i] - input_mean) / input_std_dev);
57 }
58
59 let normalized_tensor: Tensor<f32> = Tensor::from_data(&normalized_vec, &input_tensor.get_sizes()).unwrap();
60
61 let mut output_tensor = self.output_data[0].clone();
62 for i in 1..self.output_data.len(){
63 output_tensor = output_tensor.append(&self.output_data[i]).unwrap();
64 }
65
66 InputHandler{
67 input_data: normalized_tensor,
68 output_data: output_tensor,
69 bach_size: self.bach_size,
70 }
71 }
72}
73
74impl InputHandler{
75
76 pub fn len(&self) -> u32{
77 self.input_data.get_sizes()[0] / self.bach_size
78 }
79 pub fn input_bach(&self, n: u32) -> Tensor<f32>{
80 let mut bach = self.input_data.matrix_row(n*self.bach_size).unwrap();
81
82 for i in 1..self.bach_size{
83 let mut next_col = self.input_data.matrix_row(n*self.bach_size + i).unwrap();
84 bach = bach.append(&next_col).unwrap();
85 }
86
87 bach.matrix_transpose().unwrap()
88 }
89 pub fn output_bach(&self, n: u32) -> Tensor<f32>{
90 let mut bach = self.output_data.matrix_row(n*self.bach_size).unwrap();
91
92 for i in 1..self.bach_size{
93 let mut next_col = self.output_data.matrix_row(n*self.bach_size + i).unwrap();
94 bach = bach.append(&next_col).unwrap();
95 }
96
97 bach.matrix_transpose().unwrap()
98 }
99}