1use std::fmt::Debug;
3use rcudnn::{API};
4use rcudnn_sys::cudnnBatchNormMode_t::{CUDNN_BATCHNORM_PER_ACTIVATION, CUDNN_BATCHNORM_SPATIAL};
5use rcudnn_sys::{cudnnBatchNormalizationBackward, cudnnBatchNormalizationForwardInference, cudnnBatchNormalizationForwardTraining, cudnnDeriveBNTensorDescriptor, cudnnStatus_t};
6
7use crate::arr::{Arr, ArrView, IntoConverter, SerializedVec, SerializedVecView};
8use crate::ope::Sum;
9use crate::collection::Broadcast;
10use crate::computational_graph::{BroadcastNode, GraphNode, SqrtNode, SquareNode, SumNode};
11use crate::cuda::{AsMutVoidPtr, AsVoidPtr, CudaTensor1dPtr, CudaTensor1dPtrView, CudaVec, CudaVecView, DataTypeInfo, WriteMemory, ReadMemory, MemoryMoveTo};
12use crate::cuda::cudnn::tensor::CudnnTensor4dDescriptor;
13use crate::device::{DeviceCpu, DeviceGpu, DeviceMemoryPool};
14use crate::error::{EvaluateError, TrainingError, TypeConvertError};
15use crate::layer::{BatchDataType, BatchSize};
16use crate::ope::UnitValue;
17
18pub trait DeviceBatchNorm<U,C,I,const N:usize>
20 where U: UnitValue<U>,
21 I: BatchDataType + Debug + 'static,
22 <I as BatchDataType>::Type: Debug + 'static {
23 fn forward_batch_norm<'a>(&self, input: &'a I, scale: &C, bias: &C,
37 estimated_mean: &C, estimated_variance: &C) -> Result<I,EvaluateError>;
38 fn forward_batch_norm_train<'a>(&self, input: &'a I, scale: &C, bias: &C,
52 estimated_mean: &C, estimated_variance: &C) -> Result<(I,C,C),EvaluateError>;
53 fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &C , bias: &C,
67 estimated_mean: &C, estimated_variance: &C) -> Result<<I as BatchDataType>::Type,EvaluateError>;
68 fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &C, bias: &C,
85 running_mean: &C, running_variance: &C, momentum: U)
86 -> Result<(<I as BatchDataType>::Type,C,C,C,C),TrainingError>;
87 fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I, scale: &C,
100 saved_mean: &C, saved_inv_variance: &C) -> Result<(I,C,C), TrainingError>;
101 fn batch_backward_batch_norm<'a>(&self, loss:&'a <I as BatchDataType>::Type, input: &'a <I as BatchDataType>::Type,
114 scale: &C, saved_mean: &C, saved_inv_variance: &C)
115 -> Result<(<I as BatchDataType>::Type,C,C), TrainingError>;
116}
117impl<U,I,const N:usize> DeviceBatchNorm<U,Arr<U,N>,I,N> for DeviceCpu<U>
118 where U: UnitValue<U>,
119 I: BatchDataType + Debug + From<Arr<U,N>> + 'static,
120 <I as BatchDataType>::Type: Debug + 'static,
121 <I as BatchDataType>::Type: TryFrom<<SerializedVec<U,Arr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
122 SerializedVec<U,Arr<U,N>>: IntoConverter,
123 for<'a> ArrView<'a,U,N>: From<&'a I>,
124 for<'a> SerializedVecView<'a,U,Arr<U,N>>: TryFrom<&'a <I as BatchDataType>::Type,Error=TypeConvertError> {
125 #[inline]
126 fn forward_batch_norm<'a>(&self, input: &'a I, scale: &Arr<U,N>, bias: &Arr<U,N>,
127 estimated_mean: &Arr<U,N>, estimated_variance: &Arr<U,N>) -> Result<I,EvaluateError> {
128 let input = ArrView::<'a,U,N>::from(input);
129
130 let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
131 "Error in type conversion from usize."
132 )))?;
133
134 Ok(Arr::try_from(input.iter()
135 .zip(scale.iter())
136 .zip(bias.iter())
137 .zip(estimated_mean.iter())
138 .zip(estimated_variance.iter())
139 .map(|((((&i,&scale),&bias),&mean),&variance)| {
140 scale * ((i - mean) / SqrtNode::new().forward(variance + eps)) + bias
141 }).collect::<Vec<U>>())?.into())
142 }
143
144 #[inline]
145 fn forward_batch_norm_train<'a>(&self, input: &'a I,
146 scale: &Arr<U,N>,
147 bias: &Arr<U,N>,
148 estimated_mean: &Arr<U,N>,
149 estimated_variance: &Arr<U,N>) -> Result<(I,Arr<U,N>,Arr<U,N>),EvaluateError> {
150 let input = ArrView::<'a,U,N>::from(input);
151
152 let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
153 "Error in type conversion from usize."
154 )))?;
155
156 Ok((Arr::try_from(input.iter()
157 .zip(scale.iter())
158 .zip(bias.iter())
159 .zip(estimated_mean.iter())
160 .zip(estimated_variance.iter())
161 .map(|((((&i,&scale),&bias),&mean),&variance)| {
162 scale * ((i - mean) / SqrtNode::new().forward(variance + eps)) + bias
163 }).collect::<Vec<U>>())?.into(),
164 estimated_mean.clone(),
165 estimated_variance.iter().map(|&v| U::one() / SqrtNode::new().forward(v + eps)).collect::<Vec<U>>().try_into()?
166 ))
167 }
168
169 #[inline]
170 fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type, scale: &Arr<U,N>, bias: &Arr<U,N>,
171 estimated_mean: &Arr<U,N>, estimated_variance: &Arr<U,N>)
172 -> Result<<I as BatchDataType>::Type, EvaluateError> {
173 let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
174
175 let eps = U::from_f64(1e-6).ok_or(EvaluateError::TypeCastError(String::from(
176 "Error in type conversion from usize."
177 )))?;
178
179 Ok(SerializedVec::from(input.iter().map(|input| {
180 input.iter()
181 .zip(scale.iter())
182 .zip(bias.iter())
183 .zip(estimated_mean.iter())
184 .zip(estimated_variance.iter())
185 .map(|((((&i,&scale),&bias),&mean),&variance)| {
186 scale * (i - mean) / SqrtNode::new().forward(variance + eps) + bias
187 }).collect::<Vec<U>>().try_into()
188 }).collect::<Result<Vec<Arr<U,N>>,_>>()?).into_converter().try_into()?)
189 }
190
191 #[inline]
192 fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type,
193 scale: &Arr<U,N>, bias: &Arr<U,N>,
194 running_mean: &Arr<U,N>, running_variance: &Arr<U,N>,
195 momentum: U)
196 -> Result<(<I as BatchDataType>::Type,Arr<U,N>,Arr<U,N>,Arr<U,N>,Arr<U,N>), TrainingError> {
197 let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
198
199 let eps = U::from_f64(1e-6).ok_or(TrainingError::TypeCastError(String::from(
200 "Error in type conversion from usize."
201 )))?;
202
203 let n = input.len();
204 let un = U::from_usize(n).ok_or(TrainingError::TypeCastError(String::from(
205 "Error in type conversion from usize."
206 )))?;
207
208 let un_inv = U::from_f64(1.).ok_or(TrainingError::TypeCastError(
209 String::from(
210 "Error in type conversion from usize."
211 )
212 ))? / un;
213
214 let mean:Arr<U,N> = SumNode::<U,SerializedVecView<'_,U,Arr<U,N>>>::new().forward(input) * un_inv;
215
216 let variance:SerializedVec<U,Arr<U,N>> = (input - Broadcast::<Arr<U,N>>(mean.clone()))
217 .iter()
218 .map(|i| {
219 i.iter().map(|&i| {
220 SquareNode::new().forward(i)
221 }).collect::<Vec<U>>().try_into()
222 }).collect::<Result<Vec<Arr<U,N>>,_>>()?.into();
223 let variance = variance.sum() * un_inv;
224
225 let inv_variance:Arr<U,N> = variance.iter().map(|&v| U::one() / SqrtNode::new().forward(v + eps)).collect::<Vec<U>>().try_into()?;
226
227 let o:SerializedVec<U,Arr<U,N>> = Broadcast(inv_variance.clone()) * (input - Broadcast(mean.clone()));
228
229 let running_mean = running_mean * momentum + &mean * (U::one() - momentum);
230 let running_variance = running_variance * momentum + variance * (U::one() - momentum);
231
232 let o = (BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((scale,n)) * o) + Broadcast(bias.clone());
233
234 Ok((o.into_converter().try_into()?,mean,inv_variance,running_mean,running_variance))
235 }
236
237 #[inline]
238 fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I,
239 scale: &Arr<U,N>, saved_mean: &Arr<U,N>, saved_inv_variance: &Arr<U,N>)
240 -> Result<(I, Arr<U,N>, Arr<U,N>), TrainingError> {
241 let loss = ArrView::<'a,U,N>::from(loss);
242 let input = ArrView::<'a,U,N>::from(input);
243
244 let b = loss.clone();
245
246 let x = input - saved_mean;
247
248 let s = (&x * saved_inv_variance) * loss;
249
250 let dx1 = scale * loss;
251 let dx2 = &dx1 * saved_inv_variance;
252 let dx3 = &x * dx1;
253 let dx4 = -(saved_inv_variance * saved_inv_variance) * dx3;
254 let dx5 = dx4 * (saved_inv_variance * U::from_f64(0.5).ok_or(TrainingError::TypeCastError(String::from(
255 "Error in type conversion from f64.")
256 ))?);
257 let dx6 = &x * dx5 * U::from_usize(2).ok_or(TrainingError::TypeCastError(String::from(
258 "Error in type conversion from usize."
259 )))?;
260 let dx7 = dx2 + dx6;
261 let dx8 = &dx7;
262 let dx9 = -&dx7;
263 let dx = dx8 + dx9;
264
265 Ok((dx.into(),s,b.into()))
266 }
267
268 #[inline]
269 fn batch_backward_batch_norm<'a>(&self, loss: &'a <I as BatchDataType>::Type,
270 input: &'a <I as BatchDataType>::Type,
271 scale: &Arr<U,N>,
272 saved_mean: &Arr<U,N>, saved_inv_variance: &Arr<U,N>)
273 -> Result<(<I as BatchDataType>::Type, Arr<U,N>, Arr<U,N>), TrainingError> {
274 let loss = SerializedVecView::<'a,U,Arr<U,N>>::try_from(loss)?;
275 let input = SerializedVecView::<'a,U,Arr<U,N>>::try_from(input)?;
276
277 let n = input.len();
278
279 let un = U::from_usize(n).ok_or(TrainingError::TypeCastError(String::from(
280 "Error in type conversion from usize."
281 )))?;
282
283 let un_inv = U::from_usize(1).ok_or(TrainingError::TypeCastError(String::from(
284 "Error in type conversion from usize."
285 )))? / un;
286
287 let b = BroadcastNode::<U,SerializedVecView<'_,U,Arr<U,N>>>::new().backward(loss);
288
289 let x = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((saved_mean,n));
290 let x2 = input - &x;
291 let iv = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().forward((saved_inv_variance,n));
292
293 let s = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&(&x2 * &iv * loss));
294
295 let dx1 = Broadcast(scale.clone()) * loss;
296 let dx2 = &dx1 * iv;
297 let dx3 = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&(&x2 * dx1));
298 let dx4 = -(saved_inv_variance * saved_inv_variance) * dx3;
299 let dx5 = dx4 * (saved_inv_variance * U::from_f64(0.5).ok_or(TrainingError::TypeCastError(String::from(
300 "Error in type conversion from f64.")
301 ))?);
302 let dx6 = SumNode::<U,SerializedVec<U,Arr<U,N>>>::new().backward((&(dx5 * un_inv),n));
303 let dx7 = x2 * dx6 * U::from_usize(2).ok_or(TrainingError::TypeCastError(String::from(
304 "Error in type conversion from usize."
305 )))?;
306 let dx8 = dx2 + dx7;
307 let dx9 = &dx8;
308 let dx10 = -&dx8;
309 let dx11 = BroadcastNode::<U,&SerializedVec<U,Arr<U,N>>>::new().backward(&dx10);
310 let dx12 = SumNode::<U,SerializedVec<U,Arr<U,N>>>::new().backward((&dx11,n)) * un_inv;
311
312 let dx = dx9 + dx12;
313
314 Ok((dx.into_converter().try_into()?,s,b))
315 }
316}
317impl<U,I,const N:usize> DeviceBatchNorm<U,CudaTensor1dPtr<U,N>,I,N> for DeviceGpu<U>
318 where U: UnitValue<U> + DataTypeInfo + AsVoidPtr,
319 I: BatchDataType + Debug + From<CudaTensor1dPtr<U,N>> + 'static,
320 <I as BatchDataType>::Type: Debug + 'static,
321 <I as BatchDataType>::Type: TryFrom<<CudaVec<U,CudaTensor1dPtr<U,N>> as IntoConverter>::Converter,Error=TypeConvertError>,
322 CudaVec<U,CudaTensor1dPtr<U,N>>: IntoConverter,
323 for<'a> CudaTensor1dPtrView<'a,U,N>: From<&'a I>,
324 for<'a> CudaVecView<'a,U,CudaTensor1dPtr<U,N>>: TryFrom<&'a <I as BatchDataType>::Type,Error=TypeConvertError>,
325 f64: From<U> {
326 fn forward_batch_norm<'a>(&self, input: &'a I, scale: &CudaTensor1dPtr<U,N>, bias: &CudaTensor1dPtr<U,N>,
327 estimated_mean: &CudaTensor1dPtr<U,N>, estimated_variance: &CudaTensor1dPtr<U,N>)
328 -> Result<I,EvaluateError> {
329 let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
330
331 let len = N as i32;
332
333 let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
334
335 let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
336 let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
337
338 unsafe {
339 match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
340 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
341 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
342 return Err(EvaluateError::CudnnError(
343 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
344 },
345 status => {
346 return Err(EvaluateError::CudnnError(
347 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
348 }
349 }
350 }
351
352 let alpha = U::one();
353 let beta = U::default();
354
355 let eps = 1e-6;
356
357 unsafe {
358 match cudnnBatchNormalizationForwardInference(
359 *self.cudnn.id_c(),
360 CUDNN_BATCHNORM_SPATIAL,
361 alpha.as_void_ptr(),
362 beta.as_void_ptr(),
363 *xd.id_c(),
364 input.as_void_ptr(),
365 *xd.id_c(),
366 output_ptr.as_mut_void_ptr(),
367 bn_scale_bias_mean_var_desc,
368 scale.as_void_ptr(),
369 bias.as_void_ptr(),
370 estimated_mean.as_void_ptr(),
371 estimated_variance.as_void_ptr(),
372 eps as f64) {
373 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
374 return Ok(output_ptr.into());
375 },
376 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
377 return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
378 },
379 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
380 return Err(EvaluateError::CudnnError(
381 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
382 },
383 status => {
384 return Err(EvaluateError::CudnnError(
385 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
386 }
387 }
388 }
389 }
390
391 fn forward_batch_norm_train<'a>(&self, input: &'a I,
392 scale: &CudaTensor1dPtr<U,N>,
393 bias: &CudaTensor1dPtr<U,N>,
394 estimated_mean: &CudaTensor1dPtr<U,N>,
395 estimated_variance: &CudaTensor1dPtr<U,N>) -> Result<(I,CudaTensor1dPtr<U,N>,CudaTensor1dPtr<U,N>),EvaluateError> {
396 let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
397
398 let len = N as i32;
399
400 let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
401
402 let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
403 let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
404
405 unsafe {
406 match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
407 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
408 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
409 return Err(EvaluateError::CudnnError(
410 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
411 },
412 status => {
413 return Err(EvaluateError::CudnnError(
414 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
415 }
416 }
417 }
418
419 let alpha = U::one();
420 let beta = U::default();
421
422 let eps = U::from_f64(1e-6).ok_or(
423 EvaluateError::TypeCastError(String::from("An error occurred in floating point type conversion.")))?;
424
425 let mut mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
426 let mut inv_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
427
428 estimated_mean.memcpy_to(&mut mean,N)?;
429 inv_variance.memcpy(estimated_variance.read_to_vec()?.into_boxed_slice()
430 .iter()
431 .map(|&v| U::one() / SqrtNode::new().forward(v + eps))
432 .collect::<Vec<U>>().as_ptr(),N)?;
433
434 let eps = 1e-6;
435
436 unsafe {
437 match cudnnBatchNormalizationForwardInference(
438 *self.cudnn.id_c(),
439 CUDNN_BATCHNORM_SPATIAL,
440 alpha.as_void_ptr(),
441 beta.as_void_ptr(),
442 *xd.id_c(),
443 input.as_void_ptr(),
444 *xd.id_c(),
445 output_ptr.as_mut_void_ptr(),
446 bn_scale_bias_mean_var_desc,
447 scale.as_void_ptr(),
448 bias.as_void_ptr(),
449 estimated_mean.as_void_ptr(),
450 estimated_variance.as_void_ptr(),
451 eps as f64) {
452 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
453 return Ok((output_ptr.into(),mean,inv_variance));
454 },
455 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
456 return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
457 },
458 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
459 return Err(EvaluateError::CudnnError(
460 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
461 },
462 status => {
463 return Err(EvaluateError::CudnnError(
464 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
465 }
466 }
467 }
468 }
469
470 fn batch_forward_batch_norm<'a>(&self, input: &'a <I as BatchDataType>::Type,
471 scale: &CudaTensor1dPtr<U,N>,
472 bias: &CudaTensor1dPtr<U,N>,
473 estimated_mean: &CudaTensor1dPtr<U,N>, estimated_variance: &CudaTensor1dPtr<U,N>)
474 -> Result<<I as BatchDataType>::Type, EvaluateError> {
475 let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
476
477 let len = input.size();
478
479 let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,&self.memory_pool)?;
480
481 let len = len as i32;
482
483 let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
484 let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
485
486 unsafe {
487 match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
488 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
489 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
490 return Err(EvaluateError::CudnnError(
491 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
492 },
493 status => {
494 return Err(EvaluateError::CudnnError(
495 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
496 }
497 }
498 }
499
500 let alpha = U::one();
501 let beta = U::default();
502
503 let eps = 1e-6;
504
505 unsafe {
506 match cudnnBatchNormalizationForwardInference(
507 *self.cudnn.id_c(),
508 CUDNN_BATCHNORM_SPATIAL,
509 alpha.as_void_ptr(),
510 beta.as_void_ptr(),
511 *xd.id_c(),
512 input.as_void_ptr(),
513 *xd.id_c(),
514 output_ptr.as_mut_void_ptr(),
515 bn_scale_bias_mean_var_desc,
516 scale.as_void_ptr(),
517 bias.as_void_ptr(),
518 estimated_mean.as_void_ptr(),
519 estimated_variance.as_void_ptr(),
520 eps as f64) {
521 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
522 return Ok(output_ptr.into_converter().try_into()?);
523 },
524 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
525 return Err(EvaluateError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
526 },
527 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
528 return Err(EvaluateError::CudnnError(
529 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
530 },
531 status => {
532 return Err(EvaluateError::CudnnError(
533 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
534 }
535 }
536 }
537 }
538
539 fn batch_forward_batch_norm_train<'a>(&self, input: &'a <I as BatchDataType>::Type,
540 scale: &CudaTensor1dPtr<U,N>, bias: &CudaTensor1dPtr<U,N>,
541 running_mean: &CudaTensor1dPtr<U,N>, running_variance: &CudaTensor1dPtr<U,N>,
542 momentum: U)
543 -> Result<(<I as BatchDataType>::Type,
544 CudaTensor1dPtr<U,N>,
545 CudaTensor1dPtr<U,N>,
546 CudaTensor1dPtr<U,N>,
547 CudaTensor1dPtr<U,N>), TrainingError> {
548 let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
549
550 let len = input.size();
551
552 let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,self.get_memory_pool())?;
553
554 let len = len as i32;
555
556 let bn_scale_bias_mean_var_desc = API::create_tensor_descriptor()?;
557 let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
558
559 unsafe {
560 match cudnnDeriveBNTensorDescriptor(bn_scale_bias_mean_var_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
561 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
562 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
563 return Err(TrainingError::CudnnError(
564 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
565 },
566 status => {
567 return Err(TrainingError::CudnnError(
568 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
569 }
570 }
571 }
572
573 let alpha = U::one();
574 let beta = U::default();
575
576 let eps = 1e-6;
577
578 let mut new_running_mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
579 let mut new_running_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
580
581 running_mean.memcpy_to(&mut new_running_mean, N)?;
582 running_variance.memcpy_to(&mut new_running_variance, N)?;
583
584 let mut mean = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
585 let mut inv_variance = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
586
587 unsafe {
588 match cudnnBatchNormalizationForwardTraining(
589 *self.cudnn.id_c(),
590 CUDNN_BATCHNORM_PER_ACTIVATION,
591 alpha.as_void_ptr(),
592 beta.as_void_ptr(),
593 *xd.id_c(),
594 input.as_void_ptr(),
595 *xd.id_c(),
596 output_ptr.as_mut_void_ptr(),
597 bn_scale_bias_mean_var_desc,
598 scale.as_void_ptr(),
599 bias.as_void_ptr(),
600 1. - f64::from(momentum),
601 new_running_mean.as_mut_void_ptr(),
602 new_running_variance.as_mut_void_ptr(),
603 eps as f64,
604 mean.as_mut_void_ptr(),
605 inv_variance.as_mut_void_ptr()) {
606 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
607 return Ok((output_ptr.into_converter().try_into()?,
608 mean,
609 inv_variance,
610 new_running_mean,
611 new_running_variance));
612 },
613 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
614 return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
615 },
616 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
617 return Err(TrainingError::CudnnError(
618 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
619 },
620 status => {
621 return Err(TrainingError::CudnnError(
622 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
623 }
624 }
625 }
626 }
627
628 fn backward_batch_norm<'a>(&self, loss: &'a I, input: &'a I,
629 scale: &CudaTensor1dPtr<U,N>,
630 saved_mean: &CudaTensor1dPtr<U,N>,
631 saved_inv_variance: &CudaTensor1dPtr<U,N>)
632 -> Result<(I, CudaTensor1dPtr<U,N>, CudaTensor1dPtr<U,N>), TrainingError> {
633 let loss = CudaTensor1dPtrView::<'a,U,N>::from(loss);
634 let input = CudaTensor1dPtrView::<'a,U,N>::from(input);
635
636 let len = N as i32;
637
638 let mut output_ptr = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
639
640 let bn_scale_bias_diff_desc = API::create_tensor_descriptor()?;
641 let xd = CudnnTensor4dDescriptor::<U>::new(1,len as usize,1,1)?;
642
643 unsafe {
644 match cudnnDeriveBNTensorDescriptor(bn_scale_bias_diff_desc,*xd.id_c(),CUDNN_BATCHNORM_SPATIAL) {
645 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
646 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
647 return Err(TrainingError::CudnnError(
648 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
649 },
650 status => {
651 return Err(TrainingError::CudnnError(
652 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
653 }
654 }
655 }
656
657 let eps = 1e-6;
658
659 let alpha = U::one();
660 let beta = U::default();
661
662 let mut result_scale= CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
663 let mut result_bias = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
664
665 unsafe {
666 match cudnnBatchNormalizationBackward(
667 *self.cudnn.id_c(),
668 CUDNN_BATCHNORM_PER_ACTIVATION,
669 alpha.as_void_ptr(),
670 beta.as_void_ptr(),
671 alpha.as_void_ptr(),
672 beta.as_void_ptr(),
673 *xd.id_c(),
674 input.as_void_ptr(),
675 *xd.id_c(),
676 loss.as_void_ptr(),
677 *xd.id_c(),
678 output_ptr.as_mut_void_ptr(),
679 bn_scale_bias_diff_desc,
680 scale.as_void_ptr(),
681 result_scale.as_mut_void_ptr(),
682 result_bias.as_mut_void_ptr(),
683 eps as f64,
684 saved_mean.as_void_ptr(),
685 saved_inv_variance.as_void_ptr()) {
686 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
687 return Ok((output_ptr.into(),result_scale,result_bias));
688 },
689 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
690 return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
691 },
692 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
693 return Err(TrainingError::CudnnError(
694 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
695 },
696 status => {
697 return Err(TrainingError::CudnnError(
698 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
699 }
700 }
701 }
702 }
703
704 fn batch_backward_batch_norm<'a>(&self, loss: &'a <I as BatchDataType>::Type,
705 input: &'a <I as BatchDataType>::Type,
706 scale: &CudaTensor1dPtr<U,N>,
707 saved_mean: &CudaTensor1dPtr<U,N>, saved_inv_variance: &CudaTensor1dPtr<U,N>)
708 -> Result<(<I as BatchDataType>::Type, CudaTensor1dPtr<U,N>, CudaTensor1dPtr<U,N>), TrainingError> {
709
710 let loss = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(loss)?;
711 let input = CudaVecView::<'a,U,CudaTensor1dPtr<U,N>>::try_from(input)?;
712
713 let len = input.size();
714
715 let mut output_ptr = CudaVec::<U,CudaTensor1dPtr<U,N>>::new(len,self.get_memory_pool())?;
716
717 let be_scale_bias_diff_desc = API::create_tensor_descriptor()?;
718 let xd = CudnnTensor4dDescriptor::<U>::new(len as usize,N,1,1)?;
719
720 unsafe {
721 match cudnnDeriveBNTensorDescriptor(be_scale_bias_diff_desc, *xd.id_c(), CUDNN_BATCHNORM_SPATIAL) {
722 cudnnStatus_t::CUDNN_STATUS_SUCCESS => (),
723 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
724 return Err(TrainingError::CudnnError(
725 rcudnn::Error::BadParam("The parameter passed to the vs is invalid.")));
726 },
727 status => {
728 return Err(TrainingError::CudnnError(
729 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
730 }
731 }
732 }
733
734 let eps = 1e-6;
735
736 let alpha = U::one();
737 let beta = U::default();
738
739 let mut result_scale= CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
740 let mut result_bias = CudaTensor1dPtr::<U,N>::new(self.get_memory_pool())?;
741
742 unsafe {
743 match cudnnBatchNormalizationBackward(
744 *self.cudnn.id_c(),
745 CUDNN_BATCHNORM_PER_ACTIVATION,
746 alpha.as_void_ptr(),
747 beta.as_void_ptr(),
748 alpha.as_void_ptr(),
749 beta.as_void_ptr(),
750 *xd.id_c(),
751 input.as_void_ptr(),
752 *xd.id_c(),
753 loss.as_void_ptr(),
754 *xd.id_c(),
755 output_ptr.as_mut_void_ptr(),
756 be_scale_bias_diff_desc,
757 scale.as_void_ptr(),
758 result_scale.as_mut_void_ptr(),
759 result_bias.as_mut_void_ptr(),
760 eps as f64,
761 saved_mean.as_void_ptr(),
762 saved_inv_variance.as_void_ptr()) {
763 cudnnStatus_t::CUDNN_STATUS_SUCCESS => {
764 return Ok((output_ptr.into_converter().try_into()?,result_scale,result_bias));
765 },
766 cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => {
767 return Err(TrainingError::CudnnError(rcudnn::Error::NotSupported("The function does not support the provided configuration.")));
768 },
769 cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => {
770 return Err(TrainingError::CudnnError(
771 rcudnn::Error::BadParam("The parameter passed to the CdnBatchNormalizationForwardInference is invalid.")));
772 },
773 status => {
774 return Err(TrainingError::CudnnError(
775 rcudnn::Error::Unknown("Unable to create the CUDA cuDNN context/resources.", status as i32 as u64)));
776 }
777 }
778 }
779 }
780}