1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
//! Defines a Pooling Descriptor.

use super::{Error, API};
use crate::ffi::*;

#[derive(Debug, Clone)]
/// Describes a Pooling Descriptor.
pub struct PoolingDescriptor {
    id: cudnnPoolingDescriptor_t,
}

impl Drop for PoolingDescriptor {
    #[allow(unused_must_use)]
    fn drop(&mut self) {
        API::destroy_pooling_descriptor(*self.id_c()).unwrap();
    }
}

impl PoolingDescriptor {
    /// Initializes a new CUDA cuDNN Pooling Descriptor.
    pub fn new(
        mode: cudnnPoolingMode_t,
        window: &[i32],
        padding: &[i32],
        stride: &[i32],
    ) -> Result<PoolingDescriptor, Error> {
        let generic_pooling_desc = API::create_pooling_descriptor()?;
        API::set_pooling_descriptor(
            generic_pooling_desc,
            mode,
            cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, // TODO check if this is sane to do
            window.len() as i32,
            window.as_ptr(),
            padding.as_ptr(),
            stride.as_ptr(),
        )?;

        Ok(PoolingDescriptor::from_c(generic_pooling_desc))
    }

    /// Initializes a new CUDA cuDNN PoolingDescriptor from its C type.
    pub fn from_c(id: cudnnPoolingDescriptor_t) -> PoolingDescriptor {
        PoolingDescriptor { id }
    }

    /// Returns the CUDA cuDNN Pooling Descriptor as its C type.
    pub fn id_c(&self) -> &cudnnPoolingDescriptor_t {
        &self.id
    }
}