1use super::{
4 ActivationDescriptor, ConvolutionDescriptor, DropoutDescriptor, FilterDescriptor,
5 NormalizationDescriptor, PoolingDescriptor, RnnDescriptor,
6};
7use crate::cuda::CudaDeviceMemory;
8
9use crate::ffi::*;
10
11use num::traits::*;
12
13#[derive(Debug, Copy, Clone)]
14pub enum DataType {
16 Float,
18 Double,
20 Half,
22}
23
24pub trait DataTypeInfo {
26 fn cudnn_data_type() -> DataType;
28 fn size() -> usize;
29}
30impl DataTypeInfo for f32 {
31 fn cudnn_data_type() -> DataType {
32 DataType::Float
33 }
34 fn size() -> usize {
35 4_usize
36 }
37}
38impl DataTypeInfo for f64 {
39 fn cudnn_data_type() -> DataType {
40 DataType::Double
41 }
42 fn size() -> usize {
43 8_usize
44 }
45}
46#[allow(missing_debug_implementations, missing_copy_implementations)]
49pub struct ConvolutionConfig {
54 forward_algo: cudnnConvolutionFwdAlgo_t,
55 backward_filter_algo: cudnnConvolutionBwdFilterAlgo_t,
56 backward_data_algo: cudnnConvolutionBwdDataAlgo_t,
57 forward_workspace_size: usize,
58 backward_filter_workspace_size: usize,
59 backward_data_workspace_size: usize,
60 conv_desc: ConvolutionDescriptor,
61 filter_desc: FilterDescriptor,
62}
63
64impl ConvolutionConfig {
65 #[allow(clippy::too_many_arguments)]
67 pub fn new(
68 algo_fwd: cudnnConvolutionFwdAlgo_t,
69 workspace_size_fwd: usize,
70 algo_filter_bwd: cudnnConvolutionBwdFilterAlgo_t,
71 workspace_filter_size_bwd: usize,
72 algo_data_bwd: cudnnConvolutionBwdDataAlgo_t,
73 workspace_data_size_bwd: usize,
74 conv_desc: ConvolutionDescriptor,
75 filter_desc: FilterDescriptor,
76 ) -> ConvolutionConfig {
77 ConvolutionConfig {
78 forward_algo: algo_fwd,
79 forward_workspace_size: workspace_size_fwd,
80 backward_filter_algo: algo_filter_bwd,
81 backward_filter_workspace_size: workspace_filter_size_bwd,
82 backward_data_algo: algo_data_bwd,
83 backward_data_workspace_size: workspace_data_size_bwd,
84 conv_desc,
85 filter_desc,
86 }
87 }
88
89 pub fn largest_workspace_size(&self) -> usize {
93 if self.backward_data_workspace_size() >= self.backward_filter_workspace_size()
94 && self.backward_data_workspace_size() >= self.forward_workspace_size()
95 {
96 self.backward_data_workspace_size()
97 } else if self.backward_filter_workspace_size() >= self.backward_data_workspace_size()
98 && self.backward_filter_workspace_size() >= self.forward_workspace_size()
99 {
100 self.backward_filter_workspace_size()
101 } else {
102 self.forward_workspace_size()
103 }
104 }
105
106 pub fn forward_algo(&self) -> &cudnnConvolutionFwdAlgo_t {
108 &self.forward_algo
109 }
110
111 pub fn forward_workspace_size(&self) -> usize {
113 self.forward_workspace_size
114 }
115
116 pub fn backward_filter_algo(&self) -> &cudnnConvolutionBwdFilterAlgo_t {
118 &self.backward_filter_algo
119 }
120
121 pub fn backward_filter_workspace_size(&self) -> usize {
123 self.backward_filter_workspace_size
124 }
125
126 pub fn backward_data_algo(&self) -> &cudnnConvolutionBwdDataAlgo_t {
128 &self.backward_data_algo
129 }
130
131 pub fn backward_data_workspace_size(&self) -> usize {
133 self.backward_data_workspace_size
134 }
135
136 pub fn conv_desc(&self) -> &ConvolutionDescriptor {
138 &self.conv_desc
139 }
140
141 pub fn filter_desc(&self) -> &FilterDescriptor {
143 &self.filter_desc
144 }
145}
146
147#[allow(missing_debug_implementations, missing_copy_implementations)]
148pub struct NormalizationConfig {
152 lrn_desc: NormalizationDescriptor,
153}
154
155impl NormalizationConfig {
156 pub fn new(lrn_desc: NormalizationDescriptor) -> NormalizationConfig {
158 NormalizationConfig { lrn_desc }
159 }
160
161 pub fn lrn_desc(&self) -> &NormalizationDescriptor {
163 &self.lrn_desc
164 }
165}
166
167#[allow(missing_debug_implementations, missing_copy_implementations)]
168pub struct PoolingConfig {
172 pooling_avg_desc: PoolingDescriptor,
173 pooling_max_desc: PoolingDescriptor,
174}
175
176impl PoolingConfig {
177 pub fn new(
179 pooling_avg_desc: PoolingDescriptor,
180 pooling_max_desc: PoolingDescriptor,
181 ) -> PoolingConfig {
182 PoolingConfig {
183 pooling_avg_desc,
184 pooling_max_desc,
185 }
186 }
187
188 pub fn pooling_avg_desc(&self) -> &PoolingDescriptor {
190 &self.pooling_avg_desc
191 }
192
193 pub fn pooling_max_desc(&self) -> &PoolingDescriptor {
195 &self.pooling_max_desc
196 }
197}
198
199#[allow(missing_debug_implementations, missing_copy_implementations)]
200pub struct ActivationConfig {
204 activation_sigmoid_desc: ActivationDescriptor,
205 activation_relu_desc: ActivationDescriptor,
206 activation_clipped_relu_desc: ActivationDescriptor,
207 activation_tanh_desc: ActivationDescriptor,
208}
209
210impl ActivationConfig {
211 pub fn new(
213 activation_sigmoid_desc: ActivationDescriptor,
214 activation_relu_desc: ActivationDescriptor,
215 activation_clipped_relu_desc: ActivationDescriptor,
216 activation_tanh_desc: ActivationDescriptor,
217 ) -> ActivationConfig {
218 ActivationConfig {
219 activation_sigmoid_desc,
220 activation_relu_desc,
221 activation_clipped_relu_desc,
222 activation_tanh_desc,
223 }
224 }
225
226 pub fn activation_sigmoid_desc(&self) -> &ActivationDescriptor {
228 &self.activation_sigmoid_desc
229 }
230 pub fn activation_relu_desc(&self) -> &ActivationDescriptor {
232 &self.activation_relu_desc
233 }
234 pub fn activation_clipped_relu_desc(&self) -> &ActivationDescriptor {
236 &self.activation_clipped_relu_desc
237 }
238 pub fn activation_tanh_desc(&self) -> &ActivationDescriptor {
240 &self.activation_tanh_desc
241 }
242}
243
244#[allow(missing_debug_implementations, missing_copy_implementations)]
245#[derive(Debug)]
249pub struct DropoutConfig {
250 dropout_desc: DropoutDescriptor,
251 reserve_space: CudaDeviceMemory,
252}
253
254impl DropoutConfig {
255 pub fn new(dropout_desc: DropoutDescriptor, reserve: CudaDeviceMemory) -> DropoutConfig {
257 DropoutConfig {
258 dropout_desc,
259 reserve_space: reserve,
260 }
261 }
262 pub fn dropout_desc(&self) -> &DropoutDescriptor {
264 &self.dropout_desc
265 }
266
267 pub fn take_mem(self) -> CudaDeviceMemory {
269 self.reserve_space
270 }
271
272 pub fn reserved_space(&self) -> &CudaDeviceMemory {
274 &self.reserve_space
275 }
276}
277
278#[allow(missing_debug_implementations, missing_copy_implementations)]
279pub struct RnnConfig {
304 rnn_desc: RnnDescriptor,
305 pub hidden_size: ::libc::c_int,
307 pub num_layers: ::libc::c_int,
309 pub sequence_length: ::libc::c_int,
311 dropout_desc: cudnnDropoutDescriptor_t,
312 input_mode: cudnnRNNInputMode_t,
313 direction_mode: cudnnDirectionMode_t,
314 rnn_mode: cudnnRNNMode_t,
315 algo: cudnnRNNAlgo_t,
316 data_type: cudnnDataType_t,
317 workspace_size: usize,
318 training_reserve_size: usize,
319 training_reserve: CudaDeviceMemory,
320}
321
322impl RnnConfig {
323 #[allow(clippy::too_many_arguments)]
325 pub fn new(
326 rnn_desc: RnnDescriptor,
327 hidden_size: i32,
328 num_layers: i32,
329 sequence_length: i32,
330 dropout_desc: cudnnDropoutDescriptor_t,
331 input_mode: cudnnRNNInputMode_t,
332 direction_mode: cudnnDirectionMode_t,
333 rnn_mode: cudnnRNNMode_t,
334 algo: cudnnRNNAlgo_t,
336 data_type: cudnnDataType_t,
337 workspace_size: usize,
338 training_reserve_size: usize,
339 training_reserve: CudaDeviceMemory,
340 ) -> RnnConfig {
341 RnnConfig {
342 rnn_desc,
343 hidden_size,
344 num_layers,
345 sequence_length,
346 dropout_desc,
347 input_mode,
348 direction_mode,
349 rnn_mode,
350 algo,
351 data_type,
352 workspace_size,
353 training_reserve_size,
354 training_reserve,
355 }
356 }
357
358 pub fn rnn_workspace_size(&self) -> usize {
360 self.workspace_size
361 }
362 pub fn largest_workspace_size(&self) -> usize {
364 self.rnn_workspace_size()
365 }
366 pub fn training_reserve_size(&self) -> usize {
368 self.training_reserve_size
369 }
370 pub fn training_reserve(&self) -> &CudaDeviceMemory {
372 &self.training_reserve
373 }
374
375 pub fn rnn_desc(&self) -> &RnnDescriptor {
377 &self.rnn_desc
378 }
379
380 pub fn sequence_length(&self) -> &i32 {
382 &self.sequence_length
383 }
384}
385
386#[allow(missing_debug_implementations, missing_copy_implementations)]
387pub struct ScalParams<T>
398where
399 T: Float + DataTypeInfo,
400{
401 pub a: T,
403 pub b: T,
405}
406
407impl<T> Default for ScalParams<T>
408where
409 T: Float + Zero + One + DataTypeInfo,
410{
411 fn default() -> ScalParams<T> {
413 ScalParams {
414 a: One::one(),
415 b: Zero::zero(),
416 }
417 }
418}