use super::{ConvolutionDescriptor, NormalizationDescriptor, FilterDescriptor, PoolingDescriptor};
use ffi::*;
use std::marker::PhantomData;
#[derive(Debug, Copy, Clone)]
pub enum DataType {
Float,
Double,
Half,
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct ConvolutionConfig {
forward_algo: cudnnConvolutionFwdAlgo_t,
backward_filter_algo: cudnnConvolutionBwdFilterAlgo_t,
backward_data_algo: cudnnConvolutionBwdDataAlgo_t,
forward_workspace_size: usize,
backward_filter_workspace_size: usize,
backward_data_workspace_size: usize,
conv_desc: ConvolutionDescriptor,
filter_desc: FilterDescriptor,
}
impl ConvolutionConfig {
pub fn new(
algo_fwd: cudnnConvolutionFwdAlgo_t,
workspace_size_fwd: usize,
algo_filter_bwd: cudnnConvolutionBwdFilterAlgo_t,
workspace_filter_size_bwd: usize,
algo_data_bwd: cudnnConvolutionBwdDataAlgo_t,
workspace_data_size_bwd: usize,
conv_desc: ConvolutionDescriptor,
filter_desc: FilterDescriptor,
) -> ConvolutionConfig {
ConvolutionConfig {
forward_algo: algo_fwd,
forward_workspace_size: workspace_size_fwd,
backward_filter_algo: algo_filter_bwd,
backward_filter_workspace_size: workspace_filter_size_bwd,
backward_data_algo: algo_data_bwd,
backward_data_workspace_size: workspace_data_size_bwd,
conv_desc: conv_desc,
filter_desc: filter_desc,
}
}
pub fn largest_workspace_size(&self) -> &usize {
if self.backward_data_workspace_size() >= self.backward_filter_workspace_size() && self.backward_data_workspace_size() >= self.forward_workspace_size() {
self.backward_data_workspace_size()
} else if self.backward_filter_workspace_size() >= self.backward_data_workspace_size() && self.backward_filter_workspace_size() >= self.forward_workspace_size() {
self.backward_filter_workspace_size()
} else {
self.forward_workspace_size()
}
}
pub fn forward_algo(&self) -> &cudnnConvolutionFwdAlgo_t {
&self.forward_algo
}
pub fn forward_workspace_size(&self) -> &usize {
&self.forward_workspace_size
}
pub fn backward_filter_algo(&self) -> &cudnnConvolutionBwdFilterAlgo_t {
&self.backward_filter_algo
}
pub fn backward_filter_workspace_size(&self) -> &usize {
&self.backward_filter_workspace_size
}
pub fn backward_data_algo(&self) -> &cudnnConvolutionBwdDataAlgo_t {
&self.backward_data_algo
}
pub fn backward_data_workspace_size(&self) -> &usize {
&self.backward_data_workspace_size
}
pub fn conv_desc(&self) -> &ConvolutionDescriptor {
&self.conv_desc
}
pub fn filter_desc(&self) -> &FilterDescriptor {
&self.filter_desc
}
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct NormalizationConfig {
lrn_desc: NormalizationDescriptor,
}
impl NormalizationConfig {
pub fn new(lrn_desc: NormalizationDescriptor) -> NormalizationConfig {
NormalizationConfig {
lrn_desc: lrn_desc,
}
}
pub fn lrn_desc(&self) -> &NormalizationDescriptor {
&self.lrn_desc
}
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct PoolingConfig {
pooling_avg_desc: PoolingDescriptor,
pooling_max_desc: PoolingDescriptor,
}
impl PoolingConfig {
pub fn new(
pooling_avg_desc: PoolingDescriptor,
pooling_max_desc: PoolingDescriptor,
) -> PoolingConfig {
PoolingConfig {
pooling_avg_desc: pooling_avg_desc,
pooling_max_desc: pooling_max_desc,
}
}
pub fn pooling_avg_desc(&self) -> &PoolingDescriptor {
&self.pooling_avg_desc
}
pub fn pooling_max_desc(&self) -> &PoolingDescriptor {
&self.pooling_max_desc
}
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct ScalParams<T> {
pub a: *const ::libc::c_void,
pub b: *const ::libc::c_void,
scal_type: PhantomData<T>,
}
impl Default for ScalParams<f32> {
fn default() -> ScalParams<f32> {
let alpha_ptr: *const ::libc::c_void = *&[1.0f32].as_ptr() as *const ::libc::c_void;
let beta_ptr: *const ::libc::c_void = *&[0.0f32].as_ptr() as *const ::libc::c_void;
ScalParams {
a: alpha_ptr,
b: beta_ptr,
scal_type: PhantomData,
}
}
}
impl Default for ScalParams<f64> {
fn default() -> ScalParams<f64> {
let alpha_ptr: *const ::libc::c_void = *&[1.0f64].as_ptr() as *const ::libc::c_void;
let beta_ptr: *const ::libc::c_void = *&[0.0f64].as_ptr() as *const ::libc::c_void;
ScalParams {
a: alpha_ptr,
b: beta_ptr,
scal_type: PhantomData,
}
}
}