1use std::{
2 cell::RefCell,
3 collections::HashMap,
4 sync::atomic::{AtomicU64, Ordering},
5 sync::Arc,
6};
7
8use log::debug;
9
10use ndarray_rand::rand::rngs::SmallRng;
11use ndarray_rand::rand::SeedableRng;
12use ndarray_rand::rand_distr::{Distribution, Uniform};
13use ndarray_rand::RandomExt;
14
15use super::util::{Array1D, Array2D, Float, WsBlob, WsMat};
16
17#[derive(PartialEq)]
18#[repr(i32)]
19pub enum TypeBuffer {
20 Weights = 0,
21 WeightsGrad = 1,
22 Output = 2,
23 Bias = 3,
24 NeuGrad = 4,
25 BiasGrad = 5, }
27
28#[derive(Clone)]
29pub enum VariantParamArc {
30 Array1(Arc<RefCell<Array1D>>),
31 Array2(Arc<RefCell<Array2D>>),
32}
33
34impl VariantParamArc {
35 pub fn get_arr_1d(&self) -> Arc<RefCell<Array1D>> {
36 if let VariantParamArc::Array1(a) = self {
37 return a.clone();
38 } else {
39 panic!("VariantParamArc is not Array1");
40 }
41 }
42
43 pub fn get_arr_2d(&self) -> Arc<RefCell<Array2D>> {
44 if let VariantParamArc::Array2(a) = self {
45 return a.clone();
46 } else {
47 panic!("VariantParamArc is not Array2");
48 }
49 }}
50
51#[derive(Clone)]
52pub enum VariantParam {
53 Array1(Array1D),
54 Array2(Array2D),
55}
56
57impl VariantParam {
58 pub fn copy_zeroed_shape_from(arg: &VariantParamArc) -> Self {
59 match arg {
60 VariantParamArc::Array1(arr1) => {
61 let arr1_bor = arr1.borrow();
62 return VariantParam::Array1(Array1D::zeros(arr1_bor.len()));
63 }
64 VariantParamArc::Array2(arr2) => {
65 let arr2_bor = arr2.borrow();
66 return VariantParam::Array2(Array2D::zeros((
67 arr2_bor.shape()[0],
68 arr2_bor.shape()[1],
69 )));
70 }
71 }
72 }
73}
74
75impl From<i32> for TypeBuffer {
76 fn from(value: i32) -> Self {
77 if value == 0 {
78 return TypeBuffer::Weights;
79 } else if value == 1 {
80 return TypeBuffer::WeightsGrad;
81 } else if value == 2 {
82 return TypeBuffer::Output;
83 } else if value == 3 {
84 return TypeBuffer::Bias;
85 } else if value == 4 {
86 return TypeBuffer::NeuGrad;
87 } else if value == 5 {
88 return TypeBuffer::BiasGrad;
89 } else {
90 panic!("Invalid integer to convert");
91 }
92 }
93}
94
95impl VariantParamArc {
96 fn copy(&self) -> VariantParamArc {
97 match self {
98 VariantParamArc::Array1(arr) => {
99 let arr_b = arr.borrow().clone();
100 return VariantParamArc::Array1(Arc::new(RefCell::new(arr_b)));
101 }
102 VariantParamArc::Array2(arr) => {
103 let arr_b = arr.borrow().clone();
104 return VariantParamArc::Array2(Arc::new(RefCell::new(arr_b)));
105 }
106 }
107 }
108}
109
110#[derive(Clone, Default)]
111pub struct CpuParams {
112 pub params: HashMap<i32, VariantParamArc>, pub id: u64,
114}
115
116pub type LearnParamsPtr = Arc<RefCell<CpuParams>>;
117pub type ParamsBlob = Vec<CpuParams>;
118
119impl CpuParams {
120 pub fn new(size: usize, prev_size: usize) -> Self {
121 let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
122 (size, prev_size),
123 Uniform::new(-0.1, 0.1),
124 ))));
125 let ws_grad =
126 VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
127 let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
128 let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
129
130 let mut m = HashMap::new();
131 m.insert(TypeBuffer::Output as i32, output);
132 m.insert(TypeBuffer::Weights as i32, ws);
133 m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
134 m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
135
136 Self {
137 params: m,
138 id: CpuParams::generate_u64_id(),
139 }
140 }
141
142 pub fn empty() -> Self {
143 Self {
144 params: HashMap::new(),
145 id: CpuParams::generate_u64_id(),
146 }
147 }
148
149 pub fn new_only_output(size: usize) -> Self {
150 let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
151
152 let mut m = HashMap::new();
153 m.insert(TypeBuffer::Output as i32, output);
154
155 Self {
156 params: m,
157 id: CpuParams::generate_u64_id(),
158 }
159 }
160
161 pub fn new_with_bias(size: usize, prev_size: usize) -> Self {
162 let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
163 (size, prev_size),
164 Uniform::new(-0.1, 0.1),
165 ))));
166 let ws_grad =
167 VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
168 let bias = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::random(
169 size,
170 Uniform::new(-0.1, 0.1),
171 ))));
172 let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::random(
173 (1, size),
174 Uniform::new(-0.1, 0.1),
175 ))));
176 let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
177 let bias_grad = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::zeros(size))));
178
179 let mut m = HashMap::new();
180 m.insert(TypeBuffer::Output as i32, output);
181 m.insert(TypeBuffer::Weights as i32, ws);
182 m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
183 m.insert(TypeBuffer::Bias as i32, bias);
184 m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
185 m.insert(TypeBuffer::BiasGrad as i32, bias_grad);
186
187 Self {
188 params: m,
189 id: CpuParams::generate_u64_id(),
190 }
191 }
192
193 pub fn new_with_const_bias(size: usize, prev_size: usize, bias_val: f32) -> Self {
194 let ws = VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::random(
195 (size, prev_size),
196 Uniform::new(-0.1, 0.1),
197 ))));
198 let ws_grad =
199 VariantParamArc::Array2(Arc::new(RefCell::new(WsMat::zeros((size, prev_size)))));
200 let bias =
201 VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::from_elem(size, bias_val))));
202 let neu_grad = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
203 let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((1, size)))));
204 let bias_grad = VariantParamArc::Array1(Arc::new(RefCell::new(Array1D::zeros(size))));
205
206 let mut m = HashMap::new();
207 m.insert(TypeBuffer::Output as i32, output);
208 m.insert(TypeBuffer::Weights as i32, ws);
209 m.insert(TypeBuffer::WeightsGrad as i32, ws_grad);
210 m.insert(TypeBuffer::Bias as i32, bias);
211 m.insert(TypeBuffer::NeuGrad as i32, neu_grad);
212 m.insert(TypeBuffer::BiasGrad as i32, bias_grad);
213
214 Self {
215 params: m,
216 id: CpuParams::generate_u64_id(),
217 }
218 }
219
220 pub fn get_1d_buf(&self, id: i32) -> Arc<RefCell<Array1D>> {
221 let res_prm = self.params.get(&id).unwrap();
222
223 if let VariantParamArc::Array1(arr) = res_prm {
224 return arr.clone();
225 } else {
226 panic!("No Array1D with id {} was found", id);
227 }
228 }
229
230 pub fn get_1d_buf_t(&self, id: TypeBuffer) -> Arc<RefCell<Array1D>> {
231 return self.get_1d_buf(id as i32);
232 }
233
234 pub fn get_2d_buf(&self, id: i32) -> Arc<RefCell<Array2D>> {
235 let res_prm = self.params.get(&id).unwrap();
236
237 if let VariantParamArc::Array2(arr) = res_prm {
238 return arr.clone();
239 } else {
240 panic!("No Array2D with id {} was found", id);
241 }
242 }
243
244 pub fn get_2d_buf_t(&self, id: TypeBuffer) -> Arc<RefCell<Array2D>> {
245 return self.get_2d_buf(id as i32);
246 }
247
248 pub fn get_param(&self, id: i32) -> VariantParamArc {
249 return self.params.get(&id).unwrap().clone();
250 }
251
252 pub fn get_param_t(&self, id: TypeBuffer) -> VariantParamArc {
253 return self.params.get(&(id as i32)).unwrap().clone();
254 }
255
256 pub fn insert_buf(&mut self, id: i32, p: VariantParamArc) {
257 self.params.insert(id, p);
258 }
259
260 pub fn remove_buf(&mut self, id: i32) {
261 self.params.remove(&id);
262 }
263
264 pub fn contains_buf(&self, id: i32) -> bool {
265 self.params.contains_key(&id)
266 }
267
268 pub fn contains_buf_t(&self, id: TypeBuffer) -> bool {
269 self.contains_buf(id as i32)
270 }
271
272 pub fn fit_to_batch_size(&mut self, new_batch_size: usize) {
273 let out_m = self.get_2d_buf_t(TypeBuffer::Output);
274 let mut out_m = out_m.borrow_mut();
275 let size = out_m.ncols();
276
277 if out_m.nrows() != new_batch_size {
278 *out_m = Array2D::zeros((new_batch_size, size));
279
280 if self.contains_buf_t(TypeBuffer::NeuGrad) {
281 let err_m = self.get_2d_buf_t(TypeBuffer::NeuGrad);
282 let mut err_m = err_m.borrow_mut();
283 *err_m = Array2D::zeros((new_batch_size, size));
284 }
285 }
286 }
287
288 pub fn prepare_for_tests(&mut self, batch_size: usize) {
289 let out_size = self.get_2d_buf_t(TypeBuffer::Output).borrow().ncols();
290
291 self.params.remove(&(TypeBuffer::NeuGrad as i32));
292 self.params.remove(&(TypeBuffer::Output as i32));
293
294 let output = VariantParamArc::Array2(Arc::new(RefCell::new(Array2D::zeros((
295 batch_size, out_size,
296 )))));
297
298 self.params.insert(TypeBuffer::Output as i32, output);
299 }
300
301 pub fn copy(&self) -> Self {
305 let mut lp = CpuParams::default();
306
307 {
308 for (k, v) in self.params.iter() {
309 let v_copied = v.copy();
310 lp.params.insert(*k, v_copied);
311 }
312
313 lp.id = self.id.clone();
314 }
315
316 lp
317 }
318
319 pub fn generate_u64_id() -> u64 {
321 static COUNTER: AtomicU64 = AtomicU64::new(0);
322 COUNTER.fetch_add(1, Ordering::Relaxed)
323 }
324}