1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// No sync because of https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#thread-safety
use super::{result, result::CudnnError, sys};
use crate::driver::{CudaDevice, CudaStream};
use std::sync::Arc;

#[derive(Debug)]
pub struct Cudnn {
    pub(crate) handle: sys::cudnnHandle_t,
    pub(crate) device: Arc<CudaDevice>,
}

impl Cudnn {
    /// Creates a new cudnn handle and sets the stream to the `device`'s stream.
    pub fn new(device: Arc<CudaDevice>) -> Result<Self, CudnnError> {
        let handle = result::create_handle()?;
        unsafe { result::set_stream(handle, device.stream as *mut _) }?;
        Ok(Self { handle, device })
    }

    /// Sets the handle's current to either the stream specified, or the device's default work
    /// stream.
    ///
    /// # Safety
    /// This is unsafe because you can end up scheduling multiple concurrent kernels that all
    /// write to the same memory address.
    pub unsafe fn set_stream(&self, opt_stream: Option<&CudaStream>) -> Result<(), CudnnError> {
        match opt_stream {
            Some(s) => result::set_stream(self.handle, s.stream as *mut _),
            None => result::set_stream(self.handle, self.device.stream as *mut _),
        }
    }
}

impl Drop for Cudnn {
    fn drop(&mut self) {
        let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
        if !handle.is_null() {
            unsafe { result::destroy_handle(handle) }.unwrap();
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::driver::CudaDevice;

    use super::Cudnn;

    #[test]
    fn create_and_drop() {
        let _handle = Cudnn::new(CudaDevice::new(0).unwrap()).unwrap();
    }
}