coaster_nn/frameworks/native/
helper.rs

1//! Provides useful macros for easier NN implementation for native.
2
3use crate::RnnNetworkMode;
4use crate::{DirectionMode, RnnInputMode};
5use co::frameworks::native::flatbox::FlatBox;
6use co::plugin::numeric_helpers::Float;
7use co::plugin::Error as PluginError;
8use coaster as co;
9
10#[derive(Debug, Copy, Clone)]
11#[allow(missing_docs)]
12pub struct NormalizationConfig;
13
14#[derive(Debug, Clone)]
15#[allow(missing_docs)]
16pub struct PoolingConfig {
17    pub window: Vec<i32>,
18    pub padding: Vec<i32>,
19    //TODO: check datatype
20    pub stride: Vec<i32>,
21}
22
23#[derive(Debug, Copy, Clone)]
24#[allow(missing_docs)]
25pub struct DropoutConfig {
26    pub probability: f32,
27    pub seed: u64,
28}
29
30/// shortcut to reading a tensor as slice
31/// contains unwrap
32macro_rules! read {
33    ($x:ident, $t:ident, $slf:ident) => {
34        $x.read($slf.device()).unwrap().as_slice::<$t>()
35    };
36}
37
38/// shortcut to reading a tensor as mut slice
39/// contains unwrap
40macro_rules! read_write {
41    ($x:ident, $t: ident, $slf:ident) => {
42        $x.read_write($slf.device()).unwrap().as_mut_slice::<$t>()
43    };
44}
45
46/// shortcut to reading a tensor as mut slice
47/// contains unwrap
48macro_rules! write_only {
49    ($x:ident, $t: ident, $slf:ident) => {
50        $x.write_only($slf.device()).unwrap().as_mut_slice::<$t>()
51    };
52}
53
54/// Just a helper function until SharedTensor has a nice interface for writing data
55pub fn write_to_memory<T: Iterator>(mem: &mut FlatBox, data: T)
56where
57    T::Item: Clone,
58{
59    let mem_buffer = mem.as_mut_slice::<T::Item>();
60    for (index, datum) in data.enumerate() {
61        mem_buffer[index] = datum;
62    }
63}
64
65/// Computes the Sigmoid Function on the CPU
66pub fn sigmoid<T: Float>(x: T) -> T {
67    (T::one()) / (T::one() + (-x).exp())
68}
69
70/// Computes the Sigmoid Gradient on the CPU
71pub fn sigmoid_grad<T: Float>(x: T, dx: T) -> T {
72    x * (T::one() - x) * dx
73}
74
75/// Computes the ReLU Function on the CPU
76pub fn relu<T: Float>(x: T) -> T {
77    let x: T = x.clone();
78    x.max(T::zero())
79}
80
81/// Computes the ReLU Gradient on the CPU
82pub fn relu_grad<T: Float>(x: T, dx: T) -> T {
83    if x > T::zero() {
84        return dx;
85    }
86    T::zero()
87}
88
89/// Computes the Tanh Function on the CPU
90pub fn tanh<T: Float>(x: T) -> T {
91    x.tanh()
92}
93
94// d/dx tanh x = sech2 x = 1 + tanh2 x
95/// Computes the Tanh Gradient on the CPU
96pub fn tanh_grad<T: Float>(x: T, dx: T) -> T {
97    (T::one() - x.powi(2)) * dx
98}
99
100/// sigmoid impl generation macro
101#[macro_export]
102macro_rules! impl_ops_sigmoid_for {
103    ($t:ident, $b:ty) => {
104        impl Sigmoid<$t> for $b {
105            fn sigmoid(
106                &self,
107                x: &SharedTensor<$t>,
108                result: &mut SharedTensor<$t>,
109            ) -> Result<(), Error> {
110                map1(
111                    read!(x, $t, self),
112                    write_only!(result, $t, self),
113                    crate::frameworks::native::helper::sigmoid,
114                )
115            }
116
117            fn sigmoid_grad(
118                &self,
119                x: &SharedTensor<$t>,
120                x_diff: &SharedTensor<$t>,
121                result: &SharedTensor<$t>,
122                result_diff: &mut SharedTensor<$t>,
123            ) -> Result<(), Error> {
124                map2(
125                    read!(x, $t, self),
126                    read!(x_diff, $t, self),
127                    write_only!(result_diff, $t, self),
128                    crate::frameworks::native::helper::sigmoid_grad,
129                )
130            }
131        }
132
133        impl SigmoidPointwise<$t> for $b {
134            fn sigmoid_pointwise(&self, x: &mut SharedTensor<$t>) -> Result<(), Error> {
135                map1_inplace(
136                    read_write!(x, $t, self),
137                    crate::frameworks::native::helper::sigmoid,
138                )
139            }
140
141            fn sigmoid_pointwise_grad(
142                &self,
143                x: &SharedTensor<$t>,
144                x_diff: &mut SharedTensor<$t>,
145            ) -> Result<(), $crate::co::error::Error> {
146                return map2_inplace(
147                    read!(x, $t, self),
148                    read_write!(x_diff, $t, self),
149                    crate::frameworks::native::helper::sigmoid_grad,
150                );
151            }
152        }
153    };
154}
155
156/// relu impl generation macro
157#[macro_export]
158macro_rules! impl_ops_relu_for {
159    ($t:ident, $b:ty) => {
160        impl Relu<$t> for $b {
161            fn relu(
162                &self,
163                x: &SharedTensor<$t>,
164                result: &mut SharedTensor<$t>,
165            ) -> Result<(), $crate::co::error::Error> {
166                map1(
167                    read!(x, $t, self),
168                    write_only!(result, $t, self),
169                    crate::frameworks::native::helper::relu,
170                )
171            }
172
173            fn relu_grad(
174                &self,
175                x: &SharedTensor<$t>,
176                x_diff: &SharedTensor<$t>,
177                result: &SharedTensor<$t>,
178                result_diff: &mut SharedTensor<$t>,
179            ) -> Result<(), Error> {
180                map2(
181                    read!(x, $t, self),
182                    read!(x_diff, $t, self),
183                    write_only!(result_diff, $t, self),
184                    crate::frameworks::native::helper::relu_grad,
185                )
186            }
187        }
188        impl ReluPointwise<$t> for $b {
189            fn relu_pointwise(
190                &self,
191                x: &mut SharedTensor<$t>,
192            ) -> Result<(), $crate::co::error::Error> {
193                map1_inplace(
194                    read_write!(x, $t, self),
195                    crate::frameworks::native::helper::relu,
196                )
197            }
198
199            fn relu_pointwise_grad(
200                &self,
201                x: &SharedTensor<$t>,
202                x_diff: &mut SharedTensor<$t>,
203            ) -> Result<(), $crate::co::error::Error> {
204                map2_inplace(
205                    read!(x, $t, self),
206                    read_write!(x_diff, $t, self),
207                    crate::frameworks::native::helper::relu_grad,
208                )
209            }
210        }
211    };
212}
213
214/// tanh impl generation macro
215#[macro_export]
216macro_rules! impl_ops_tanh_for {
217    ($t:ident, $b:ty) => {
218        impl $crate::plugin::Tanh<$t> for $b {
219            fn tanh(
220                &self,
221                x: &SharedTensor<$t>,
222                result: &mut SharedTensor<$t>,
223            ) -> Result<(), $crate::co::error::Error> {
224                map1(
225                    read!(x, $t, self),
226                    write_only!(result, $t, self),
227                    crate::frameworks::native::helper::tanh,
228                )
229            }
230
231            fn tanh_grad(
232                &self,
233                x: &SharedTensor<$t>,
234                x_diff: &SharedTensor<$t>,
235                result: &SharedTensor<$t>,
236                result_diff: &mut SharedTensor<$t>,
237            ) -> Result<(), Error> {
238                map2(
239                    read!(x, $t, self),
240                    read!(x_diff, $t, self),
241                    write_only!(result_diff, $t, self),
242                    crate::frameworks::native::helper::tanh_grad,
243                )
244            }
245        }
246        impl $crate::plugin::TanhPointwise<$t> for $b {
247            fn tanh_pointwise(
248                &self,
249                x: &mut SharedTensor<$t>,
250            ) -> Result<(), $crate::co::error::Error> {
251                map1_inplace(
252                    read_write!(x, $t, self),
253                    crate::frameworks::native::helper::tanh,
254                )
255            }
256
257            fn tanh_pointwise_grad(
258                &self,
259                x: &SharedTensor<$t>,
260                x_diff: &mut SharedTensor<$t>,
261            ) -> Result<(), Error> {
262                map2_inplace(
263                    read!(x, $t, self),
264                    read_write!(x_diff, $t, self),
265                    crate::frameworks::native::helper::tanh_grad,
266                )
267            }
268        }
269    };
270}
271
272#[derive(Debug, Clone)]
273#[allow(missing_docs)]
274pub struct ConvolutionConfig {
275    pub filter_shape: Vec<usize>,
276    pub stride: Vec<i32>,
277    pub padding: Vec<i32>,
278}
279
280#[derive(Debug, Clone, Copy)]
281#[allow(missing_docs)]
282// TODO: Keep parallel with impl in Cuda
283pub struct RnnConfig {
284    /// Size of the Hidden Layer
285    pub hidden_size: usize,
286    /// Number of Hidden Layers
287    pub num_layers: usize,
288    /// Dropout Probability
289    pub dropout_probability: f32,
290    /// Dropout Seed
291    pub dropout_seed: u64,
292    /// Type of RNN
293    pub rnn_type: RnnNetworkMode,
294    /// Input Mode
295    pub input_mode: RnnInputMode,
296    /// RNN Direction
297    pub direction_mode: DirectionMode,
298}
299
300/// softmax impl generation macro
301#[macro_export]
302macro_rules! impl_ops_softmax_for {
303    ($t:ident, $b:ty) => {
304        impl $crate::plugin::Softmax<$t> for $b {
305            fn softmax(
306                &self,
307                x: &SharedTensor<$t>,
308                result: &mut SharedTensor<$t>,
309            ) -> Result<(), Error> {
310                let xs = read!(x, $t, self);
311                let rs = write_only!(result, $t, self);
312
313                map1(xs, rs, |v| v.exp())?;
314
315                let mut sum: $t = 0.0; // iter_arith is not stable yet
316                for r in &*rs {
317                    sum += *r;
318                }
319                for r in rs {
320                    *r /= sum;
321                }
322                Ok(())
323            }
324
325            // TODO: check
326            fn softmax_grad(
327                &self,
328                x: &SharedTensor<$t>,
329                x_diff: &SharedTensor<$t>,
330                result_diff: &mut SharedTensor<$t>,
331            ) -> Result<(), Error> {
332                let xs = read!(x, $t, self);
333                let dxs = read!(x_diff, $t, self);
334                let drs = write_only!(result_diff, $t, self);
335
336                let mut dot: $t = 0.0;
337                for (t, dt) in xs.iter().zip(dxs.iter()) {
338                    dot += t * dt;
339                }
340
341                map2(xs, dxs, drs, |t, dt| t * (dt - dot))
342            }
343        }
344    };
345}
346
347/// log softmax impl generation macro
348#[macro_export]
349macro_rules! impl_ops_log_softmax_for {
350    ($t:ident, $b:ty) => {
351        impl $crate::plugin::LogSoftmax<$t> for $b {
352            fn log_softmax(
353                &self,
354                x: &SharedTensor<$t>,
355                result: &mut SharedTensor<$t>,
356            ) -> Result<(), $crate::co::error::Error> {
357                let xs = read!(x, $t, self);
358                let rs = write_only!(result, $t, self);
359
360                let max_x = xs
361                    .iter()
362                    .fold(::std::$t::NEG_INFINITY, |acc, &t| acc.max(t));
363
364                let mut logsum: $t = 0.0;
365                for t in xs {
366                    logsum += (-(max_x - t)).exp();
367                }
368                logsum = max_x + logsum.ln();
369
370                map1(xs, rs, |t| t - logsum)
371            }
372
373            fn log_softmax_grad(
374                &self,
375                x: &SharedTensor<$t>,
376                x_diff: &SharedTensor<$t>,
377                result_diff: &mut SharedTensor<$t>,
378            ) -> Result<(), $crate::co::error::Error> {
379                let xs = read!(x, $t, self);
380                let dxs = read!(x_diff, $t, self);
381                let drs = write_only!(result_diff, $t, self);
382
383                let mut sum: $t = 0.0;
384                for &grad_val in dxs.iter() {
385                    sum += grad_val;
386                }
387                map2(xs, dxs, drs, |t, dt| dt - t.exp() * sum)
388            }
389        }
390    };
391}
392
393/// lrn impl generation macro
394/// TODO it's all unimplemented!() right now
395#[macro_export]
396macro_rules! impl_ops_lrn_for {
397    ($t:ident, $b:ty) => {
398        impl ::plugin::LRN<$t> for $b {
399            fn new_lrn_config(
400                &self,
401                n: u32,
402                alpha: f64,
403                beta: f64,
404                k: f64,
405            ) -> Result<Self::CLRN, ::co::error::Error> {
406                unimplemented!();
407                Ok(::frameworks::native::helper::NormalizationConfig)
408            }
409
410            fn lrn(
411                &self,
412                x: &mut SharedTensor<$t>,
413                result: &mut SharedTensor<$t>,
414                config: &Self::CLRN,
415            ) -> Result<(), ::co::error::Error> {
416                unimplemented!();
417                Ok(())
418            }
419
420            fn lrn_plain(
421                &self,
422                x: &SharedTensor<$t>,
423                result: &mut SharedTensor<$t>,
424                config: &Self::CLRN,
425            ) -> Result<(), ::co::error::Error> {
426                unimplemented!();
427                Ok(())
428            }
429
430            fn lrn_grad(
431                &self,
432                x: &mut SharedTensor<$t>,
433                x_diff: &mut SharedTensor<$t>,
434                result: &mut SharedTensor<$t>,
435                result_diff: &mut SharedTensor<$t>,
436                config: &Self::CLRN,
437            ) -> Result<(), ::co::error::Error> {
438                unimplemented!();
439                Ok(())
440            }
441
442            fn lrn_grad_plain(
443                &self,
444                x: &SharedTensor<$t>,
445                x_diff: &SharedTensor<$t>,
446                result: &SharedTensor<$t>,
447                result_diff: &mut SharedTensor<$t>,
448                config: &Self::CLRN,
449            ) -> Result<(), ::co::error::Error> {
450                unimplemented!();
451                Ok(())
452            }
453        }
454    };
455}