nncombinator/device/
batchnormalization.rs

1//! Implementation of the calculation process for batch normalization
2use std::fmt::Debug;
3use rcudnn::{API};
4use rcudnn_sys::cudnnBatchNormMode_t::{CUDNN_BATCHNORM_PER_ACTIVATION, CUDNN_BATCHNORM_SPATIAL};
5use rcudnn_sys::{cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, cudnnBatchNormalizationForwardTraining, cudnnDeriveBNTensorDescriptor, cudnnStatus_t};
6
7use crate::arr::{Arr, ArrView, IntoConverter, SerializedVec, SerializedVecView};
8use crate::ope::Sum;
9use crate::collection::Broadcast;
10use crate::computational_graph::{BroadcastNode, GraphNode, SqrtNode, SquareNode, SumNode};
11use crate::cuda::{AsMutVoidPtr, AsVoidPtr, CudaTensor1dPtr, CudaTensor1dPtrView, CudaVec, CudaVecView, DataTypeInfo, WriteMemory, ReadMemory, MemoryMoveTo};
12use crate::cuda::cudnn::tensor::CudnnTensor4dDescriptor;
13use crate::device::{DeviceCpu, DeviceGpu, DeviceMemoryPool};
14use crate::error::{EvaluateError, TrainingError, TypeConvertError};
15use crate::layer::{BatchDataType, BatchSize};
16use crate::ope::UnitValue;
17
18/// Features defining the implementation of the various computational processes in the batch normalization layer
19pub trait DeviceBatchNorm<U,C,I,const N:usize>
20    where U: UnitValue<U>,
21          I: BatchDataType + Debug + 'static,
22          <I as BatchDataType>::Type: Debug + 'static {
23    /// Forward propagation calculation
24    /// # Arguments
25    /// * `input` - input
26    /// * `scale` - γ
27    /// * `bias` - β
28    /// * `estimated_mean` - μΒ
29    /// * `estimated_variance` - σΒ
30    ///
31    /// output = γ * ((input - μΒ) / sqrt(σ^2Β + 1e-6)) + β
32    /// # Errors
33    ///
34    /// This function may return the following errors
35    /// * [`EvaluateError`]
36    fn forward_batch_norm<'a>(&self, input: &'a I, scale: &C, bias: &C,
37                          estimated_mean: &C, estimated_variance: &C) -> Result<I,EvaluateError>;
38    /// Forward propagation calculation (implemented in training mode)
39    /// # Arguments
40    /// * `input` - input
41    /// * `scale` - γ
42    /// * `bias` - β
43    /// * `estimated_mean` - μΒ
44    /// * `estimated_variance` - σΒ
45    ///
46    /// output = (γ * ((input - μΒ) / sqrt(σ^2Β + 1e-6)) + β,μΒ,1 / (σΒ + 1e-6))
47    /// # Errors
48    ///
49    /// This function may return the following errors
50    /// * [`EvaluateError`]
51    fn forward_batch_norm_train<'a>(&self, input: &'a I, scale: &C, bias: &C,
52                                estimated_mean: &C, estimated_variance: &C) -> Result<(I,C,C),EvaluateError>;
53    /// Forward propagation calculation in batch
54    /// # Arguments
55    /// * `input` - input
56    /// * `scale` - γ
57    /// * `bias` - β
58    /// * `estimated_mean` - μΒ
59    /// * `estimated_variance` - σΒ
60    ///
61    /// output = γ * ((input - μΒ) / sqrt(σ^2Β + 1e-6)) + β
62    /// # Errors
63    ///
64    /// This function may return the following errors
65    /// * [`EvaluateError`]
66    fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &C , bias: &C,
67                                estimated_mean: &C, estimated_variance: &C) -> Result<<I as BatchDataType>::Type,EvaluateError>;
68    /// Forward propagation calculation in batch (implemented in training mode)
69    /// # Arguments
70    /// * `input` - input
71    /// * `scale` - γ
72    /// * `bias` - β
73    /// * `running_mean` - μΒ
74    /// * `running_variance` - σΒ
75    ///
76    /// running_mean = running_mean * momentum + (1 - momentum) * μΒ
77    /// running_variance = running_variance * momentum + (1 - momentum) * μΒ
78    ///
79    /// output = (γ * ((input - μΒ) / sqrt(σ^2Β + 1e-6)) + β,,μΒ,1 / (σΒ + 1e-6),running_mean,running_variance)
80    /// # Errors
81    ///
82    /// This function may return the following errors
83    /// * [`EvaluateError`]
84    fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &C, bias: &C,
85                                      running_mean: &C, running_variance: &C, momentum: U)
86                                      -> Result<(<I as BatchDataType>::Type,C,C,C,C),TrainingError>;
87    /// Error back propagation calculation
88    /// # Arguments
89    /// * `loss` - loss input
90    /// * `input` - input
91    /// * `scale` - γ
92    /// * `saved_mean` - μΒ calculated during forward propagation
93    /// * `saved_inv_variance` - Inverse of σΒ calculated during forward propagati (1 / (σΒ + 1e-6))
94    ///
95    /// # Errors
96    ///
97    /// This function may return the following errors
98    /// * [`TrainingError`]
99    fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I, scale: &C,
100                           saved_mean: &C, saved_inv_variance: &C) -> Result<(I,C,C), TrainingError>;
101    /// Error back propagation calculation in batch
102    /// # Arguments
103    /// * `loss` - loss input
104    /// * `input` - input
105    /// * `scale` - γ
106    /// * `saved_mean` - μΒ calculated during forward propagation
107    /// * `saved_inv_variance` - Inverse of σΒ calculated during forward propagati (1 / (σΒ + 1e-6))
108    ///
109    /// # Errors
110    ///
111    /// This function may return the following errors
112    /// * [`TrainingError`]
113    fn batch_backward_batch_norm<'a>(&self, loss:&'a <I as BatchDataType>::Type, input: &'a <I as BatchDataType>::Type,
114                                     scale: &C, saved_mean: &C, saved_inv_variance: &C)
115        -> Result<(<I as BatchDataType>::Type,C,C), TrainingError>;
116}
117impl<U,I,const N:usize> DeviceBatchNorm<U,Arr<U,N>,I,N> for DeviceCpu<U>
118    where U: UnitValue<U>,
119          I: BatchDataType + Debug + From<Arr<U,N>> + 'static,
120          <I as BatchDataType>::Type: Debug + 'static,
121          <I as BatchDataType>::Type: TryFrom<<SerializedVec<U,Arr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
122          SerializedVec<U,Arr<U,N>>: IntoConverter,
123          for<'a> ArrView<'a,U,N>: From<&'a I>,
124          for<'a> SerializedVecView<'a,U,Arr<U,N>>: TryFrom<&'a <I as BatchDataType>::Type,Error=TypeConvertError> {
125    #[inline]
126    fn forward_batch_norm<'a>(&self, input: &'a I, scale: &Arr<U,N>, bias: &Arr<U,N>,
127                          estimated_mean: &Arr<U,N>, estimated_variance: &Arr<U,N>) -> Result<I,EvaluateError> {
128        let input = ArrView::<'a,U,N>::from(input);
129
130        let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
131            "Error in type conversion from usize."
132        )))?;
133
134        Ok(Arr::try_from(input.iter()
135            .zip(scale.iter())
136            .zip(bias.iter())
137            .zip(estimated_mean.iter())
138            .zip(estimated_variance.iter())
139            .map(|((((&i,&scale),&bias),&mean),&variance)| {
140                scale * ((i - mean) / SqrtNode::new().forward(variance + eps)) + bias
141            }).collect::<Vec<U>>())?.into())
142    }
143
144    #[inline]
145    fn forward_batch_norm_train<'a>(&self, input: &'a I,
146                                scale: &Arr<U,N>,
147                                bias: &Arr<U,N>,
148                                estimated_mean: &Arr<U,N>,
149                                estimated_variance: &Arr<U,N>) -> Result<(I,Arr<U,N>,Arr<U,N>),EvaluateError> {
150        let input = ArrView::<'a,U,N>::from(input);
151
152        let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
153            "Error in type conversion from usize."
154        )))?;
155
156        Ok((Arr::try_from(input.iter()
157                .zip(scale.iter())
158                .zip(bias.iter())
159                .zip(estimated_mean.iter())
160                .zip(estimated_variance.iter())
161                .map(|((((&i,&scale),&bias),&mean),&variance)| {
162                    scale * ((i - mean) / SqrtNode::new().forward(variance + eps)) + bias
163                }).collect::<Vec<U>>())?.into(),
164            estimated_mean.clone(),
165            estimated_variance.iter().map(|&v| U::one() / SqrtNode::new().forward(v + eps)).collect::<Vec<U>>().try_into()?
166        ))
167    }
168
169    #[inline]
170    fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &Arr<U,N>, bias: &Arr<U,N>,
171                                    estimated_mean: &Arr<U,N>, estimated_variance: &Arr<U,N>)
172        -> Result<<I as BatchDataType>::Type, EvaluateError> {
173        let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
174
175        let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
176            "Error in type conversion from usize."
177        )))?;
178
179        Ok(SerializedVec::from(input.iter().map(|input| {
180            input.iter()
181                .zip(scale.iter())
182                .zip(bias.iter())
183                .zip(estimated_mean.iter())
184                .zip(estimated_variance.iter())
185                .map(|((((&i,&scale),&bias),&mean),&variance)| {
186                    scale * (i - mean) / SqrtNode::new().forward(variance + eps) + bias
187                }).collect::<Vec<U>>().try_into()
188        }).collect::<Result<Vec<Arr<U,N>>,_>>()?).into_converter().try_into()?)
189    }
190
191    #[inline]
192    fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type,
193                                      scale: &Arr<U,N>, bias: &Arr<U,N>,
194                                      running_mean: &Arr<U,N>, running_variance: &Arr<U,N>,
195                                      momentum: U)
196                                      -> Result<(<I as BatchDataType>::Type,Arr<U,N>,Arr<U,N>,Arr<U,N>,Arr<U,N>), TrainingError> {
197        let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
198
199        let eps = U::from_f64(1e-6).ok_or(TrainingError::TypeCastError(String::from(
200            "Error in type conversion from usize."
201        )))?;
202
203        let n = input.len();
204        let un = U::from_usize(n).ok_or(TrainingError::TypeCastError(String::from(
205            "Error in type conversion from usize."
206        )))?;
207
208        let un_inv = U::from_f64(1.).ok_or(TrainingError::TypeCastError(
209            String::from(
210                "Error in type conversion from usize."
211            )
212        ))? / un;
213
214        let mean:Arr<U,N> = SumNode::<U,SerializedVecView<'_,U,Arr<U,N>>>::new().forward(input) * un_inv;
215
216        let variance:SerializedVec<U,Arr<U,N>> = (input - Broadcast::<Arr<U,N>>(mean.clone()))
217            .iter()
218            .map(|i| {
219                i.iter().map(|&i| {
220                    SquareNode::new().forward(i)
221                }).collect::<Vec<U>>().try_into()
222            }).collect::<Result<Vec<Arr<U,N>>,_>>()?.into();
223        let variance = variance.sum() * un_inv;
224
225        let inv_variance:Arr<U,N> = variance.iter().map(|&v| U::one() / SqrtNode::new().forward(v + eps)).collect::<Vec<U>>().try_into()?;
226
227        let o:SerializedVec<U,Arr<U,N>> = Broadcast(inv_variance.clone()) * (input - Broadcast(mean.clone()));
228
229        let running_mean = running_mean * momentum + &mean * (U::one() - momentum);
230        let running_variance = running_variance * momentum + variance * (U::one() - momentum);
231
232        let o = (BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((scale,n)) * o) + Broadcast(bias.clone());
233
234        Ok((o.into_converter().try_into()?,mean,inv_variance,running_mean,running_variance))
235    }
236
237    #[inline]
238    fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I,
239                           scale: &Arr<U,N>, saved_mean: &Arr<U,N>, saved_inv_variance: &Arr<U,N>)
240                           -> Result<(I, Arr<U,N>, Arr<U,N>), TrainingError> {
241        let loss = ArrView::<'a,U,N>::from(loss);
242        let input = ArrView::<'a,U,N>::from(input);
243
244        let b = loss.clone();
245
246        let x = input - saved_mean;
247
248        let s = (&x * saved_inv_variance) * loss;
249
250        let dx1 = scale * loss;
251        let dx2 = &dx1 * saved_inv_variance;
252        let dx3 = &x * dx1;
253        let dx4 =  -(saved_inv_variance * saved_inv_variance) * dx3;
254        let dx5 = dx4 * (saved_inv_variance * U::from_f64(0.5).ok_or(TrainingError::TypeCastError(String::from(
255            "Error in type conversion from f64.")
256        ))?);
257        let dx6 = &x * dx5 * U::from_usize(2).ok_or(TrainingError::TypeCastError(String::from(
258            "Error in type conversion from usize."
259        )))?;
260        let dx7 = dx2 + dx6;
261        let dx8 = &dx7;
262        let dx9 = -&dx7;
263        let dx = dx8 + dx9;
264
265        Ok((dx.into(),s,b.into()))
266    }
267
268    #[inline]
269    fn batch_backward_batch_norm<'a>(&self, loss: &'a <I as BatchDataType>::Type,
270                                 input: &'a <I as BatchDataType>::Type,
271                                 scale: &Arr<U,N>,
272                                 saved_mean: &Arr<U,N>, saved_inv_variance: &Arr<U,N>)
273                                 -> Result<(<I as BatchDataType>::Type, Arr<U,N>, Arr<U,N>), TrainingError> {
274        let loss = SerializedVecView::<'a,U,Arr<U,N>>::try_from(loss)?;
275        let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
276
277        let n = input.len();
278
279        let un = U::from_usize(n).ok_or(TrainingError::TypeCastError(String::from(
280            "Error in type conversion from usize."
281        )))?;
282
283        let un_inv = U::from_usize(1).ok_or(TrainingError::TypeCastError(String::from(
284            "Error in type conversion from usize."
285        )))? / un;
286
287        let b = BroadcastNode::<U,SerializedVecView<'_,U,Arr<U,N>>>::new().backward(loss);
288
289        let x = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((saved_mean,n));
290        let x2 = input - &x;
291        let iv = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((saved_inv_variance,n));
292
293        let s = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&(&x2 * &iv * loss));
294
295        let dx1 = Broadcast(scale.clone()) * loss;
296        let dx2 = &dx1 * iv;
297        let dx3 = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&(&x2 * dx1));
298        let dx4 = -(saved_inv_variance * saved_inv_variance) * dx3;
299        let dx5 = dx4 * (saved_inv_variance * U::from_f64(0.5).ok_or(TrainingError::TypeCastError(String::from(
300            "Error in type conversion from f64.")
301        ))?);
302        let dx6 = SumNode::<U,SerializedVec<U,Arr<U,N>>>::new().backward((&(dx5 * un_inv),n));
303        let dx7 = x2 * dx6 * U::from_usize(2).ok_or(TrainingError::TypeCastError(String::from(
304            "Error in type conversion from usize."
305        )))?;
306        let dx8 = dx2 + dx7;
307        let dx9 = &dx8;
308        let dx10 = -&dx8;
309        let dx11 = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&dx10);
310        let dx12 = SumNode::<U,SerializedVec<U,Arr<U,N>>>::new().backward((&dx11,n)) * un_inv;
311
312        let dx = dx9 + dx12;
313
314        Ok((dx.into_converter().try_into()?,s,b))
315    }
316}
317impl<U,I,const N:usize> DeviceBatchNorm<U,CudaTensor1dPtr<U,N>,I,N> for DeviceGpu<U>
318    where U: UnitValue<U> + DataTypeInfo + AsVoidPtr,
319          I: BatchDataType + Debug + From<CudaTensor1dPtr<U,N>> + 'static,
320          <I as BatchDataType>::Type: Debug + 'static,
321          <I as BatchDataType>::Type: TryFrom<<CudaVec<U,CudaTensor1dPtr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
322          CudaVec<U,CudaTensor1dPtr<U,N>>: IntoConverter,
323          for<'a> CudaTensor1dPtrView<'a,U,N>: From<&'a I>,
324          for<'a> CudaVecView<'a,U,CudaTensor1dPtr<U,N>>: TryFrom<&'a <I as BatchDataType>::Type,Error=TypeConvertError>,
325          f64: From<U> {
326    fn forward_batch_norm<'a>(&self, input: &'a I, scale: &CudaTensor1dPtr<U,N>, bias: &CudaTensor1dPtr<U,N>,
327                          estimated_mean: &CudaTensor1dPtr<U,N>, estimated_variance: &CudaTensor1dPtr<U,N>)
328        -> Result<I,EvaluateError> {
329        let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
330
331        let len = N as i32;
332
333        let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
334
335        let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
336        let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
337
338        unsafe {
339            match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
340                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
341                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
342                    return Err(EvaluateError::CudnnError(
343                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
344                },
345                status => {
346                    return Err(EvaluateError::CudnnError(
347                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
348                }
349            }
350        }
351
352        let alpha = U::one();
353        let beta = U::default();
354
355        let eps = 1e-6;
356
357        unsafe {
358            match cudnnBatchNormalizationForwardInference(
359                *self.cudnn.id_c(),
360                CUDNN_BATCHNORM_SPATIAL,
361                alpha.as_void_ptr(),
362                beta.as_void_ptr(),
363                *xd.id_c(),
364                input.as_void_ptr(),
365                *xd.id_c(),
366                output_ptr.as_mut_void_ptr(),
367                bn_scale_bias_mean_var_desc,
368                scale.as_void_ptr(),
369                bias.as_void_ptr(),
370                estimated_mean.as_void_ptr(),
371                estimated_variance.as_void_ptr(),
372                eps as f64) {
373                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
374                    return Ok(output_ptr.into());
375                },
376                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
377                    return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
378                },
379                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
380                    return Err(EvaluateError::CudnnError(
381                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
382                },
383                status => {
384                    return Err(EvaluateError::CudnnError(
385                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
386                }
387            }
388        }
389    }
390
391    fn forward_batch_norm_train<'a>(&self, input: &'a I,
392                                scale: &CudaTensor1dPtr<U,N>,
393                                bias: &CudaTensor1dPtr<U,N>,
394                                estimated_mean: &CudaTensor1dPtr<U,N>,
395                                estimated_variance: &CudaTensor1dPtr<U,N>) -> Result<(I,CudaTensor1dPtr<U,N>,CudaTensor1dPtr<U,N>),EvaluateError> {
396        let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
397
398        let len = N as i32;
399
400        let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
401
402        let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
403        let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
404
405        unsafe {
406            match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
407                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
408                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
409                    return Err(EvaluateError::CudnnError(
410                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
411                },
412                status => {
413                    return Err(EvaluateError::CudnnError(
414                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
415                }
416            }
417        }
418
419        let alpha = U::one();
420        let beta = U::default();
421
422        let eps = U::from_f64(1e-6).ok_or(
423            EvaluateError::TypeCastError(String::from("An error occurred in floating point type conversion.")))?;
424
425        let mut mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
426        let mut inv_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
427
428        estimated_mean.memcpy_to(&mut mean,N)?;
429        inv_variance.memcpy(estimated_variance.read_to_vec()?.into_boxed_slice()
430                                                                .iter()
431                                                                .map(|&v| U::one() / SqrtNode::new().forward(v + eps))
432                                                                .collect::<Vec<U>>().as_ptr(),N)?;
433
434        let eps = 1e-6;
435
436        unsafe {
437            match cudnnBatchNormalizationForwardInference(
438                *self.cudnn.id_c(),
439                CUDNN_BATCHNORM_SPATIAL,
440                alpha.as_void_ptr(),
441                beta.as_void_ptr(),
442                *xd.id_c(),
443                input.as_void_ptr(),
444                *xd.id_c(),
445                output_ptr.as_mut_void_ptr(),
446                bn_scale_bias_mean_var_desc,
447                scale.as_void_ptr(),
448                bias.as_void_ptr(),
449                estimated_mean.as_void_ptr(),
450                estimated_variance.as_void_ptr(),
451                eps as f64) {
452                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
453                    return Ok((output_ptr.into(),mean,inv_variance));
454                },
455                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
456                    return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
457                },
458                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
459                    return Err(EvaluateError::CudnnError(
460                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
461                },
462                status => {
463                    return Err(EvaluateError::CudnnError(
464                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
465                }
466            }
467        }
468    }
469
470    fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type,
471                                    scale: &CudaTensor1dPtr<U,N>,
472                                    bias: &CudaTensor1dPtr<U,N>,
473                                    estimated_mean: &CudaTensor1dPtr<U,N>, estimated_variance: &CudaTensor1dPtr<U,N>)
474        -> Result<<I as BatchDataType>::Type, EvaluateError> {
475        let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
476
477        let len = input.size();
478
479        let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,&self.memory_pool)?;
480
481        let len = len as i32;
482
483        let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
484        let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
485
486        unsafe {
487            match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
488                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
489                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
490                    return Err(EvaluateError::CudnnError(
491                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
492                },
493                status => {
494                    return Err(EvaluateError::CudnnError(
495                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
496                }
497            }
498        }
499
500        let alpha = U::one();
501        let beta = U::default();
502
503        let eps = 1e-6;
504
505        unsafe {
506            match cudnnBatchNormalizationForwardInference(
507                *self.cudnn.id_c(),
508                CUDNN_BATCHNORM_SPATIAL,
509                alpha.as_void_ptr(),
510                beta.as_void_ptr(),
511                *xd.id_c(),
512                input.as_void_ptr(),
513                *xd.id_c(),
514                output_ptr.as_mut_void_ptr(),
515                bn_scale_bias_mean_var_desc,
516                scale.as_void_ptr(),
517                bias.as_void_ptr(),
518                estimated_mean.as_void_ptr(),
519                estimated_variance.as_void_ptr(),
520                eps as f64) {
521                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
522                    return Ok(output_ptr.into_converter().try_into()?);
523                },
524                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
525                    return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
526                },
527                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
528                    return Err(EvaluateError::CudnnError(
529                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
530                },
531                status => {
532                    return Err(EvaluateError::CudnnError(
533                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
534                }
535            }
536        }
537    }
538
539    fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type,
540                                      scale: &CudaTensor1dPtr<U,N>, bias: &CudaTensor1dPtr<U,N>,
541                                      running_mean: &CudaTensor1dPtr<U,N>, running_variance: &CudaTensor1dPtr<U,N>,
542                                      momentum: U)
543        -> Result<(<I as BatchDataType>::Type,
544                   CudaTensor1dPtr<U,N>,
545                   CudaTensor1dPtr<U,N>,
546                   CudaTensor1dPtr<U,N>,
547                   CudaTensor1dPtr<U,N>), TrainingError> {
548        let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
549
550        let len = input.size();
551
552        let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,self.get_memory_pool())?;
553
554        let len = len as i32;
555
556        let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
557        let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
558
559        unsafe {
560            match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
561                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
562                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
563                    return Err(TrainingError::CudnnError(
564                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
565                },
566                status => {
567                    return Err(TrainingError::CudnnError(
568                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
569                }
570            }
571        }
572
573        let alpha = U::one();
574        let beta = U::default();
575
576        let eps = 1e-6;
577
578        let mut new_running_mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
579        let mut new_running_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
580
581        running_mean.memcpy_to(&mut new_running_mean, N)?;
582        running_variance.memcpy_to(&mut new_running_variance, N)?;
583
584        let mut mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
585        let mut inv_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
586
587        unsafe {
588            match cudnnBatchNormalizationForwardTraining(
589                *self.cudnn.id_c(),
590                CUDNN_BATCHNORM_PER_ACTIVATION,
591                alpha.as_void_ptr(),
592                beta.as_void_ptr(),
593                *xd.id_c(),
594                input.as_void_ptr(),
595                *xd.id_c(),
596                output_ptr.as_mut_void_ptr(),
597                bn_scale_bias_mean_var_desc,
598                scale.as_void_ptr(),
599                bias.as_void_ptr(),
600                1. - f64::from(momentum),
601                new_running_mean.as_mut_void_ptr(),
602                new_running_variance.as_mut_void_ptr(),
603                eps as f64,
604                mean.as_mut_void_ptr(),
605                inv_variance.as_mut_void_ptr()) {
606                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
607                    return Ok((output_ptr.into_converter().try_into()?,
608                               mean,
609                               inv_variance,
610                               new_running_mean,
611                               new_running_variance));
612                },
613                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
614                    return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
615                },
616                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
617                    return Err(TrainingError::CudnnError(
618                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
619                },
620                status => {
621                    return Err(TrainingError::CudnnError(
622                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
623                }
624            }
625        }
626    }
627
628    fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I,
629                               scale: &CudaTensor1dPtr<U,N>,
630                               saved_mean: &CudaTensor1dPtr<U,N>,
631                               saved_inv_variance: &CudaTensor1dPtr<U,N>)
632        -> Result<(I, CudaTensor1dPtr<U,N>, CudaTensor1dPtr<U,N>), TrainingError> {
633        let loss = CudaTensor1dPtrView::<'a,U,N>::from(loss);
634        let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
635
636        let len = N as i32;
637
638        let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
639
640        let bn_scale_bias_diff_desc = API::create_tensor_descriptor()?;
641        let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
642
643        unsafe {
644            match cudnnDeriveBNTensorDescriptor(bn_scale_bias_diff_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
645                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
646                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
647                    return Err(TrainingError::CudnnError(
648                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
649                },
650                status => {
651                    return Err(TrainingError::CudnnError(
652                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
653                }
654            }
655        }
656
657        let eps = 1e-6;
658
659        let alpha = U::one();
660        let beta = U::default();
661
662        let mut result_scale= CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
663        let mut result_bias = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
664
665        unsafe {
666            match cudnnBatchNormalizationBackward(
667                *self.cudnn.id_c(),
668                CUDNN_BATCHNORM_PER_ACTIVATION,
669                alpha.as_void_ptr(),
670                beta.as_void_ptr(),
671                alpha.as_void_ptr(),
672                beta.as_void_ptr(),
673                *xd.id_c(),
674                input.as_void_ptr(),
675                *xd.id_c(),
676                loss.as_void_ptr(),
677                *xd.id_c(),
678                output_ptr.as_mut_void_ptr(),
679                bn_scale_bias_diff_desc,
680                scale.as_void_ptr(),
681                result_scale.as_mut_void_ptr(),
682                result_bias.as_mut_void_ptr(),
683                eps as f64,
684                saved_mean.as_void_ptr(),
685                saved_inv_variance.as_void_ptr()) {
686                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
687                    return Ok((output_ptr.into(),result_scale,result_bias));
688                },
689                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
690                    return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
691                },
692                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
693                    return Err(TrainingError::CudnnError(
694                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
695                },
696                status => {
697                    return Err(TrainingError::CudnnError(
698                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
699                }
700            }
701        }
702    }
703
704    fn batch_backward_batch_norm<'a>(&self, loss: &'a <I as BatchDataType>::Type,
705                                 input: &'a <I as BatchDataType>::Type,
706                                 scale: &CudaTensor1dPtr<U,N>,
707                                 saved_mean: &CudaTensor1dPtr<U,N>, saved_inv_variance: &CudaTensor1dPtr<U,N>)
708        -> Result<(<I as BatchDataType>::Type, CudaTensor1dPtr<U,N>, CudaTensor1dPtr<U,N>), TrainingError> {
709
710        let loss = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(loss)?;
711        let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
712
713        let len = input.size();
714
715        let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,self.get_memory_pool())?;
716
717        let be_scale_bias_diff_desc = API::create_tensor_descriptor()?;
718        let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
719
720        unsafe {
721            match cudnnDeriveBNTensorDescriptor(be_scale_bias_diff_desc, *xd.id_c(), CUDNN_BATCHNORM_SPATIAL) {
722                cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
723                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
724                    return Err(TrainingError::CudnnError(
725                        rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
726                },
727                status => {
728                    return Err(TrainingError::CudnnError(
729                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
730                }
731            }
732        }
733
734        let eps = 1e-6;
735
736        let alpha = U::one();
737        let beta = U::default();
738
739        let mut result_scale= CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
740        let mut result_bias = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
741
742        unsafe {
743            match cudnnBatchNormalizationBackward(
744                *self.cudnn.id_c(),
745                CUDNN_BATCHNORM_PER_ACTIVATION,
746                alpha.as_void_ptr(),
747                beta.as_void_ptr(),
748                alpha.as_void_ptr(),
749                beta.as_void_ptr(),
750                *xd.id_c(),
751                input.as_void_ptr(),
752                *xd.id_c(),
753                loss.as_void_ptr(),
754                *xd.id_c(),
755                output_ptr.as_mut_void_ptr(),
756                be_scale_bias_diff_desc,
757                scale.as_void_ptr(),
758                result_scale.as_mut_void_ptr(),
759                result_bias.as_mut_void_ptr(),
760                eps as f64,
761                saved_mean.as_void_ptr(),
762                saved_inv_variance.as_void_ptr()) {
763                cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
764                    return Ok((output_ptr.into_converter().try_into()?,result_scale,result_bias));
765                },
766                cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
767                    return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
768                },
769                cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
770                    return Err(TrainingError::CudnnError(
771                        rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
772                },
773                status => {
774                    return Err(TrainingError::CudnnError(
775                        rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
776                }
777            }
778        }
779    }
780}