1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use super::{result, result::CudnnError, sys};
use crate::driver::{CudaDevice, CudaStream};
use std::sync::Arc;
#[derive(Debug)]
pub struct Cudnn {
pub(crate) handle: sys::cudnnHandle_t,
pub(crate) device: Arc<CudaDevice>,
}
impl Cudnn {
pub fn new(device: Arc<CudaDevice>) -> Result<Self, CudnnError> {
let handle = result::create_handle()?;
unsafe { result::set_stream(handle, device.stream as *mut _) }?;
Ok(Self { handle, device })
}
pub unsafe fn set_stream(&self, opt_stream: Option<&CudaStream>) -> Result<(), CudnnError> {
match opt_stream {
Some(s) => result::set_stream(self.handle, s.stream as *mut _),
None => result::set_stream(self.handle, self.device.stream as *mut _),
}
}
}
impl Drop for Cudnn {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::destroy_handle(handle) }.unwrap();
}
}
}
#[cfg(test)]
mod tests {
use crate::driver::CudaDevice;
use super::Cudnn;
#[test]
fn create_and_drop() {
let _handle = Cudnn::new(CudaDevice::new(0).unwrap()).unwrap();
}
}