coaster_nn/plugin.rs
1//! Provides the INn Plugin trait for Coaster implementation.
2use crate::co::tensor::SharedTensor;
3use std::fmt::Formatter;
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq)]
6/// Different algorithms to compute the convolution forward algorithm.
7pub enum ConvForwardAlgo {
8 /// Attempt to automatically find the best algorithm of all the other available ones.
9 Auto,
10 /// Compute the convolution as explicit matrix product.
11 ///
12 /// Needs a significant memory workspace.
13 GEMM,
14 /// Compute the convolution as matrix product without forming the matrix that holds the input data.
15 ///
16 /// Does not need any memory workspace.
17 ImplicitGEMM,
18 /// Similar to `ImplicitGEMM` but needs some workspace to precompile the implicit indices.
19 ImplicitPrecompiledGEMM,
20 /// Compute the convolution as Fast-Fourier Transform.
21 ///
22 /// Needs a significant memory workspace.
23 FFT,
24 /// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
25 ///
26 /// Needs a significant memory workspace.
27 FFTTiling,
28 /// Compute the convolution without implicit or explicit matrix-multiplication. **Do not try to use this**.
29 ///
30 /// Listed in cuDNN docs but cuDNN does not provide a implementation.
31 Direct,
32 /// Winograd Transform
33 Winograd,
34 /// Winograd Transform Non-Fused
35 WinogradNonFused,
36}
37
38impl ConvForwardAlgo {
39 /// Check if algorithim should be chosen automatically.
40 pub fn is_auto(&self) -> bool {
41 match *self {
42 ConvForwardAlgo::Auto => true,
43 _ => false,
44 }
45 }
46}
47
48#[derive(Debug, Copy, Clone, PartialEq, Eq)]
49/// Different algorithms to compute the gradient with respect to the filter.
50pub enum ConvBackwardFilterAlgo {
51 /// Attempt to automatically find the best algorithm of all the other available ones.
52 Auto,
53 /// Compute the convolution as matrix product without forming the matrix that holds the input data.
54 ///
55 /// Does not need any memory workspace.
56 ///
57 /// The results are deterministic.
58 ImplicitGEMM,
59 /// Compute the convolution as sum of matrix product without forming the matrix that holds the input data.
60 ///
61 /// Does not need any memory workspace.
62 ///
63 /// The results are non-deterministic.
64 ImplicitGEMMSum,
65 /// Similar to `ImplicitGEMMSum` but needs some workspace to precompile the implicit indices.
66 ///
67 /// The results are non-deterministic.
68 ImplicitPrecompiledGEMMSum,
69 /// Compute the convolution as Fast-Fourier Transform.
70 ///
71 /// Needs a significant memory workspace.
72 ///
73 /// The results are deterministic.
74 FFT,
75 /// Winograd Transform Non-Fused
76 WinogradNonFused,
77}
78
79impl ConvBackwardFilterAlgo {
80 /// Check if algorithim should be chosen automatically.
81 pub fn is_auto(&self) -> bool {
82 match *self {
83 ConvBackwardFilterAlgo::Auto => true,
84 _ => false,
85 }
86 }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq)]
90/// Different algorithms to compute the gradient with respect to the filter.
91pub enum ConvBackwardDataAlgo {
92 /// Attempt to automatically find the best algorithm of all the other available ones.
93 Auto,
94 /// Compute the convolution as matrix product without forming the matrix that holds the input data.
95 ///
96 /// Does not need any memory workspace.
97 ///
98 /// The results are deterministic.
99 ImplicitGEMM,
100 /// Compute the convolution as sum of matrix product without forming the matrix that holds the input data.
101 ///
102 /// Does not need any memory workspace.
103 ///
104 /// The results are non-deterministic.
105 ImplicitGEMMSum,
106 /// Compute the convolution as Fast-Fourier Transform.
107 ///
108 /// Needs a significant memory workspace.
109 ///
110 /// The results are deterministic.
111 FFT,
112 /// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
113 ///
114 /// Needs a significant memory workspace.
115 ///
116 /// The results are deterministic.
117 FFTTiling,
118 /// Winograd Transform
119 Winograd,
120 /// Winograd Transform Non-Fused
121 WinogradNonFused,
122}
123
124impl ConvBackwardDataAlgo {
125 /// Check if algorithim should be chosen automatically.
126 pub fn is_auto(&self) -> bool {
127 match *self {
128 ConvBackwardDataAlgo::Auto => true,
129 _ => false,
130 }
131 }
132}
133
134/// Provides generic NN Operation Config functionality.
135///
136/// Needs to be implemented for Operation specific configurations.
137pub trait NNOperationConfig<F> {}
138
139/// Provides Convolution Config functionality.
140///
141/// Needs to be implemented for Operation specific configurations.
142pub trait ConvolutionConfig<F> {
143 /// Returns the largest workspace size in bytes needed
144 /// for any of the convolution operations.
145 fn workspace_size(&self) -> usize {
146 0
147 }
148}
149
150/// Provides Rnn Config functionality.
151///
152/// Needs to be implemented for Operation specific configurations.
153pub trait RnnConfig<F> {
154 /// Workspace Size - Overwritten by each plugin method except native, which doesn't require
155 /// a workspace size.
156 fn workspace_size(&self) -> usize {
157 0
158 }
159}
160
161/// Provides the functionality for a backend to support Neural Network related operations.
162pub trait NN<F> {
163 /// The Convolution Operation Config representation for this Plugin.
164 type CC: NNOperationConfig<F> + ConvolutionConfig<F>;
165 /// The LRN Operation Config representation for this Plugin.
166 type CLRN: NNOperationConfig<F>;
167 /// The Pooling Operation Config representation for this Plugin.
168 type CPOOL: NNOperationConfig<F>;
169 // /// The Activation Operation Config representation for this Plugin.
170 // type CACTI: NNOperationConfig<F>;
171 /// The Dropout Operation Config representation for this Plugin.
172 type CDROP: NNOperationConfig<F>;
173 /// The RNN Operation Config representation for this Plugin
174 type CRNN: NNOperationConfig<F> + RnnConfig<F>;
175
176 /// Initializes the Plugin.
177 fn init_nn();
178}
179
180/// Provides the functionality for a Backend to support Sigmoid operations.
181pub trait Sigmoid<F>: NN<F> {
182 /// Computes the [Sigmoid function][sigmoid] over the input Tensor `x`.
183 /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
184 ///
185 /// Saves the result to `result`.
186 fn sigmoid(
187 &self,
188 x: &SharedTensor<F>,
189 result: &mut SharedTensor<F>,
190 ) -> Result<(), crate::co::error::Error>;
191
192 /// Computes the gradient of a [Sigmoid function][sigmoid] over the input Tensor `x`.
193 /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
194 ///
195 /// Saves the result to `result_diff`.
196 fn sigmoid_grad(
197 &self,
198 x: &SharedTensor<F>,
199 x_diff: &SharedTensor<F>,
200 result: &SharedTensor<F>,
201 result_diff: &mut SharedTensor<F>,
202 ) -> Result<(), crate::co::error::Error>;
203}
204
205/// Provides the functionality for pointwise Sigmoid operations (overwrites the input with the result of the operation).
206pub trait SigmoidPointwise<F>: NN<F> {
207 /// Computes the [Sigmoid function][sigmoid] over the input Tensor `x`.
208 /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
209 ///
210 /// Saves the result back to `x`.
211 ///
212 /// For a no-memory managed version see `sigmoid_pointwise_plain`.
213 fn sigmoid_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
214
215 /// Computes the gradient of a [Sigmoid function][sigmoid] over the input Tensor `x`.
216 /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
217 ///
218 /// Saves the result back to `x_diff`.
219 fn sigmoid_pointwise_grad(
220 &self,
221 x: &SharedTensor<F>,
222 x_diff: &mut SharedTensor<F>,
223 ) -> Result<(), crate::co::error::Error>;
224}
225
226/// Provides the functionality for a Backend to support ReLU operations.
227pub trait Relu<F>: NN<F> {
228 /// Computes the [Rectified linear units][relu] over the input Tensor `x`.
229 /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
230 ///
231 /// Saves the result to `result`.
232 fn relu(
233 &self,
234 x: &SharedTensor<F>,
235 result: &mut SharedTensor<F>,
236 ) -> Result<(), crate::co::error::Error>;
237
238 /// Computes the gradient of [ReLU][relu] over the input Tensor `x`.
239 /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
240 ///
241 /// Saves the result to `result_diff`.
242 fn relu_grad(
243 &self,
244 x: &SharedTensor<F>,
245 x_diff: &SharedTensor<F>,
246 result: &SharedTensor<F>,
247 result_diff: &mut SharedTensor<F>,
248 ) -> Result<(), crate::co::error::Error>;
249}
250
251/// Provides the functionality for pointwise ReLU operations (overwrites the input with the result of the operation).
252pub trait ReluPointwise<F>: NN<F> {
253 /// Computes the [Rectified linear units][relu] over the input Tensor `x`.
254 /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
255 ///
256 /// Saves the result back to `x`.
257 fn relu_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
258
259 /// Computes the gradient of [ReLU][relu] over the input Tensor `x`.
260 /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
261 ///
262 /// Saves the result back to `x_diff`.
263 fn relu_pointwise_grad(
264 &self,
265 x: &SharedTensor<F>,
266 x_diff: &mut SharedTensor<F>,
267 ) -> Result<(), crate::co::error::Error>;
268}
269
270/// Provides the functionality for a Backend to support TanH operations.
271pub trait Tanh<F>: NN<F> {
272 /// Computes the [hyperbolic Tangent][tanh] over the input Tensor `x`.
273 /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
274 ///
275 /// Saves the result to `result`.
276 fn tanh(
277 &self,
278 x: &SharedTensor<F>,
279 result: &mut SharedTensor<F>,
280 ) -> Result<(), crate::co::error::Error>;
281
282 /// Computes the gradient of [hyperbolic Tangent][tanh] over the input Tensor `x`.
283 /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
284 ///
285 /// Saves the result to `result_diff`.
286 fn tanh_grad(
287 &self,
288 x: &SharedTensor<F>,
289 x_diff: &SharedTensor<F>,
290 result: &SharedTensor<F>,
291 result_diff: &mut SharedTensor<F>,
292 ) -> Result<(), crate::co::error::Error>;
293}
294
295/// Provides the functionality for pointwise ReLU operations (overwrites the input
296/// with the result of the operation).
297pub trait TanhPointwise<F>: NN<F> {
298 /// Computes the [hyperbolic Tangent][tanh] over the input Tensor `x`.
299 /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
300 ///
301 /// Saves the result back to `x`.
302 fn tanh_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
303
304 /// Computes the gradient of [tanh][tanh] over the input Tensor `x`.
305 /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
306 ///
307 /// Saves the result back to `x_diff`.
308 fn tanh_pointwise_grad(
309 &self,
310 x: &SharedTensor<F>,
311 x_diff: &mut SharedTensor<F>,
312 ) -> Result<(), crate::co::error::Error>;
313}
314
315/// Provide the functionality for a Backend to support RNN operations
316pub trait Rnn<F>: NN<F> {
317 /// Create a RnnConfig
318 fn new_rnn_config(
319 &self,
320 src: &SharedTensor<F>,
321 dropout_probability: Option<f32>,
322 dropout_seed: Option<u64>,
323 sequence_length: i32,
324 network_mode: RnnNetworkMode,
325 input_mode: RnnInputMode,
326 direction_mode: DirectionMode,
327 algorithm: RnnAlgorithm,
328 hidden_size: i32,
329 num_layers: i32,
330 batch_size: i32,
331 // RC being RNNConfig
332 ) -> Result<Self::CRNN, crate::co::error::Error>;
333
334 /// Generate Weights for RNN
335 fn generate_rnn_weight_description(
336 &self,
337 rnn_config: &Self::CRNN,
338 batch_size: i32,
339 input_size: i32,
340 ) -> Result<Vec<usize>, crate::co::error::Error>;
341
342 /// Train a LSTM Network and Return Results
343 // TODO: Create alternate rnn_forward or alternate path to work with pretrained networks
344 /// # Arguments
345 /// * `weight_desc` Previously initialised FilterDescriptor for Weights
346 fn rnn_forward(
347 &self,
348 src: &SharedTensor<F>,
349 output: &mut SharedTensor<F>,
350 rnn_config: &Self::CRNN,
351 weight: &SharedTensor<F>,
352 workspace: &mut SharedTensor<u8>,
353 ) -> Result<(), crate::co::error::Error>;
354
355 /// Calculates RNN Gradients for Input/Hidden/Cell
356 fn rnn_backward_data(
357 &self,
358 src: &SharedTensor<F>,
359 src_gradient: &mut SharedTensor<F>,
360 output: &SharedTensor<F>,
361 output_gradient: &SharedTensor<F>,
362 rnn_config: &Self::CRNN,
363 weight: &SharedTensor<F>,
364 workspace: &mut SharedTensor<u8>,
365 ) -> Result<(), crate::co::error::Error>;
366
367 /// Calculates RNN Gradients for Weights
368 fn rnn_backward_weights(
369 &self,
370 src: &SharedTensor<F>,
371 output: &SharedTensor<F>,
372 filter: &mut SharedTensor<F>,
373 rnn_config: &Self::CRNN,
374 workspace: &mut SharedTensor<u8>,
375 ) -> Result<(), crate::co::error::Error>;
376}
377
378#[derive(Debug, Copy, Clone, PartialEq, Eq)]
379/// Network Type for RNN Networks [cudnnRNNMOde_t][1]
380/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNMode_t
381pub enum RnnNetworkMode {
382 /// CUDNN_RNN_RELU - Single gate RNN with a ReLU activation function
383 ReLU,
384 /// Single-gate RNN with a tanh activation function
385 Tanh,
386 /// Four-gate LSTM Network with no peephole connection
387 LSTM,
388 /// Three-gate network with Gated Recurrent Units
389 GRU,
390}
391
392impl std::fmt::Display for RnnNetworkMode {
393 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
394 let result = match self {
395 RnnNetworkMode::ReLU => "RelU",
396 RnnNetworkMode::Tanh => "Tanh",
397 RnnNetworkMode::LSTM => "LSTM",
398 RnnNetworkMode::GRU => "GRU",
399 };
400 write!(f, "{}", result)
401 }
402}
403
404impl RnnNetworkMode {
405 /// Convert RnnNetworkMode to String Representation
406 pub fn from_string(input: &str) -> Result<Self, &str> {
407 match input {
408 "GRU" => Ok(RnnNetworkMode::GRU),
409 "LSTM" => Ok(RnnNetworkMode::LSTM),
410 "ReLU" => Ok(RnnNetworkMode::ReLU),
411 "Tanh" => Ok(RnnNetworkMode::Tanh),
412 _ => Err("Unknown RnnType used - variants are GRU, LSTM, ReLU, and Tanhd"),
413 }
414 }
415}
416
417#[derive(Debug, Copy, Clone, PartialEq, Eq)]
418/// Input Modes for RNN [cudnnRNNInputMode_t][1]
419/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNInputMode_t
420pub enum RnnInputMode {
421 /// CUDNN_LINEAR_INPUT - A biased matrix multiplication is performed at the input of the first
422 /// recurrent layer
423 LinearInput,
424 /// CUDNN_SKIP_INPUT - No operation is performed at the input of the first recurrent layer -
425 /// if this is used then the leading dimension of the input tensor must be equal to the hidden
426 /// state size of the network.
427 SkipInput,
428}
429
430impl std::fmt::Display for RnnInputMode {
431 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
432 let result = match self {
433 RnnInputMode::LinearInput => "LinearInput",
434 RnnInputMode::SkipInput => "SkipInput",
435 };
436 write!(f, "{}", result)
437 }
438}
439
440impl RnnInputMode {
441 /// Convert to RnnInputMode from String Representation
442 pub fn from_string(input: &str) -> Result<Self, &str> {
443 match input {
444 "LinearInput" => Ok(RnnInputMode::LinearInput),
445 "SkipInput" => Ok(RnnInputMode::SkipInput),
446 _ => Err("Unknown RnnInputMode used - variants are LinearInput, SkipInput"),
447 }
448 }
449}
450
451#[derive(Debug, Copy, Clone, PartialEq, Eq)]
452/// Direction Mode for RNN [cudnnDirectionMode_t][1]
453/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnDirectionMode_t
454pub enum DirectionMode {
455 /// CUDNN_UNIDIRECTIONAL - The network iterates from first to last
456 UniDirectional,
457 /// CUDNN_BIDIRECTION - Concats recurrent output of First -> Last && Last -> First
458 BiDirectional,
459}
460
461impl std::fmt::Display for DirectionMode {
462 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463 let result = match self {
464 DirectionMode::UniDirectional => "UniDirectional",
465 DirectionMode::BiDirectional => "BiDirectional",
466 };
467 write!(f, "{}", result)
468 }
469}
470
471impl DirectionMode {
472 /// Convert to DirectionMode from String Representation
473 pub fn from_string(input: &str) -> Result<Self, &str> {
474 match input {
475 "UniDirectional" => Ok(DirectionMode::UniDirectional),
476 "BiDirectional" => Ok(DirectionMode::BiDirectional),
477 _ => Err("Unknown DirectionMode used - variants are UniDirectional, BiDirectional"),
478 }
479 }
480}
481
482#[derive(Debug, Copy, Clone, PartialEq, Eq)]
483/// Algorithm for RNN [cudnnRNNAlgo_t][1]
484/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNAlgo_t
485///
486/// Persist Static requires v6+
487pub enum RnnAlgorithm {
488 /// Sequence of Operations for each RNN Layer
489 Standard,
490 /// Uses a Persistent Kernel - fast when the first D of the input is small
491 PersistStatic,
492 /// RNN parts use a persistent kernel. Fast when the first dimension is small, and when it can
493 /// reuse plans in repeated calls.
494 PersistDynamic,
495 /// Count - Cannot find in docs but is in Generated - FIXME
496 Count,
497}
498
499impl std::fmt::Display for RnnAlgorithm {
500 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
501 let result = match self {
502 RnnAlgorithm::Standard => "Standard",
503 RnnAlgorithm::PersistStatic => "PersistStatic",
504 RnnAlgorithm::PersistDynamic => "PersistDynamic",
505 RnnAlgorithm::Count => unreachable!(),
506 };
507 write!(f, "{}", result)
508 }
509}
510
511impl RnnAlgorithm {
512 /// Convert to RnnAlgorithm from String Representation
513 fn from_string(input: &str) -> Result<Self, &str> {
514 match input {
515 "Standard" => Ok(RnnAlgorithm::Standard),
516 "PersistStatic" => Ok(RnnAlgorithm::PersistStatic),
517 "PersistDynamic" => Ok(RnnAlgorithm::PersistDynamic),
518 _ => Err(
519 "Unknown RnnAlgorithm used - variants are Standard, PersistStatic, PersistDynamic",
520 ),
521 }
522 }
523}
524
525#[derive(Debug, Copy, Clone)]
526/// Enables/Disables the padded input/output [cudnnRNNPaddingMode_t][1]
527/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNPaddingMode_t
528pub enum RnnPaddingMode {
529 /// Padding disabled
530 Disabled,
531 /// Padding enabled
532 Enabled,
533}
534
535#[derive(Debug, Copy, Clone)]
536/// Indicate if Tensor Core Operations are permitted [cudnnMathType_t][1]
537/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnMathType_t
538pub enum MathType {
539 /// No Tensor Core ops
540 Default,
541 /// Uses Tensor Core ops
542 TensorOPMath,
543 /// Uses FP32 Tensors for input/output
544 TensorOPMathAllowConversion,
545}
546
547/// Provides the functionality for a Backend to support Convolution operations.
548pub trait Convolution<F>: NN<F> {
549 /// Creates a new ConvolutionConfig, which needs to be passed to further
550 /// convolution Operations.
551 fn new_convolution_config(
552 &self,
553 src: &SharedTensor<F>,
554 dest: &SharedTensor<F>,
555 filter: &SharedTensor<F>,
556 algo_fwd: ConvForwardAlgo,
557 algo_bwd_filter: ConvBackwardFilterAlgo,
558 algo_bwd_data: ConvBackwardDataAlgo,
559 stride: &[i32],
560 zero_padding: &[i32],
561 ) -> Result<Self::CC, crate::co::error::Error>;
562
563 /// Computes a [CNN convolution][convolution] over the input Tensor `x`.
564 /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
565 ///
566 /// Saves the result to `result`.
567 fn convolution(
568 &self,
569 filter: &SharedTensor<F>,
570 x: &SharedTensor<F>,
571 result: &mut SharedTensor<F>,
572 workspace: &mut SharedTensor<u8>,
573 config: &Self::CC,
574 ) -> Result<(), crate::co::error::Error>;
575
576 /// Computes the gradient of a [CNN convolution][convolution] with respect to the filter.
577 /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
578 ///
579 /// Saves the result to `filter_diff`.
580 fn convolution_grad_filter(
581 &self,
582 src_data: &SharedTensor<F>,
583 dest_diff: &SharedTensor<F>,
584 filter_diff: &mut SharedTensor<F>,
585 workspace: &mut SharedTensor<u8>,
586 config: &Self::CC,
587 ) -> Result<(), crate::co::error::Error>;
588
589 /// Computes the gradient of a [CNN convolution][convolution] over the input
590 /// Tensor `x` with respect to the data.
591 /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
592 ///
593 /// Saves the result to `result_diff`.
594 fn convolution_grad_data(
595 &self,
596 filter: &SharedTensor<F>,
597 x_diff: &SharedTensor<F>,
598 result_diff: &mut SharedTensor<F>,
599 workspace: &mut SharedTensor<u8>,
600 config: &Self::CC,
601 ) -> Result<(), crate::co::error::Error>;
602
603 // /// Computes the backward Convolution function w.r.t the bias.
604 // ///
605 // /// Writes the result of the computation to `bias_data`.
606 // pub fn convolution_backward_bias<T>(
607 // &self,
608 // dest_grad_desc: &TensorDescriptor,
609 // dest_grad_data: *const ::libc::c_void,
610 // bias_desc: &TensorDescriptor,
611 // bias_data: *mut ::libc::c_void,
612 // scale: ScalParams<T>,
613 // }
614 //
615 // /// Computes the backward Convolution function w.r.t the filter.
616 // ///
617 // /// Writes the result of the computation to `filter_data`.
618 // pub fn convolution_backward_filter<T>(
619 // &self,
620 // conv_config: &ConvolutionConfig,
621 // src_desc: &TensorDescriptor,
622 // src_data: *const ::libc::c_void,
623 // dest_grad_desc: &TensorDescriptor,
624 // dest_grad_data: *const ::libc::c_void,
625 // filter_data: *mut ::libc::c_void,
626 // scale: ScalParams<T>,
627 // }
628}
629
630/// Provides the functionality for a Backend to support Softmax operations.
631pub trait Softmax<F>: NN<F> {
632 /// Computes a [Softmax][softmax] over the input Tensor `x`.
633 /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function
634 ///
635 /// Saves the result to `result`.
636 fn softmax(
637 &self,
638 x: &SharedTensor<F>,
639 result: &mut SharedTensor<F>,
640 ) -> Result<(), crate::co::error::Error>;
641
642 /// Computes the gradient of a [Softmax][softmax] over the input Tensor `x`.
643 /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function
644 ///
645 /// Saves the result to `result_diff`.
646 fn softmax_grad(
647 &self,
648 x: &SharedTensor<F>,
649 x_diff: &SharedTensor<F>,
650 result_diff: &mut SharedTensor<F>,
651 ) -> Result<(), crate::co::error::Error>;
652}
653
654/// Provides the functionality for a Backend to support LogSoftmax operations.
655pub trait LogSoftmax<F>: NN<F> {
656 /// Computes a logarithmic softmax over the input Tensor `x`.
657 ///
658 /// Saves the result to `result`.
659 fn log_softmax(
660 &self,
661 x: &SharedTensor<F>,
662 result: &mut SharedTensor<F>,
663 ) -> Result<(), crate::co::error::Error>;
664
665 /// Computes the gradient of a logarithmic softmax over the input Tensor `x`.
666 ///
667 /// Saves the result to `result_diff`.
668 fn log_softmax_grad(
669 &self,
670 x: &SharedTensor<F>,
671 x_diff: &SharedTensor<F>,
672 result_diff: &mut SharedTensor<F>,
673 ) -> Result<(), crate::co::error::Error>;
674}
675
676/// Provides the functionality for a Backend to support Local Response Normalization operations.
677pub trait LRN<F>: NN<F> {
678 /// Creates a new (Local Response Normalization) LRNConfig, which needs to be
679 /// passed to further LRN Operations.
680 fn new_lrn_config(
681 &self,
682 n: u32,
683 alpha: f64,
684 beta: f64,
685 k: f64,
686 ) -> Result<Self::CLRN, crate::co::error::Error>;
687
688 /// Computes a [LRN][lrn] over the input Tensor `x`.
689 /// [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network
690 ///
691 /// Saves the result to `result`.
692 fn lrn(
693 &self,
694 x: &SharedTensor<F>,
695 result: &mut SharedTensor<F>,
696 config: &Self::CLRN,
697 ) -> Result<(), crate::co::error::Error>;
698
699 /// Computes the gradient of a [LRN][lrn] over the input Tensor `x`.
700 /// [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network
701 ///
702 /// Saves the result to `result_diff`.
703 fn lrn_grad(
704 &self,
705 x: &SharedTensor<F>,
706 x_diff: &SharedTensor<F>,
707 result: &SharedTensor<F>,
708 result_diff: &mut SharedTensor<F>,
709 config: &Self::CLRN,
710 ) -> Result<(), crate::co::error::Error>;
711}
712
713/// Provides the functionality for a Backend to support Pooling operations.
714pub trait Pooling<F>: NN<F> {
715 /// Creates a new PoolingConfig, which needs to be passed to further pooling Operations.
716 fn new_pooling_config(
717 &self,
718 window: &[i32],
719 stride: &[i32],
720 padding: &[i32],
721 ) -> Result<Self::CPOOL, crate::co::error::Error>;
722
723 /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
724 /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
725 ///
726 /// Saves the result to `result`.
727 fn pooling_max(
728 &self,
729 x: &SharedTensor<F>,
730 result: &mut SharedTensor<F>,
731 config: &Self::CPOOL,
732 ) -> Result<(), crate::co::error::Error>;
733
734 /// Computes the gradient of [max Pooling][pooling] over the input Tensor `x`.
735 /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
736 ///
737 /// Saves the result to `result_diff`.
738 fn pooling_max_grad(
739 &self,
740 x: &SharedTensor<F>,
741 x_diff: &SharedTensor<F>,
742 result: &SharedTensor<F>,
743 result_diff: &mut SharedTensor<F>,
744 config: &Self::CPOOL,
745 ) -> Result<(), crate::co::error::Error>;
746
747 /// Computes non-linear down-sampling ([average Pooling][pooling]) over the input Tensor `x`.
748 /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
749 ///
750 /// Saves the result to `result`.
751 fn pooling_avg(
752 &self,
753 x: &SharedTensor<F>,
754 result: &mut SharedTensor<F>,
755 config: &Self::CPOOL,
756 ) -> Result<(), crate::co::error::Error>;
757
758 /// Computes the gradient of [average Pooling][pooling] over the input Tensor `x`.
759 /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
760 ///
761 /// Saves the result to `result_diff`.
762 fn pooling_avg_grad(
763 &self,
764 x: &SharedTensor<F>,
765 x_diff: &SharedTensor<F>,
766 result: &SharedTensor<F>,
767 result_diff: &mut SharedTensor<F>,
768 config: &Self::CPOOL,
769 ) -> Result<(), crate::co::error::Error>;
770}
771
772/// Provides the functionality for a Backend to support Dropout operations.
773pub trait Dropout<F>: NN<F> {
774 /// Creates a new DropoutConfig, which needs to be passed to further dropout Operations.
775 fn new_dropout_config(
776 &self,
777 dropout: f32,
778 seed: u64,
779 ) -> Result<Self::CDROP, crate::co::error::Error>;
780
781 /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
782 /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
783 ///
784 /// Saves the result to `result`.
785 fn dropout(
786 &self,
787 x: &SharedTensor<F>,
788 result: &mut SharedTensor<F>,
789 config: &Self::CDROP,
790 ) -> Result<(), crate::co::error::Error>;
791
792 /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
793 /// [pooling]: https://en.wikipedia.org/wiki/Dropout_(neural_networks)
794 ///
795 /// Saves the result to `result`.
796 fn dropout_grad(
797 &self,
798 x: &SharedTensor<F>,
799 x_diff: &SharedTensor<F>,
800 result: &SharedTensor<F>,
801 result_diff: &mut SharedTensor<F>,
802 config: &Self::CDROP,
803 ) -> Result<(), crate::co::error::Error>;
804}