nevermind_neu/
cpu_params.rs

1use std::{
2    cell::RefCell,
3    collections::HashMap,
4    sync::atomic::{AtomicU64, Ordering},
5    sync::Arc,
6};
7
8use log::debug;
9
10use ndarray_rand::rand::rngs::SmallRng;
11use ndarray_rand::rand::SeedableRng;
12use ndarray_rand::rand_distr::{Distribution, Uniform};
13use ndarray_rand::RandomExt;
14
15use super::util::{Array1D, Array2D, Float, WsBlob, WsMat};
16
17#[derive(PartialEq)]
18#[repr(i32)]
19pub enum TypeBuffer {
20    Weights = 0,
21    WeightsGrad = 1,
22    Output = 2,
23    Bias = 3,
24    NeuGrad = 4,
25    BiasGrad = 5, // averaged neuron gradient
26}
27
28#[derive(Clone)]
29pub enum VariantParamArc {
30    Array1(Arc<RefCell<Array1D>>),
31    Array2(Arc<RefCell<Array2D>>),
32}
33
34impl VariantParamArc {
35    pub fn get_arr_1d(&self) -> Arc<RefCell<Array1D>> {
36        if let VariantParamArc::Array1(a) = self {
37            return a.clone();
38        } else { 
39            panic!("VariantParamArc is not Array1");
40        }
41    }
42
43    pub fn get_arr_2d(&self) -> Arc<RefCell<Array2D>> {
44        if let VariantParamArc::Array2(a) = self {
45            return a.clone();
46        } else { 
47            panic!("VariantParamArc is not Array2");
48        }
49    }}
50
51#[derive(Clone)]
52pub enum VariantParam {
53    Array1(Array1D),
54    Array2(Array2D),
55}
56
57impl VariantParam {
58    pub fn copy_zeroed_shape_from(arg: &VariantParamArc) -> Self {
59        match arg {
60            VariantParamArc::Array1(arr1) => {
61                let arr1_bor = arr1.borrow();
62                return VariantParam::Array1(Array1D::zeros(arr1_bor.len()));
63            }
64            VariantParamArc::Array2(arr2) => {
65                let arr2_bor = arr2.borrow();
66                return VariantParam::Array2(Array2D::zeros((
67                    arr2_bor.shape()[0],
68                    arr2_bor.shape()[1],
69                )));
70            }
71        }
72    }
73}
74
75impl From<i32> for TypeBuffer {
76    fn from(value: i32) -> Self {
77        if value == 0 {
78            return TypeBuffer::Weights;
79        } else if value == 1 {
80            return TypeBuffer::WeightsGrad;
81        } else if value == 2 {
82            return TypeBuffer::Output;
83        } else if value == 3 {
84            return TypeBuffer::Bias;
85        } else if value == 4 {
86            return TypeBuffer::NeuGrad;
87        } else if value == 5 {
88            return TypeBuffer::BiasGrad;
89        } else {
90            panic!("Invalid integer to convert");
91        }
92    }
93}
94
95impl VariantParamArc {
96    fn copy(&self) -> VariantParamArc {
97        match self {
98            VariantParamArc::Array1(arr) => {
99                let arr_b = arr.borrow().clone();
100                return VariantParamArc::Array1(Arc::new(RefCell::new(arr_b)));
101            }
102            VariantParamArc::Array2(arr) => {
103                let arr_b = arr.borrow().clone();
104                return VariantParamArc::Array2(Arc::new(RefCell::new(arr_b)));
105            }
106        }
107    }
108}
109
110#[derive(Clone, Default)]
111pub struct CpuParams {
112    pub params: HashMap<i32, VariantParamArc>, // pub may be removed further
113    pub id: u64,
114}
115
116pub type LearnParamsPtr = Arc<RefCell<CpuParams>>;
117pub type ParamsBlob = Vec<CpuParams>;
118
119impl CpuParams {
120    pub fn new(size: usize, prev_size: usize) -> Self {
121        let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
122            (size, prev_size),
123            Uniform::new(-0.1, 0.1),
124        ))));
125        let ws_grad =
126            VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
127        let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
128        let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
129
130        let mut m = HashMap::new();
131        m.insert(TypeBuffer::Output as i32, output);
132        m.insert(TypeBuffer::Weights as i32, ws);
133        m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
134        m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
135
136        Self {
137            params: m,
138            id: CpuParams::generate_u64_id(),
139        }
140    }
141
142    pub fn empty() -> Self {
143        Self {
144            params: HashMap::new(),
145            id: CpuParams::generate_u64_id(),
146        }
147    }
148
149    pub fn new_only_output(size: usize) -> Self {
150        let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
151
152        let mut m = HashMap::new();
153        m.insert(TypeBuffer::Output as i32, output);
154
155        Self {
156            params: m,
157            id: CpuParams::generate_u64_id(),
158        }
159    }
160
161    pub fn new_with_bias(size: usize, prev_size: usize) -> Self {
162        let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
163            (size, prev_size),
164            Uniform::new(-0.1, 0.1),
165        ))));
166        let ws_grad =
167            VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
168        let bias = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::random(
169            size,
170            Uniform::new(-0.1, 0.1),
171        ))));
172        let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::random(
173            (1, size),
174            Uniform::new(-0.1, 0.1),
175        ))));
176        let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
177        let bias_grad = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::zeros(size))));
178
179        let mut m = HashMap::new();
180        m.insert(TypeBuffer::Output as i32, output);
181        m.insert(TypeBuffer::Weights as i32, ws);
182        m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
183        m.insert(TypeBuffer::Bias as i32, bias);
184        m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
185        m.insert(TypeBuffer::BiasGrad as i32, bias_grad);
186
187        Self {
188            params: m,
189            id: CpuParams::generate_u64_id(),
190        }
191    }
192
193    pub fn new_with_const_bias(size: usize, prev_size: usize, bias_val: f32) -> Self {
194        let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
195            (size, prev_size),
196            Uniform::new(-0.1, 0.1),
197        ))));
198        let ws_grad =
199            VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
200        let bias =
201            VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::from_elem(size, bias_val))));
202        let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
203        let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
204        let bias_grad = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::zeros(size))));
205
206        let mut m = HashMap::new();
207        m.insert(TypeBuffer::Output as i32, output);
208        m.insert(TypeBuffer::Weights as i32, ws);
209        m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
210        m.insert(TypeBuffer::Bias as i32, bias);
211        m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
212        m.insert(TypeBuffer::BiasGrad as i32, bias_grad);
213
214        Self {
215            params: m,
216            id: CpuParams::generate_u64_id(),
217        }
218    }
219
220    pub fn get_1d_buf(&self, id: i32) -> Arc<RefCell<Array1D>> {
221        let res_prm = self.params.get(&id).unwrap();
222
223        if let VariantParamArc::Array1(arr) = res_prm {
224            return arr.clone();
225        } else {
226            panic!("No Array1D with id {} was found", id);
227        }
228    }
229
230    pub fn get_1d_buf_t(&self, id: TypeBuffer) -> Arc<RefCell<Array1D>> {
231        return self.get_1d_buf(id as i32);
232    }
233
234    pub fn get_2d_buf(&self, id: i32) -> Arc<RefCell<Array2D>> {
235        let res_prm = self.params.get(&id).unwrap();
236
237        if let VariantParamArc::Array2(arr) = res_prm {
238            return arr.clone();
239        } else {
240            panic!("No Array2D with id {} was found", id);
241        }
242    }
243
244    pub fn get_2d_buf_t(&self, id: TypeBuffer) -> Arc<RefCell<Array2D>> {
245        return self.get_2d_buf(id as i32);
246    }
247
248    pub fn get_param(&self, id: i32) -> VariantParamArc {
249        return self.params.get(&id).unwrap().clone();
250    }
251
252    pub fn get_param_t(&self, id: TypeBuffer) -> VariantParamArc {
253        return self.params.get(&(id as i32)).unwrap().clone();
254    }
255
256    pub fn insert_buf(&mut self, id: i32, p: VariantParamArc) {
257        self.params.insert(id, p);
258    }
259
260    pub fn remove_buf(&mut self, id: i32) {
261        self.params.remove(&id);
262    }
263
264    pub fn contains_buf(&self, id: i32) -> bool {
265        self.params.contains_key(&id)
266    }
267
268    pub fn contains_buf_t(&self, id: TypeBuffer) -> bool {
269        self.contains_buf(id as i32)
270    }
271
272    pub fn fit_to_batch_size(&mut self, new_batch_size: usize) {
273        let out_m = self.get_2d_buf_t(TypeBuffer::Output);
274        let mut out_m = out_m.borrow_mut();
275        let size = out_m.ncols();
276
277        if out_m.nrows() != new_batch_size {
278            *out_m = Array2D::zeros((new_batch_size, size));
279
280            if self.contains_buf_t(TypeBuffer::NeuGrad) {
281                let err_m = self.get_2d_buf_t(TypeBuffer::NeuGrad);
282                let mut err_m = err_m.borrow_mut();
283                *err_m = Array2D::zeros((new_batch_size, size));
284            }
285        }
286    }
287
288    pub fn prepare_for_tests(&mut self, batch_size: usize) {
289        let out_size = self.get_2d_buf_t(TypeBuffer::Output).borrow().ncols();
290
291        self.params.remove(&(TypeBuffer::NeuGrad as i32));
292        self.params.remove(&(TypeBuffer::Output as i32));
293
294        let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((
295            batch_size, out_size,
296        )))));
297
298        self.params.insert(TypeBuffer::Output as i32, output);
299    }
300
301    /// Copies learn_params memory of (weights, gradients, output...)
302    /// This function DO copy memory
303    /// To clone only Arc<...> use .clone() function
304    pub fn copy(&self) -> Self {
305        let mut lp = CpuParams::default();
306
307        {
308            for (k, v) in self.params.iter() {
309                let v_copied = v.copy();
310                lp.params.insert(*k, v_copied);
311            }
312
313            lp.id = self.id.clone();
314        }
315
316        lp
317    }
318
319    /// Could be used also for OclParams and others
320    pub fn generate_u64_id() -> u64 {
321        static COUNTER: AtomicU64 = AtomicU64::new(0);
322        COUNTER.fetch_add(1, Ordering::Relaxed)
323    }
324}