coaster_nn/
plugin.rs

1//! Provides the INn Plugin trait for Coaster implementation.
2use crate::co::tensor::SharedTensor;
3use std::fmt::Formatter;
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq)]
6/// Different algorithms to compute the convolution forward algorithm.
7pub enum ConvForwardAlgo {
8    /// Attempt to automatically find the best algorithm of all the other available ones.
9    Auto,
10    /// Compute the convolution as explicit matrix product.
11    ///
12    /// Needs a significant memory workspace.
13    GEMM,
14    /// Compute the convolution as matrix product without forming the matrix that holds the input data.
15    ///
16    /// Does not need any memory workspace.
17    ImplicitGEMM,
18    /// Similar to `ImplicitGEMM` but needs some workspace to precompile the implicit indices.
19    ImplicitPrecompiledGEMM,
20    /// Compute the convolution as Fast-Fourier Transform.
21    ///
22    /// Needs a significant memory workspace.
23    FFT,
24    /// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
25    ///
26    /// Needs a significant memory workspace.
27    FFTTiling,
28    /// Compute the convolution without implicit or explicit matrix-multiplication. **Do not try to use this**.
29    ///
30    /// Listed in cuDNN docs but cuDNN does not provide a implementation.
31    Direct,
32    /// Winograd  Transform
33    Winograd,
34    /// Winograd  Transform Non-Fused
35    WinogradNonFused,
36}
37
38impl ConvForwardAlgo {
39    /// Check if algorithim should be chosen automatically.
40    pub fn is_auto(&self) -> bool {
41        match *self {
42            ConvForwardAlgo::Auto => true,
43            _ => false,
44        }
45    }
46}
47
48#[derive(Debug, Copy, Clone, PartialEq, Eq)]
49/// Different algorithms to compute the gradient with respect to the filter.
50pub enum ConvBackwardFilterAlgo {
51    /// Attempt to automatically find the best algorithm of all the other available ones.
52    Auto,
53    /// Compute the convolution as matrix product without forming the matrix that holds the input data.
54    ///
55    /// Does not need any memory workspace.
56    ///
57    /// The results are deterministic.
58    ImplicitGEMM,
59    /// Compute the convolution as sum of matrix product without forming the matrix that holds the input data.
60    ///
61    /// Does not need any memory workspace.
62    ///
63    /// The results are non-deterministic.
64    ImplicitGEMMSum,
65    /// Similar to `ImplicitGEMMSum` but needs some workspace to precompile the implicit indices.
66    ///
67    /// The results are non-deterministic.
68    ImplicitPrecompiledGEMMSum,
69    /// Compute the convolution as Fast-Fourier Transform.
70    ///
71    /// Needs a significant memory workspace.
72    ///
73    /// The results are deterministic.
74    FFT,
75    /// Winograd  Transform Non-Fused
76    WinogradNonFused,
77}
78
79impl ConvBackwardFilterAlgo {
80    /// Check if algorithim should be chosen automatically.
81    pub fn is_auto(&self) -> bool {
82        match *self {
83            ConvBackwardFilterAlgo::Auto => true,
84            _ => false,
85        }
86    }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq)]
90/// Different algorithms to compute the gradient with respect to the filter.
91pub enum ConvBackwardDataAlgo {
92    /// Attempt to automatically find the best algorithm of all the other available ones.
93    Auto,
94    /// Compute the convolution as matrix product without forming the matrix that holds the input data.
95    ///
96    /// Does not need any memory workspace.
97    ///
98    /// The results are deterministic.
99    ImplicitGEMM,
100    /// Compute the convolution as sum of matrix product without forming the matrix that holds the input data.
101    ///
102    /// Does not need any memory workspace.
103    ///
104    /// The results are non-deterministic.
105    ImplicitGEMMSum,
106    /// Compute the convolution as Fast-Fourier Transform.
107    ///
108    /// Needs a significant memory workspace.
109    ///
110    /// The results are deterministic.
111    FFT,
112    /// Compute the convolution as Fast-Fourier Transform with 32x32 tiles.
113    ///
114    /// Needs a significant memory workspace.
115    ///
116    /// The results are deterministic.
117    FFTTiling,
118    /// Winograd  Transform
119    Winograd,
120    /// Winograd  Transform Non-Fused
121    WinogradNonFused,
122}
123
124impl ConvBackwardDataAlgo {
125    /// Check if algorithim should be chosen automatically.
126    pub fn is_auto(&self) -> bool {
127        match *self {
128            ConvBackwardDataAlgo::Auto => true,
129            _ => false,
130        }
131    }
132}
133
134/// Provides generic NN Operation Config functionality.
135///
136/// Needs to be implemented for Operation specific configurations.
137pub trait NNOperationConfig<F> {}
138
139/// Provides Convolution Config functionality.
140///
141/// Needs to be implemented for Operation specific configurations.
142pub trait ConvolutionConfig<F> {
143    /// Returns the largest workspace size in bytes needed
144    /// for any of the convolution operations.
145    fn workspace_size(&self) -> usize {
146        0
147    }
148}
149
150/// Provides Rnn Config functionality.
151///
152/// Needs to be implemented for Operation specific configurations.
153pub trait RnnConfig<F> {
154    /// Workspace Size - Overwritten by each plugin method except native, which doesn't require
155    /// a workspace size.
156    fn workspace_size(&self) -> usize {
157        0
158    }
159}
160
161/// Provides the functionality for a backend to support Neural Network related operations.
162pub trait NN<F> {
163    /// The Convolution Operation Config representation for this Plugin.
164    type CC: NNOperationConfig<F> + ConvolutionConfig<F>;
165    /// The LRN Operation Config representation for this Plugin.
166    type CLRN: NNOperationConfig<F>;
167    /// The Pooling Operation Config representation for this Plugin.
168    type CPOOL: NNOperationConfig<F>;
169    // /// The Activation Operation Config representation for this Plugin.
170    // type CACTI: NNOperationConfig<F>;
171    /// The Dropout Operation Config representation for this Plugin.
172    type CDROP: NNOperationConfig<F>;
173    /// The RNN Operation Config representation for this Plugin
174    type CRNN: NNOperationConfig<F> + RnnConfig<F>;
175
176    /// Initializes the Plugin.
177    fn init_nn();
178}
179
180/// Provides the functionality for a Backend to support Sigmoid operations.
181pub trait Sigmoid<F>: NN<F> {
182    /// Computes the [Sigmoid function][sigmoid] over the input Tensor `x`.
183    /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
184    ///
185    /// Saves the result to `result`.
186    fn sigmoid(
187        &self,
188        x: &SharedTensor<F>,
189        result: &mut SharedTensor<F>,
190    ) -> Result<(), crate::co::error::Error>;
191
192    /// Computes the gradient of a [Sigmoid function][sigmoid] over the input Tensor `x`.
193    /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
194    ///
195    /// Saves the result to `result_diff`.
196    fn sigmoid_grad(
197        &self,
198        x: &SharedTensor<F>,
199        x_diff: &SharedTensor<F>,
200        result: &SharedTensor<F>,
201        result_diff: &mut SharedTensor<F>,
202    ) -> Result<(), crate::co::error::Error>;
203}
204
205/// Provides the functionality for pointwise Sigmoid operations (overwrites the input with the result of the operation).
206pub trait SigmoidPointwise<F>: NN<F> {
207    /// Computes the [Sigmoid function][sigmoid] over the input Tensor `x`.
208    /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
209    ///
210    /// Saves the result back to `x`.
211    ///
212    /// For a no-memory managed version see `sigmoid_pointwise_plain`.
213    fn sigmoid_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
214
215    /// Computes the gradient of a [Sigmoid function][sigmoid] over the input Tensor `x`.
216    /// [sigmoid]: https://en.wikipedia.org/wiki/Sigmoid_function
217    ///
218    /// Saves the result back to `x_diff`.
219    fn sigmoid_pointwise_grad(
220        &self,
221        x: &SharedTensor<F>,
222        x_diff: &mut SharedTensor<F>,
223    ) -> Result<(), crate::co::error::Error>;
224}
225
226/// Provides the functionality for a Backend to support ReLU operations.
227pub trait Relu<F>: NN<F> {
228    /// Computes the [Rectified linear units][relu] over the input Tensor `x`.
229    /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
230    ///
231    /// Saves the result to `result`.
232    fn relu(
233        &self,
234        x: &SharedTensor<F>,
235        result: &mut SharedTensor<F>,
236    ) -> Result<(), crate::co::error::Error>;
237
238    /// Computes the gradient of [ReLU][relu] over the input Tensor `x`.
239    /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
240    ///
241    /// Saves the result to `result_diff`.
242    fn relu_grad(
243        &self,
244        x: &SharedTensor<F>,
245        x_diff: &SharedTensor<F>,
246        result: &SharedTensor<F>,
247        result_diff: &mut SharedTensor<F>,
248    ) -> Result<(), crate::co::error::Error>;
249}
250
251/// Provides the functionality for pointwise ReLU operations (overwrites the input with the result of the operation).
252pub trait ReluPointwise<F>: NN<F> {
253    /// Computes the [Rectified linear units][relu] over the input Tensor `x`.
254    /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
255    ///
256    /// Saves the result back to `x`.
257    fn relu_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
258
259    /// Computes the gradient of [ReLU][relu] over the input Tensor `x`.
260    /// [relu]: https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
261    ///
262    /// Saves the result back to `x_diff`.
263    fn relu_pointwise_grad(
264        &self,
265        x: &SharedTensor<F>,
266        x_diff: &mut SharedTensor<F>,
267    ) -> Result<(), crate::co::error::Error>;
268}
269
270/// Provides the functionality for a Backend to support TanH operations.
271pub trait Tanh<F>: NN<F> {
272    /// Computes the [hyperbolic Tangent][tanh] over the input Tensor `x`.
273    /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
274    ///
275    /// Saves the result to `result`.
276    fn tanh(
277        &self,
278        x: &SharedTensor<F>,
279        result: &mut SharedTensor<F>,
280    ) -> Result<(), crate::co::error::Error>;
281
282    /// Computes the gradient of [hyperbolic Tangent][tanh] over the input Tensor `x`.
283    /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
284    ///
285    /// Saves the result to `result_diff`.
286    fn tanh_grad(
287        &self,
288        x: &SharedTensor<F>,
289        x_diff: &SharedTensor<F>,
290        result: &SharedTensor<F>,
291        result_diff: &mut SharedTensor<F>,
292    ) -> Result<(), crate::co::error::Error>;
293}
294
295/// Provides the functionality for pointwise ReLU operations (overwrites the input
296/// with the result of the operation).
297pub trait TanhPointwise<F>: NN<F> {
298    /// Computes the [hyperbolic Tangent][tanh] over the input Tensor `x`.
299    /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
300    ///
301    /// Saves the result back to `x`.
302    fn tanh_pointwise(&self, x: &mut SharedTensor<F>) -> Result<(), crate::co::error::Error>;
303
304    /// Computes the gradient of [tanh][tanh] over the input Tensor `x`.
305    /// [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function
306    ///
307    /// Saves the result back to `x_diff`.
308    fn tanh_pointwise_grad(
309        &self,
310        x: &SharedTensor<F>,
311        x_diff: &mut SharedTensor<F>,
312    ) -> Result<(), crate::co::error::Error>;
313}
314
315/// Provide the functionality for a Backend to support RNN operations
316pub trait Rnn<F>: NN<F> {
317    /// Create a RnnConfig
318    fn new_rnn_config(
319        &self,
320        src: &SharedTensor<F>,
321        dropout_probability: Option<f32>,
322        dropout_seed: Option<u64>,
323        sequence_length: i32,
324        network_mode: RnnNetworkMode,
325        input_mode: RnnInputMode,
326        direction_mode: DirectionMode,
327        algorithm: RnnAlgorithm,
328        hidden_size: i32,
329        num_layers: i32,
330        batch_size: i32,
331        // RC being RNNConfig
332    ) -> Result<Self::CRNN, crate::co::error::Error>;
333
334    /// Generate Weights for RNN
335    fn generate_rnn_weight_description(
336        &self,
337        rnn_config: &Self::CRNN,
338        batch_size: i32,
339        input_size: i32,
340    ) -> Result<Vec<usize>, crate::co::error::Error>;
341
342    /// Train a LSTM Network and Return Results
343    // TODO: Create alternate rnn_forward or alternate path to work with pretrained networks
344    /// # Arguments
345    /// * `weight_desc` Previously initialised FilterDescriptor for Weights
346    fn rnn_forward(
347        &self,
348        src: &SharedTensor<F>,
349        output: &mut SharedTensor<F>,
350        rnn_config: &Self::CRNN,
351        weight: &SharedTensor<F>,
352        workspace: &mut SharedTensor<u8>,
353    ) -> Result<(), crate::co::error::Error>;
354
355    /// Calculates RNN Gradients for Input/Hidden/Cell
356    fn rnn_backward_data(
357        &self,
358        src: &SharedTensor<F>,
359        src_gradient: &mut SharedTensor<F>,
360        output: &SharedTensor<F>,
361        output_gradient: &SharedTensor<F>,
362        rnn_config: &Self::CRNN,
363        weight: &SharedTensor<F>,
364        workspace: &mut SharedTensor<u8>,
365    ) -> Result<(), crate::co::error::Error>;
366
367    /// Calculates RNN Gradients for Weights
368    fn rnn_backward_weights(
369        &self,
370        src: &SharedTensor<F>,
371        output: &SharedTensor<F>,
372        filter: &mut SharedTensor<F>,
373        rnn_config: &Self::CRNN,
374        workspace: &mut SharedTensor<u8>,
375    ) -> Result<(), crate::co::error::Error>;
376}
377
378#[derive(Debug, Copy, Clone, PartialEq, Eq)]
379/// Network Type for RNN Networks [cudnnRNNMOde_t][1]
380/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNMode_t
381pub enum RnnNetworkMode {
382    /// CUDNN_RNN_RELU - Single gate RNN with a ReLU activation function
383    ReLU,
384    /// Single-gate RNN with a tanh activation function
385    Tanh,
386    /// Four-gate LSTM Network with no peephole connection
387    LSTM,
388    /// Three-gate network with Gated Recurrent Units
389    GRU,
390}
391
392impl std::fmt::Display for RnnNetworkMode {
393    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
394        let result = match self {
395            RnnNetworkMode::ReLU => "RelU",
396            RnnNetworkMode::Tanh => "Tanh",
397            RnnNetworkMode::LSTM => "LSTM",
398            RnnNetworkMode::GRU => "GRU",
399        };
400        write!(f, "{}", result)
401    }
402}
403
404impl RnnNetworkMode {
405    /// Convert RnnNetworkMode to String Representation
406    pub fn from_string(input: &str) -> Result<Self, &str> {
407        match input {
408            "GRU" => Ok(RnnNetworkMode::GRU),
409            "LSTM" => Ok(RnnNetworkMode::LSTM),
410            "ReLU" => Ok(RnnNetworkMode::ReLU),
411            "Tanh" => Ok(RnnNetworkMode::Tanh),
412            _ => Err("Unknown RnnType used - variants are GRU, LSTM, ReLU, and Tanhd"),
413        }
414    }
415}
416
417#[derive(Debug, Copy, Clone, PartialEq, Eq)]
418/// Input Modes for RNN [cudnnRNNInputMode_t][1]
419/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNInputMode_t
420pub enum RnnInputMode {
421    /// CUDNN_LINEAR_INPUT - A biased matrix multiplication is performed at the input of the first
422    /// recurrent layer
423    LinearInput,
424    /// CUDNN_SKIP_INPUT - No operation is performed at the input of the first recurrent layer -
425    /// if this is used then the leading dimension of the input tensor must be equal to the hidden
426    /// state size of the network.
427    SkipInput,
428}
429
430impl std::fmt::Display for RnnInputMode {
431    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
432        let result = match self {
433            RnnInputMode::LinearInput => "LinearInput",
434            RnnInputMode::SkipInput => "SkipInput",
435        };
436        write!(f, "{}", result)
437    }
438}
439
440impl RnnInputMode {
441    /// Convert to RnnInputMode from String Representation
442    pub fn from_string(input: &str) -> Result<Self, &str> {
443        match input {
444            "LinearInput" => Ok(RnnInputMode::LinearInput),
445            "SkipInput" => Ok(RnnInputMode::SkipInput),
446            _ => Err("Unknown RnnInputMode used - variants are LinearInput, SkipInput"),
447        }
448    }
449}
450
451#[derive(Debug, Copy, Clone, PartialEq, Eq)]
452/// Direction Mode for RNN [cudnnDirectionMode_t][1]
453/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnDirectionMode_t
454pub enum DirectionMode {
455    /// CUDNN_UNIDIRECTIONAL - The network iterates from first to last
456    UniDirectional,
457    /// CUDNN_BIDIRECTION - Concats recurrent output of First -> Last && Last -> First
458    BiDirectional,
459}
460
461impl std::fmt::Display for DirectionMode {
462    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
463        let result = match self {
464            DirectionMode::UniDirectional => "UniDirectional",
465            DirectionMode::BiDirectional => "BiDirectional",
466        };
467        write!(f, "{}", result)
468    }
469}
470
471impl DirectionMode {
472    /// Convert to DirectionMode from String Representation
473    pub fn from_string(input: &str) -> Result<Self, &str> {
474        match input {
475            "UniDirectional" => Ok(DirectionMode::UniDirectional),
476            "BiDirectional" => Ok(DirectionMode::BiDirectional),
477            _ => Err("Unknown DirectionMode used - variants are UniDirectional, BiDirectional"),
478        }
479    }
480}
481
482#[derive(Debug, Copy, Clone, PartialEq, Eq)]
483/// Algorithm for RNN [cudnnRNNAlgo_t][1]
484/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNAlgo_t
485///
486/// Persist Static requires v6+
487pub enum RnnAlgorithm {
488    /// Sequence of Operations for each RNN Layer
489    Standard,
490    /// Uses a Persistent Kernel - fast when the first D of the input is small
491    PersistStatic,
492    /// RNN parts use a persistent kernel. Fast when the first dimension is small, and when it can
493    /// reuse plans in repeated calls.
494    PersistDynamic,
495    /// Count - Cannot find in docs but is in Generated - FIXME
496    Count,
497}
498
499impl std::fmt::Display for RnnAlgorithm {
500    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
501        let result = match self {
502            RnnAlgorithm::Standard => "Standard",
503            RnnAlgorithm::PersistStatic => "PersistStatic",
504            RnnAlgorithm::PersistDynamic => "PersistDynamic",
505            RnnAlgorithm::Count => unreachable!(),
506        };
507        write!(f, "{}", result)
508    }
509}
510
511impl RnnAlgorithm {
512    /// Convert to RnnAlgorithm from String Representation
513    fn from_string(input: &str) -> Result<Self, &str> {
514        match input {
515            "Standard" => Ok(RnnAlgorithm::Standard),
516            "PersistStatic" => Ok(RnnAlgorithm::PersistStatic),
517            "PersistDynamic" => Ok(RnnAlgorithm::PersistDynamic),
518            _ => Err(
519                "Unknown RnnAlgorithm used - variants are Standard, PersistStatic, PersistDynamic",
520            ),
521        }
522    }
523}
524
525#[derive(Debug, Copy, Clone)]
526/// Enables/Disables the padded input/output [cudnnRNNPaddingMode_t][1]
527/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNPaddingMode_t
528pub enum RnnPaddingMode {
529    /// Padding disabled
530    Disabled,
531    /// Padding enabled
532    Enabled,
533}
534
535#[derive(Debug, Copy, Clone)]
536/// Indicate if Tensor Core Operations are permitted [cudnnMathType_t][1]
537/// [1]: https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnMathType_t
538pub enum MathType {
539    /// No Tensor Core ops
540    Default,
541    /// Uses Tensor Core ops
542    TensorOPMath,
543    /// Uses FP32 Tensors for input/output
544    TensorOPMathAllowConversion,
545}
546
547/// Provides the functionality for a Backend to support Convolution operations.
548pub trait Convolution<F>: NN<F> {
549    /// Creates a new ConvolutionConfig, which needs to be passed to further
550    /// convolution Operations.
551    fn new_convolution_config(
552        &self,
553        src: &SharedTensor<F>,
554        dest: &SharedTensor<F>,
555        filter: &SharedTensor<F>,
556        algo_fwd: ConvForwardAlgo,
557        algo_bwd_filter: ConvBackwardFilterAlgo,
558        algo_bwd_data: ConvBackwardDataAlgo,
559        stride: &[i32],
560        zero_padding: &[i32],
561    ) -> Result<Self::CC, crate::co::error::Error>;
562
563    /// Computes a [CNN convolution][convolution] over the input Tensor `x`.
564    /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
565    ///
566    /// Saves the result to `result`.
567    fn convolution(
568        &self,
569        filter: &SharedTensor<F>,
570        x: &SharedTensor<F>,
571        result: &mut SharedTensor<F>,
572        workspace: &mut SharedTensor<u8>,
573        config: &Self::CC,
574    ) -> Result<(), crate::co::error::Error>;
575
576    /// Computes the gradient of a [CNN convolution][convolution] with respect to the filter.
577    /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
578    ///
579    /// Saves the result to `filter_diff`.
580    fn convolution_grad_filter(
581        &self,
582        src_data: &SharedTensor<F>,
583        dest_diff: &SharedTensor<F>,
584        filter_diff: &mut SharedTensor<F>,
585        workspace: &mut SharedTensor<u8>,
586        config: &Self::CC,
587    ) -> Result<(), crate::co::error::Error>;
588
589    /// Computes the gradient of a [CNN convolution][convolution] over the input
590    /// Tensor `x` with respect to the data.
591    /// [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
592    ///
593    /// Saves the result to `result_diff`.
594    fn convolution_grad_data(
595        &self,
596        filter: &SharedTensor<F>,
597        x_diff: &SharedTensor<F>,
598        result_diff: &mut SharedTensor<F>,
599        workspace: &mut SharedTensor<u8>,
600        config: &Self::CC,
601    ) -> Result<(), crate::co::error::Error>;
602
603    // /// Computes the backward Convolution function w.r.t the bias.
604    // ///
605    // /// Writes the result of the computation to `bias_data`.
606    // pub fn convolution_backward_bias<T>(
607    //     &self,
608    //     dest_grad_desc: &TensorDescriptor,
609    //     dest_grad_data: *const ::libc::c_void,
610    //     bias_desc: &TensorDescriptor,
611    //     bias_data: *mut ::libc::c_void,
612    //     scale: ScalParams<T>,
613    // }
614    //
615    // /// Computes the backward Convolution function w.r.t the filter.
616    // ///
617    // /// Writes the result of the computation to `filter_data`.
618    // pub fn convolution_backward_filter<T>(
619    //     &self,
620    //     conv_config: &ConvolutionConfig,
621    //     src_desc: &TensorDescriptor,
622    //     src_data: *const ::libc::c_void,
623    //     dest_grad_desc: &TensorDescriptor,
624    //     dest_grad_data: *const ::libc::c_void,
625    //     filter_data: *mut ::libc::c_void,
626    //     scale: ScalParams<T>,
627    // }
628}
629
630/// Provides the functionality for a Backend to support Softmax operations.
631pub trait Softmax<F>: NN<F> {
632    /// Computes a [Softmax][softmax] over the input Tensor `x`.
633    /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function
634    ///
635    /// Saves the result to `result`.
636    fn softmax(
637        &self,
638        x: &SharedTensor<F>,
639        result: &mut SharedTensor<F>,
640    ) -> Result<(), crate::co::error::Error>;
641
642    /// Computes the gradient of a [Softmax][softmax] over the input Tensor `x`.
643    /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function
644    ///
645    /// Saves the result to `result_diff`.
646    fn softmax_grad(
647        &self,
648        x: &SharedTensor<F>,
649        x_diff: &SharedTensor<F>,
650        result_diff: &mut SharedTensor<F>,
651    ) -> Result<(), crate::co::error::Error>;
652}
653
654/// Provides the functionality for a Backend to support LogSoftmax operations.
655pub trait LogSoftmax<F>: NN<F> {
656    /// Computes a logarithmic softmax over the input Tensor `x`.
657    ///
658    /// Saves the result to `result`.
659    fn log_softmax(
660        &self,
661        x: &SharedTensor<F>,
662        result: &mut SharedTensor<F>,
663    ) -> Result<(), crate::co::error::Error>;
664
665    /// Computes the gradient of a logarithmic softmax over the input Tensor `x`.
666    ///
667    /// Saves the result to `result_diff`.
668    fn log_softmax_grad(
669        &self,
670        x: &SharedTensor<F>,
671        x_diff: &SharedTensor<F>,
672        result_diff: &mut SharedTensor<F>,
673    ) -> Result<(), crate::co::error::Error>;
674}
675
676/// Provides the functionality for a Backend to support Local Response Normalization operations.
677pub trait LRN<F>: NN<F> {
678    /// Creates a new (Local Response Normalization) LRNConfig, which needs to be
679    /// passed to further LRN Operations.
680    fn new_lrn_config(
681        &self,
682        n: u32,
683        alpha: f64,
684        beta: f64,
685        k: f64,
686    ) -> Result<Self::CLRN, crate::co::error::Error>;
687
688    /// Computes a [LRN][lrn] over the input Tensor `x`.
689    /// [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network
690    ///
691    /// Saves the result to `result`.
692    fn lrn(
693        &self,
694        x: &SharedTensor<F>,
695        result: &mut SharedTensor<F>,
696        config: &Self::CLRN,
697    ) -> Result<(), crate::co::error::Error>;
698
699    /// Computes the gradient of a [LRN][lrn] over the input Tensor `x`.
700    /// [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network
701    ///
702    /// Saves the result to `result_diff`.
703    fn lrn_grad(
704        &self,
705        x: &SharedTensor<F>,
706        x_diff: &SharedTensor<F>,
707        result: &SharedTensor<F>,
708        result_diff: &mut SharedTensor<F>,
709        config: &Self::CLRN,
710    ) -> Result<(), crate::co::error::Error>;
711}
712
713/// Provides the functionality for a Backend to support Pooling operations.
714pub trait Pooling<F>: NN<F> {
715    /// Creates a new PoolingConfig, which needs to be passed to further pooling Operations.
716    fn new_pooling_config(
717        &self,
718        window: &[i32],
719        stride: &[i32],
720        padding: &[i32],
721    ) -> Result<Self::CPOOL, crate::co::error::Error>;
722
723    /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
724    /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
725    ///
726    /// Saves the result to `result`.
727    fn pooling_max(
728        &self,
729        x: &SharedTensor<F>,
730        result: &mut SharedTensor<F>,
731        config: &Self::CPOOL,
732    ) -> Result<(), crate::co::error::Error>;
733
734    /// Computes the gradient of [max Pooling][pooling] over the input Tensor `x`.
735    /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
736    ///
737    /// Saves the result to `result_diff`.
738    fn pooling_max_grad(
739        &self,
740        x: &SharedTensor<F>,
741        x_diff: &SharedTensor<F>,
742        result: &SharedTensor<F>,
743        result_diff: &mut SharedTensor<F>,
744        config: &Self::CPOOL,
745    ) -> Result<(), crate::co::error::Error>;
746
747    /// Computes non-linear down-sampling ([average Pooling][pooling]) over the input Tensor `x`.
748    /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
749    ///
750    /// Saves the result to `result`.
751    fn pooling_avg(
752        &self,
753        x: &SharedTensor<F>,
754        result: &mut SharedTensor<F>,
755        config: &Self::CPOOL,
756    ) -> Result<(), crate::co::error::Error>;
757
758    /// Computes the gradient of [average Pooling][pooling] over the input Tensor `x`.
759    /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
760    ///
761    /// Saves the result to `result_diff`.
762    fn pooling_avg_grad(
763        &self,
764        x: &SharedTensor<F>,
765        x_diff: &SharedTensor<F>,
766        result: &SharedTensor<F>,
767        result_diff: &mut SharedTensor<F>,
768        config: &Self::CPOOL,
769    ) -> Result<(), crate::co::error::Error>;
770}
771
772/// Provides the functionality for a Backend to support Dropout operations.
773pub trait Dropout<F>: NN<F> {
774    /// Creates a new DropoutConfig, which needs to be passed to further dropout Operations.
775    fn new_dropout_config(
776        &self,
777        dropout: f32,
778        seed: u64,
779    ) -> Result<Self::CDROP, crate::co::error::Error>;
780
781    /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
782    /// [pooling]: https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer
783    ///
784    /// Saves the result to `result`.
785    fn dropout(
786        &self,
787        x: &SharedTensor<F>,
788        result: &mut SharedTensor<F>,
789        config: &Self::CDROP,
790    ) -> Result<(), crate::co::error::Error>;
791
792    /// Computes non-linear down-sampling ([max Pooling][pooling]) over the input Tensor `x`.
793    /// [pooling]: https://en.wikipedia.org/wiki/Dropout_(neural_networks)
794    ///
795    /// Saves the result to `result`.
796    fn dropout_grad(
797        &self,
798        x: &SharedTensor<F>,
799        x_diff: &SharedTensor<F>,
800        result: &SharedTensor<F>,
801        result_diff: &mut SharedTensor<F>,
802        config: &Self::CDROP,
803    ) -> Result<(), crate::co::error::Error>;
804}