use std::marker::PhantomData;
use ffi::*;
#[derive(Debug, Copy, Clone)]
pub enum Direction {
Fr,
Bc
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct ConvolutionConfig {
forward_algo: cudnnConvolutionFwdAlgo_t,
backward_algo: cudnnConvolutionBwdDataAlgo_t,
forward_workspace: *mut ::libc::c_void,
forward_workspace_size: usize,
backward_workspace: *mut ::libc::c_void,
backward_workspace_size: usize,
}
impl ConvolutionConfig {
pub fn new(
algo_fwd: cudnnConvolutionFwdAlgo_t,
workspace_fwd: *mut ::libc::c_void,
workspace_size_fwd: usize,
algo_bwd: cudnnConvolutionBwdDataAlgo_t,
workspace_bwd: *mut ::libc::c_void,
workspace_size_bwd: usize,
) -> ConvolutionConfig {
ConvolutionConfig {
forward_algo: algo_fwd,
backward_algo: algo_bwd,
forward_workspace: workspace_fwd,
forward_workspace_size: workspace_size_fwd,
backward_workspace: workspace_bwd,
backward_workspace_size: workspace_size_bwd,
}
}
pub fn forward_algo(&self) -> &cudnnConvolutionFwdAlgo_t {
&self.forward_algo
}
pub fn forward_workspace(&self) -> &*mut ::libc::c_void {
&self.forward_workspace
}
pub fn forward_workspace_size(&self) -> &usize {
&self.forward_workspace_size
}
pub fn backward_algo(&self) -> &cudnnConvolutionBwdDataAlgo_t {
&self.backward_algo
}
pub fn backward_workspace(&self) -> &*mut ::libc::c_void {
&self.backward_workspace
}
pub fn backward_workspace_size(&self) -> &usize {
&self.backward_workspace_size
}
}
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct ScalParams<T> {
pub a: *const ::libc::c_void,
pub b: *const ::libc::c_void,
scal_type: PhantomData<T>,
}
impl ScalParams<f32> {
pub fn default() -> ScalParams<f32> {
let alpha_ptr: *const ::libc::c_void = *&[1.0f32].as_ptr() as *const ::libc::c_void;
let beta_ptr: *const ::libc::c_void = *&[0.0f32].as_ptr() as *const ::libc::c_void;
ScalParams {
a: alpha_ptr,
b: beta_ptr,
scal_type: PhantomData,
}
}
}
impl ScalParams<f64> {
pub fn default() -> ScalParams<f64> {
let alpha_ptr: *const ::libc::c_void = *&[1.0f64].as_ptr() as *const ::libc::c_void;
let beta_ptr: *const ::libc::c_void = *&[0.0f64].as_ptr() as *const ::libc::c_void;
ScalParams {
a: alpha_ptr,
b: beta_ptr,
scal_type: PhantomData,
}
}
}