coaster_nn/frameworks/native/
mod.rs

1//! Provides NN for a Native backend.
2
3#![allow(unused_imports)]
4#![allow(unused_variables)]
5#![allow(unreachable_code)]
6
7use std::cmp::PartialOrd;
8use std::fmt::Debug;
9use std::ops::*;
10
11#[cfg(feature = "native")]
12use rand::{distributions::Distribution, Rng, SeedableRng};
13
14use crate::plugin::*;
15use co::plugin::numeric_helpers::Bounded;
16use co::plugin::numeric_helpers::Float;
17use co::plugin::Error as PluginError;
18use co::prelude::*;
19use co::Error;
20use coaster as co;
21
22#[macro_use]
23pub mod helper;
24
25// Those functions should be in helper.rs, but there is no point to make them
26// public.
27fn lens_eq<T>(xs: &[T], ys: &[T]) -> Result<(), Error> {
28    if xs.len() != ys.len() {
29        return Err(PluginError::Operation("Tensor dimension mismatch").into());
30    }
31    Ok(())
32}
33
34fn map1_inplace<T, F>(src: &mut [T], f: F) -> Result<(), Error>
35where
36    T: Float,
37    F: Fn(T) -> T,
38{
39    for i in 0..src.len() {
40        src[i] = f(src[i]);
41    }
42    Ok(())
43}
44
45fn map2_inplace<T, F>(src1: &[T], src2: &mut [T], f: F) -> Result<(), Error>
46where
47    T: Float,
48    F: Fn(T, T) -> T,
49{
50    lens_eq(src1, src2)?;
51    for i in 0..src2.len() {
52        src2[i] = f(src1[i], src2[i]);
53    }
54    Ok(())
55}
56
57fn map1<T, F>(src: &[T], dst: &mut [T], f: F) -> Result<(), Error>
58where
59    T: Float,
60    F: Fn(T) -> T,
61{
62    lens_eq(dst, src)?;
63    for i in 0..dst.len() {
64        dst[i] = f(src[i]);
65    }
66    Ok(())
67}
68
69fn map2<T, F>(src1: &[T], src2: &[T], dst: &mut [T], f: F) -> Result<(), Error>
70where
71    T: Float,
72    F: Fn(T, T) -> T,
73{
74    lens_eq(dst, src1)?;
75    lens_eq(dst, src2)?;
76    for i in 0..dst.len() {
77        dst[i] = f(src1[i], src2[i]);
78    }
79    Ok(())
80}
81
82impl<T> NN<T> for Backend<Native>
83where
84    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,
85{
86    type CC = helper::ConvolutionConfig;
87    type CLRN = helper::NormalizationConfig;
88    type CPOOL = helper::PoolingConfig;
89    // type CACTI = helper::ActivationConfig;
90    type CDROP = helper::DropoutConfig;
91    type CRNN = helper::RnnConfig;
92
93    fn init_nn() {}
94}
95
96impl<'a, T> NNOperationConfig<T> for helper::ConvolutionConfig where
97    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
98{
99}
100impl<'a, T> ConvolutionConfig<T> for helper::ConvolutionConfig where
101    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
102{
103}
104impl<'a, T> RnnConfig<T> for helper::RnnConfig where
105    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
106{
107}
108impl<T> NNOperationConfig<T> for helper::NormalizationConfig where
109    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
110{
111}
112impl<T> NNOperationConfig<T> for helper::PoolingConfig where
113    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
114{
115}
116// impl<T> NNOperationConfig<T> for helper::ActivationConfig
117//     where T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
118// {
119// }
120impl<T> NNOperationConfig<T> for helper::DropoutConfig where
121    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
122{
123}
124
125impl<T> NNOperationConfig<T> for helper::RnnConfig where
126    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy
127{
128}
129
130impl<T> Convolution<T> for Backend<Native>
131where
132    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,
133{
134    fn new_convolution_config(
135        &self,
136        src: &SharedTensor<T>,
137        dest: &SharedTensor<T>,
138        filter: &SharedTensor<T>,
139        algo_fwd: ConvForwardAlgo,
140        algo_bwd_filter: ConvBackwardFilterAlgo,
141        algo_bwd_data: ConvBackwardDataAlgo,
142        stride: &[i32],
143        zero_padding: &[i32],
144    ) -> Result<Self::CC, Error> {
145        // TODO: check dimensions of config
146        match algo_fwd {
147            ConvForwardAlgo::Auto | ConvForwardAlgo::ImplicitGEMM => {}
148            _ => {
149                return Err(Error::Plugin(PluginError::Plugin("Unimplemented.")));
150            }
151        }
152        match algo_bwd_filter {
153            ConvBackwardFilterAlgo::Auto | ConvBackwardFilterAlgo::ImplicitGEMM => {}
154            _ => {
155                return Err(Error::Plugin(PluginError::Plugin("Unimplemented.")));
156            }
157        }
158        match algo_bwd_data {
159            ConvBackwardDataAlgo::Auto | ConvBackwardDataAlgo::ImplicitGEMM => {}
160            _ => {
161                return Err(Error::Plugin(PluginError::Plugin("Unimplemented.")));
162            }
163        }
164
165        Ok(helper::ConvolutionConfig {
166            filter_shape: filter.desc().clone(),
167            stride: stride.to_vec(),
168            padding: zero_padding.to_vec(),
169        })
170    }
171
172    fn convolution(
173        &self,
174        filter: &SharedTensor<T>,
175        x: &SharedTensor<T>,
176        result: &mut SharedTensor<T>,
177        _workspace: &mut SharedTensor<u8>,
178        config: &Self::CC,
179    ) -> Result<(), Error> {
180        let dev = self.device();
181
182        let input_dim = x.desc();
183        let input = x.read(dev).unwrap().as_slice::<T>();
184        let input_stride = input_dim.default_stride();
185
186        let output_dim = result.desc().clone();
187        // this is ok, we only read parts we already wrote
188        let output = result.write_only(dev).unwrap().as_mut_slice::<T>();
189
190        let output_stride = output_dim.default_stride();
191        {
192            for o in output.iter_mut() {
193                *o = Default::default();
194            }
195        }
196
197        let filter_dim = filter.desc();
198        let filter = filter.read(dev).unwrap().as_slice::<T>();
199        let filter_stride = filter_dim.default_stride();
200
201        // sanity check
202        assert!(input_dim[0] == output_dim[0]);
203        assert!(filter_dim[0] == output_dim[1]);
204        assert!(input_dim[1] == filter_dim[1]);
205
206        // TODO: specializations for spatial input
207
208        // recursively sum up elementwise multiplication of the hyperplanes.
209        fn filter_<T>(
210            input: &[T],
211            input_stride: &[usize],
212            input_dim: &[usize],
213            input_offset: usize,
214            input_idx_base: &[usize],
215            filter: &[T],
216            filter_stride: &[usize],
217            filter_dim: &[usize],
218            filter_offset: usize,
219            padding: &[i32],
220            depth: usize,
221            depth_end: usize,
222            acc: Option<T>,
223        ) -> T
224        where
225            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,
226        {
227            let mut acc = acc.unwrap_or_default();
228
229            let p = padding[0] as usize;
230            let input_idx_end = input_dim[0] + 2 * p;
231
232            for filter_idx in 0..filter_dim[0] {
233                let input_idx = input_idx_base[0] + filter_idx;
234                let i_offset = input_offset + (input_idx - p) * input_stride[0];
235                let f_offset = filter_offset + filter_idx * filter_stride[0];
236
237                let v = if input_idx < p || input_idx + 1 > input_idx_end - p {
238                    Default::default()
239                } else if depth + 1 >= depth_end {
240                    input[i_offset] * filter[f_offset]
241                } else {
242                    filter_(
243                        input,
244                        &input_stride[1..],
245                        &input_dim[1..],
246                        i_offset,
247                        &input_idx_base[1..],
248                        filter,
249                        &filter_stride[1..],
250                        &filter_dim[1..],
251                        f_offset,
252                        &padding[1..],
253                        depth + 1,
254                        depth_end,
255                        None,
256                    )
257                };
258                acc = acc + v;
259            }
260            return acc;
261        }
262
263        // depth == 0 is the first level
264        fn conv<T>(
265            input: &[T],
266            input_stride: &[usize],
267            input_dim: &[usize],
268            top_input_offset: usize,
269            input_offset: usize,
270            input_idx_base: &mut [usize],
271            filter: &[T],
272            filter_stride: &[usize],
273            filter_dim: &[usize],
274            filter_offset: usize,
275            depth: usize,
276            padding: &[i32],
277            stride: &[i32],
278            output: &mut [T],
279            output_stride: &[usize],
280            output_dim: &[usize],
281            output_offset: usize,
282        ) where
283            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,
284        {
285            let p = padding[depth] as usize;
286            //let input_end = input_dim[depth] + 2 * p - (filter_dim[depth]);
287
288            for output_idx in 0..output_dim[0] {
289                let input_i = output_idx * stride[0] as usize;
290                input_idx_base[depth] = input_i;
291                let input_offset = input_offset + input_i * input_stride[depth];
292                let output_offset = output_offset + output_idx * output_stride[0];
293
294                if depth + 1 < input_dim.len() {
295                    conv(
296                        input,
297                        input_stride,
298                        input_dim,
299                        top_input_offset,
300                        input_offset,
301                        input_idx_base,
302                        filter,
303                        filter_stride,
304                        filter_dim,
305                        filter_offset,
306                        depth + 1,
307                        padding,
308                        &stride[1..],
309                        output,
310                        &output_stride[1..],
311                        &output_dim[1..],
312                        output_offset,
313                    );
314                } else {
315                    let v = filter_(
316                        input,
317                        input_stride,
318                        input_dim,
319                        top_input_offset,
320                        &input_idx_base[..],
321                        filter,
322                        filter_stride,
323                        filter_dim,
324                        filter_offset,
325                        padding,
326                        0,
327                        input_dim.len(),
328                        None,
329                    );
330                    output[output_offset] = output[output_offset] + v;
331                }
332            }
333        }
334
335        fn conv_k_d1<T>(
336            _batch: usize,
337            input: &[T],
338            input_stride: &[usize],
339            input_dim: &[usize],
340            input_offset: usize,
341            input_idx_base: &mut [usize],
342            filter: &[T],
343            filter_stride: &[usize],
344            filter_dim: &[usize],
345            padding: &[i32],
346            stride: &[i32],
347            output: &mut [T],
348            output_stride: &[usize],
349            output_dim: &[usize],
350            output_offset: usize,
351        ) where
352            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,
353        {
354            for k in 0..filter_dim[0] {
355                let output_offset = output_offset + k * output_stride[0];
356                let filter_offset = k * filter_stride[0];
357                for d1 in 0..input_dim[0] {
358                    let input_offset = input_offset + d1 * input_stride[0];
359                    let filter_offset = filter_offset + d1 * filter_stride[1];
360
361                    conv(
362                        input,
363                        &input_stride[1..],
364                        &input_dim[1..],
365                        input_offset,
366                        input_offset,
367                        input_idx_base,
368                        filter,
369                        &filter_stride[2..],
370                        &filter_dim[2..],
371                        filter_offset,
372                        0,
373                        padding,
374                        stride,
375                        output,
376                        &output_stride[1..],
377                        &output_dim[1..],
378                        output_offset,
379                    );
380                }
381            }
382        }
383
384        let mut input_idx = Vec::new();
385        input_idx.resize(input_dim.len() - 2, 0);
386        let mut output_idx = Vec::new();
387        output_idx.resize(output_dim.len(), 0);
388
389        let batches = input_dim[0];
390        for batch in 0..batches {
391            let input_offset = batch * input_stride[0];
392            let output_offset = batch * output_stride[0];
393
394            conv_k_d1(
395                batch,
396                input,
397                &input_stride[1..],
398                &input_dim[1..],
399                input_offset,
400                &mut input_idx[..],
401                filter,
402                &filter_stride[..],
403                &filter_dim[..],
404                &config.padding[..],
405                &config.stride[..],
406                output,
407                &output_stride[1..],
408                &output_dim[1..],
409                output_offset,
410            );
411        }
412
413        Ok(())
414    }
415
416    fn convolution_grad_filter(
417        &self,
418        src_data: &SharedTensor<T>,
419        dest_diff: &SharedTensor<T>,
420        filter_diff: &mut SharedTensor<T>,
421        workspace: &mut SharedTensor<u8>,
422        config: &Self::CC,
423    ) -> Result<(), Error> {
424        unimplemented!()
425    }
426
427    fn convolution_grad_data(
428        &self,
429        filter: &SharedTensor<T>,
430        x_diff: &SharedTensor<T>,
431        result_diff: &mut SharedTensor<T>,
432        workspace: &mut SharedTensor<u8>,
433        config: &Self::CC,
434    ) -> Result<(), Error> {
435        unimplemented!()
436    }
437}
438
439impl<T> Pooling<T> for Backend<Native>
440where
441    T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
442{
443    fn new_pooling_config(
444        &self,
445        window: &[i32],
446        stride: &[i32],
447        padding: &[i32],
448    ) -> Result<Self::CPOOL, Error> {
449        Ok(helper::PoolingConfig {
450            window: window.to_vec(),
451            stride: stride.to_vec(),
452            padding: padding.to_vec(),
453        })
454    }
455
456    fn pooling_max(
457        &self,
458        x: &SharedTensor<T>,
459        result: &mut SharedTensor<T>,
460        config: &Self::CPOOL,
461    ) -> Result<(), Error> {
462        let dev = self.device();
463
464        let input_dim = x.desc(); // [4, 4, 4, 4]
465        let input = x.read(dev).unwrap().as_slice::<T>();
466        let input_stride = input_dim.default_stride(); // [64, 16, 4, 1];
467
468        let output_dim = result.desc().clone(); // [4,4,2,2]
469                                                // this is ok, we only read parts we already wrote
470        let output = result.write_only(dev).unwrap().as_mut_slice::<T>();
471        let output_stride = output_dim.default_stride(); // [16, 4, 2, 1]
472        {
473            for o in output.iter_mut() {
474                *o = Default::default();
475            }
476        }
477
478        fn max_pooling_<T>(
479            input: &[T],
480            input_stride: &[usize],
481            input_dim: &[usize],
482            input_offset: usize,
483            input_idx_base: &[usize],
484            window: &[i32],
485            padding: &[i32],
486            depth: usize,
487            depth_end: usize,
488            current_max: Option<T>,
489        ) -> T
490        where
491            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
492        {
493            let mut current_max = current_max.unwrap_or(T::min_value());
494
495            let p = padding[0] as usize;
496            let input_idx_end = input_dim[0] + 2 * p;
497
498            for window_idx in 0..window[0] {
499                let input_idx = input_idx_base[0] + window_idx as usize;
500
501                let v = if input_idx < p || input_idx + 1 > input_idx_end - p {
502                    T::min_value()
503                } else {
504                    let i_mem_offset = input_offset + (input_idx - p) * input_stride[0];
505                    if depth + 1 >= depth_end {
506                        input[i_mem_offset]
507                    } else {
508                        max_pooling_(
509                            input,
510                            &input_stride[1..],
511                            &input_dim[1..],
512                            i_mem_offset,
513                            &input_idx_base[1..],
514                            &window[1..],
515                            &padding[1..],
516                            depth + 1,
517                            depth_end,
518                            None,
519                        )
520                    }
521                };
522                // TODO: Handle NAN, inf and so on
523                current_max = if current_max >= v {
524                    current_max
525                } else if current_max < v {
526                    v
527                } else {
528                    //TODO honour the configuration to pass on NaN or not, see cudnn API
529                    panic!("NaN")
530                };
531            }
532            current_max
533        }
534
535        fn recurse<T>(
536            input: &[T],
537            input_stride: &[usize],
538            input_dim: &[usize],
539            top_input_offset: usize,
540            input_offset: usize,
541            input_idx_base: &mut [usize],
542            window: &[i32],
543            depth: usize,
544            stride: &[i32],
545            padding: &[i32],
546            output: &mut [T],
547            output_stride: &[usize],
548            output_dim: &[usize],
549            output_offset: usize,
550        ) where
551            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
552        {
553            let p = padding[depth] as usize; // 0
554            let w = window[depth] as usize; // 2
555
556            for output_idx in 0..output_dim[0] {
557                let input_idx = output_idx * stride[0] as usize;
558                input_idx_base[depth] = input_idx;
559                // memory offset of linear input_idx
560                let input_offset = input_offset + input_idx * input_stride[depth];
561                let output_offset = output_offset + output_idx * output_stride[0];
562                //println!("input_offset {} <- output_offset {}", input_offset, output_offset);
563
564                if depth + 1 < input_dim.len() {
565                    recurse(
566                        input,
567                        input_stride,
568                        input_dim,
569                        top_input_offset,
570                        input_offset,
571                        input_idx_base,
572                        window,
573                        depth + 1,
574                        &stride[1..],
575                        padding,
576                        output,
577                        &output_stride[1..],
578                        &output_dim[1..],
579                        output_offset,
580                    );
581                } else {
582                    let v = max_pooling_(
583                        input,
584                        input_stride,
585                        input_dim,
586                        top_input_offset,
587                        &input_idx_base[..],
588                        window,
589                        padding,
590                        0,
591                        input_dim.len(),
592                        None,
593                    );
594                    output[output_offset] = v;
595                }
596            }
597        }
598
599        let mut input_idx = Vec::new();
600        input_idx.resize(input_dim.len() - 2, 0);
601        let mut output_idx = Vec::new();
602        output_idx.resize(output_dim.len(), 0);
603
604        let window = &config.window[..];
605        let stride = &config.stride[..];
606        let padding = &config.padding[..];
607        // do everything for each batch
608        for batch in 0..input_dim[0] {
609            // iterate over the batches!
610            let input_offset = batch * input_stride[0];
611            let output_offset = batch * output_stride[0];
612
613            // iterate over the chanels
614            for d1 in 0..input_dim[1] {
615                let input_offset = input_offset + d1 * input_stride[1];
616                let output_offset = output_offset + d1 * output_stride[1];
617                // pass on the remaining dimensions (no batches, no channels, thus [2..]
618                recurse(
619                    input,
620                    &input_stride[2..],
621                    &input_dim[2..],
622                    input_offset,
623                    input_offset,
624                    &mut input_idx,
625                    &window,
626                    0,
627                    &stride,
628                    &padding,
629                    output,
630                    &output_stride[2..],
631                    &output_dim[2..],
632                    output_offset,
633                );
634            }
635        }
636
637        Ok(())
638    }
639
640    // x, x_diff are known outputs of the forward propagation
641    // result is the previous layer which derivate we want to know
642    // FIXME verify
643    fn pooling_max_grad(
644        &self,
645        x: &SharedTensor<T>,
646        x_diff: &SharedTensor<T>,
647        result: &SharedTensor<T>,
648        result_diff: &mut SharedTensor<T>,
649        config: &Self::CPOOL,
650    ) -> Result<(), Error> {
651        let dev = self.device();
652
653        let input_dim = x.desc(); // []
654        println!("x dims {:?}", input_dim);
655        let input = x.read(dev).unwrap().as_slice::<T>();
656        let input_stride = input_dim.default_stride(); // [];
657
658        let x_diff_dim = x_diff.desc(); // []
659        let x_diff = x_diff.read(dev).unwrap().as_slice::<T>();
660        println!("x_diff dims {:?}", x_diff_dim);
661
662        let output_dim = result_diff.desc().clone(); // []
663        println!("result dims {:?}", result.desc());
664        println!("result_diff dims {:?}", output_dim);
665
666        // this is ok, we only read parts we already wrote
667        let output = result_diff.write_only(dev).unwrap().as_mut_slice::<T>();
668        let output_stride = output_dim.default_stride(); // []
669        {
670            for o in output.iter_mut() {
671                *o = Default::default();
672            }
673        }
674
675        fn max_pooling_<T>(
676            input: &[T],
677            input_stride: &[usize],
678            input_dim: &[usize],
679            input_offset: usize,
680            input_idx_base: &[usize],
681            window: &[i32],
682            padding: &[i32],
683            depth: usize,
684            depth_end: usize,
685            current_max: Option<T>,
686            current_max_index: Option<usize>,
687        ) -> (T, usize)
688        where
689            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
690        {
691            let mut current_max = (
692                current_max.unwrap_or(T::min_value()),
693                current_max_index.unwrap_or(0usize),
694            );
695
696            let p = padding[0] as usize;
697            let input_idx_end = input_dim[0] + 2 * p;
698
699            for window_idx in 0..window[0] {
700                let input_idx = input_idx_base[0] + window_idx as usize;
701
702                let (v, v_index) = if input_idx < p || input_idx + 1 > input_idx_end - p {
703                    (T::min_value(), 0usize)
704                } else {
705                    let i_mem_offset = input_offset + (input_idx - p) * input_stride[0];
706                    if depth + 1 >= depth_end {
707                        (input[i_mem_offset], i_mem_offset)
708                    } else {
709                        max_pooling_(
710                            input,
711                            &input_stride[1..],
712                            &input_dim[1..],
713                            i_mem_offset,
714                            &input_idx_base[1..],
715                            &window[1..],
716                            &padding[1..],
717                            depth + 1,
718                            depth_end,
719                            None,
720                            None,
721                        )
722                    }
723                };
724                current_max = if current_max.0 >= v {
725                    current_max
726                } else if current_max.0 < v {
727                    (v, v_index)
728                } else {
729                    //TODO honour the configuration to pass on NaN or not, see cudnn API
730                    panic!("NaN")
731                };
732            }
733            current_max
734        }
735
736        fn recurse<T>(
737            input: &[T],
738            input_stride: &[usize],
739            input_dim: &[usize],
740            top_input_offset: usize,
741            input_offset: usize,
742            input_idx_base: &mut [usize],
743            window: &[i32],
744            depth: usize,
745            stride: &[i32],
746            padding: &[i32],
747            output: &mut [T],
748            output_stride: &[usize],
749            output_dim: &[usize],
750            output_offset: usize,
751            dx: &[T],
752        ) where
753            T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
754        {
755            let p = padding[depth] as usize; // 0
756            let w = window[depth] as usize; // 2
757
758            for output_idx in 0..output_dim[0] {
759                let input_idx = output_idx * stride[0] as usize;
760                input_idx_base[depth] = input_idx;
761                // memory offset of linear input_idx
762                let input_offset = input_offset + input_idx * input_stride[depth];
763                let output_offset = output_offset + output_idx * output_stride[0];
764                //println!("input_offset {} <- output_offset {}", input_offset, output_offset);
765
766                if depth + 1 < input_dim.len() {
767                    recurse(
768                        input,
769                        input_stride,
770                        input_dim,
771                        top_input_offset,
772                        input_offset,
773                        input_idx_base,
774                        window,
775                        depth + 1,
776                        &stride[1..],
777                        padding,
778                        output,
779                        &output_stride[1..],
780                        &output_dim[1..],
781                        output_offset,
782                        dx,
783                    );
784                } else {
785                    let (val, index) = max_pooling_(
786                        input,
787                        input_stride,
788                        input_dim,
789                        top_input_offset,
790                        &input_idx_base[..],
791                        window,
792                        padding,
793                        0,
794                        input_dim.len(),
795                        None,
796                        None,
797                    );
798                    // if the stride is 1 and the size is i.e. multiple outputs of the forward propagation
799                    // can map back to one input
800                    // TODO sum up
801                    output[index] = dx[0]; // FIXME we need a second index for this shit
802                }
803            }
804        }
805
806        let mut input_idx = Vec::new();
807        input_idx.resize(input_dim.len() - 2, 0);
808        let mut output_idx = Vec::new();
809        output_idx.resize(output_dim.len(), 0);
810
811        let window = &config.window[..];
812        let stride = &config.stride[..];
813        let padding = &config.padding[..];
814        // do everything for each batch
815        for batch in 0..input_dim[0] {
816            // iterate over the batches!
817            let input_offset = batch * input_stride[0];
818            let output_offset = batch * output_stride[0];
819
820            // iterate over the chanels
821            for d1 in 0..input_dim[1] {
822                let input_offset = input_offset + d1 * input_stride[1];
823                let output_offset = output_offset + d1 * output_stride[1];
824                // pass on the remaining dimensions (no batches, no channels, thus [2..]
825                recurse(
826                    input,
827                    &input_stride[2..],
828                    &input_dim[2..],
829                    input_offset,
830                    input_offset,
831                    &mut input_idx,
832                    &window,
833                    0,
834                    &stride,
835                    &padding,
836                    output,
837                    &output_stride[2..],
838                    &output_dim[2..],
839                    output_offset,
840                    x_diff,
841                );
842            }
843        }
844        Ok(())
845    }
846
847    fn pooling_avg(
848        &self,
849        x: &SharedTensor<T>,
850        result: &mut SharedTensor<T>,
851        config: &Self::CPOOL,
852    ) -> Result<(), Error> {
853        return Err(Error::Plugin(PluginError::Plugin("Unimplemented.")));
854    }
855
856    fn pooling_avg_grad(
857        &self,
858        x: &SharedTensor<T>,
859        x_diff: &SharedTensor<T>,
860        result: &SharedTensor<T>,
861        result_diff: &mut SharedTensor<T>,
862        config: &Self::CPOOL,
863    ) -> Result<(), Error> {
864        return Err(Error::Plugin(PluginError::Plugin("Unimplemented.")));
865    }
866}
867
868impl<T> Rnn<T> for Backend<Native>
869where
870    T: Float + Default + Copy + PartialOrd + Bounded,
871{
872    fn new_rnn_config(
873        &self,
874        src: &SharedTensor<T>,
875        dropout_probability: Option<f32>,
876        dropout_seed: Option<u64>,
877        sequence_length: i32,
878        network_mode: RnnNetworkMode,
879        input_mode: RnnInputMode,
880        direction_mode: DirectionMode,
881        algorithm: RnnAlgorithm,
882        hidden_size: i32,
883        num_layers: i32,
884        batch_size: i32,
885    ) -> Result<Self::CRNN, Error> {
886        // TODO: Implement Config to hold parameters regarding the RNN
887        unimplemented!()
888    }
889
890    fn generate_rnn_weight_description(
891        &self,
892        rnn_config: &Self::CRNN,
893        batch_size: i32,
894        input_size: i32,
895    ) -> Result<Vec<usize>, Error> {
896        // This will end up being the tensor descriptor for the weights associated with the RNN pass
897        unimplemented!()
898    }
899
900    fn rnn_forward(
901        &self,
902        src: &SharedTensor<T>,
903        output: &mut SharedTensor<T>,
904        rnn_config: &Self::CRNN,
905        weight: &SharedTensor<T>,
906        workspace: &mut SharedTensor<u8>,
907    ) -> Result<(), Error> {
908        // TODO: Implement RNN Forward Pass
909        unimplemented!()
910    }
911
912    fn rnn_backward_data(
913        &self,
914        src: &SharedTensor<T>,
915        src_gradient: &mut SharedTensor<T>,
916        output: &SharedTensor<T>,
917        output_gradient: &SharedTensor<T>,
918        rnn_config: &Self::CRNN,
919        weight: &SharedTensor<T>,
920        workspace: &mut SharedTensor<u8>,
921    ) -> Result<(), Error> {
922        // TODO: Implement Backward Pass for RNN for the Input
923        unimplemented!()
924    }
925
926    fn rnn_backward_weights(
927        &self,
928        src: &SharedTensor<T>,
929        output: &SharedTensor<T>,
930        filter: &mut SharedTensor<T>,
931        rnn_config: &Self::CRNN,
932        workspace: &mut SharedTensor<u8>,
933    ) -> Result<(), Error> {
934        // TODO: Implement Backward Pass with respect to Weights
935        unimplemented!()
936    }
937}
938
939#[cfg(feature = "native")]
940impl<T> Dropout<T> for Backend<Native>
941where
942    T: Float + Add<T, Output = T> + Mul<T, Output = T> + Default + Copy + PartialOrd + Bounded,
943{
944    fn new_dropout_config(&self, probability: f32, seed: u64) -> Result<Self::CDROP, Error> {
945        Ok(helper::DropoutConfig { probability, seed })
946    }
947
948    // TODO this is supposed to be an in place operation
949    #[cfg(feature = "native")]
950    fn dropout(
951        &self,
952        x: &SharedTensor<T>,
953        result: &mut SharedTensor<T>,
954        config: &Self::CDROP,
955    ) -> Result<(), Error> {
956        let dev = self.device();
957
958        let input_dim = x.desc(); // [4, 4, 4, 4]
959        let input = x.read(dev).unwrap().as_slice::<T>();
960
961        let output_dim = result.desc().clone(); // [4,4,2,2]
962        let output = result.write_only(dev).unwrap().as_mut_slice::<T>();
963
964        output.clone_from_slice(input);
965
966        let seed: [u8; 8] = config.seed.to_le_bytes();
967        let mut extrapolated_seed = [0u8; 32];
968        extrapolated_seed[0..8].copy_from_slice(&seed);
969        extrapolated_seed[12..20].copy_from_slice(&seed);
970        extrapolated_seed[24..32].copy_from_slice(&seed);
971        let mut rng = ::rand_chacha::ChaChaRng::from_seed(extrapolated_seed);
972
973        let dist = ::rand::distributions::Uniform::<f32>::new_inclusive(0., 1.);
974
975        for i in 0..output.len() {
976            if dist.sample(&mut rng) >= config.probability {
977                output[i] = input[i];
978            } else {
979                output[i] = T::zero();
980            }
981        }
982        Ok(())
983    }
984
985    #[allow(unused_variables)]
986    fn dropout_grad(
987        &self,
988        x: &SharedTensor<T>,
989        x_diff: &SharedTensor<T>,
990        result: &SharedTensor<T>,
991        result_diff: &mut SharedTensor<T>,
992        config: &Self::CDROP,
993    ) -> Result<(), Error> {
994        // TODO check if there is anything to do here?
995        Ok(())
996    }
997}
998
999// convolution is not needed here, it is well implemented without the macro madness
1000impl_ops_sigmoid_for!(f32, Backend<Native>);
1001impl_ops_relu_for!(f32, Backend<Native>);
1002impl_ops_tanh_for!(f32, Backend<Native>);
1003impl_ops_softmax_for!(f32, Backend<Native>);
1004impl_ops_log_softmax_for!(f32, Backend<Native>);
1005// impl_ops_lrn_for!(f32, Backend<Native>);
1006
1007//impl NN<f64> for Backend<Native> {
1008//type CC = helper::ConvolutionConfig;
1009//type CLRN = helper::NormalizationConfig;
1010//type CPOOL = helper::PoolingConfig;
1011
1012//fn init_nn() { }
1013//fn device(&self) -> &DeviceType { self.device() }
1014//}
1015
1016impl_ops_sigmoid_for!(f64, Backend<Native>);
1017impl_ops_relu_for!(f64, Backend<Native>);
1018impl_ops_tanh_for!(f64, Backend<Native>);
1019impl_ops_softmax_for!(f64, Backend<Native>);
1020impl_ops_log_softmax_for!(f64, Backend<Native>);
1021// impl_ops_lrn_for!(f64, Backend<Native>);