1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use crate::*;
use num_enum::TryFromPrimitive;
use std::{ptr::null_mut, rc::Rc};

/// A CUDA application context.
/// To start interacting with a device, you want to [`Context::enter`]
#[derive(Debug)]
pub struct Context {
    pub(crate) inner: *mut sys::CUctx_st,
}

impl Context {
    /// Creates a new [`Context`] for a given [`Device`]
    pub fn new(device: &Device) -> CudaResult<Context> {
        let mut inner = null_mut();
        cuda_error(unsafe {
            sys::cuCtxCreate_v2(
                &mut inner as *mut _,
                sys::CUctx_flags_enum_CU_CTX_SCHED_BLOCKING_SYNC,
                device.handle,
            )
        })?;
        Ok(Context { inner })
    }

    /// Gets the API version of the [`Context`].
    /// This is not the compute capability of the device and probably not what you are looking for. See [`Device::compute_capability`]
    pub fn version(&self) -> CudaResult<CudaVersion> {
        let mut out = 0u32;
        cuda_error(unsafe { sys::cuCtxGetApiVersion(self.inner, &mut out as *mut u32) })?;
        Ok(out.into())
    }

    /// Synchronize a [`Context`], running all active handles to completion
    pub fn synchronize(&self) -> CudaResult<()> {
        cuda_error(unsafe { sys::cuCtxSynchronize() })
    }

    /// Set a CUDA context limit
    pub fn set_limit(&mut self, limit: LimitType, value: u64) -> CudaResult<()> {
        cuda_error(unsafe { sys::cuCtxSetLimit(limit as u32, value) })
    }

    /// Get a CUDA context limit
    pub fn get_limit(&self, limit: LimitType) -> CudaResult<u64> {
        let mut out = 0u64;
        cuda_error(unsafe { sys::cuCtxGetLimit(&mut out as *mut u64, limit as u32) })?;
        Ok(out)
    }

    /// Enter a [`Context`], consuming a mutable reference to the context, and allowing thread-local operations to happen.
    pub fn enter<'a>(&'a mut self) -> CudaResult<Rc<Handle<'a>>> {
        cuda_error(unsafe { sys::cuCtxSetCurrent(self.inner) })?;
        Ok(Rc::new(Handle {
            context: self,
            // async_stream_pool: RefCell::new(vec![]),
        }))
    }
}

impl Drop for Context {
    fn drop(&mut self) {
        if let Err(e) = cuda_error(unsafe { sys::cuCtxDestroy_v2(self.inner) }) {
            eprintln!("CUDA: failed to destroy cuda context: {:?}", e);
        }
    }
}

/// A CUDA [`Context`] handle for executing thread-local operations.
pub struct Handle<'a> {
    pub(crate) context: &'a mut Context,
    // async_stream_pool: RefCell<Vec<Stream<'a>>>,
}

impl<'a> Handle<'a> {
    /// Get an immutable reference to the source context.
    pub fn context(&self) -> &Context {
        &self.context
    }

    // pub(crate) fn get_async_stream(self: &Rc<Handle<'a>>) -> CudaResult<Stream<'a>> {
    //     let mut pool = self.async_stream_pool.borrow_mut();
    //     if pool.is_empty() {
    //         Stream::new(self)
    //     } else {
    //         Ok(pool.pop().unwrap())
    //     }
    // }

    // pub(crate) fn reset_async_stream(self: &Rc<Handle<'a>>, stream: Stream<'a>) {
    //     let mut pool = self.async_stream_pool.borrow_mut();
    //     pool.push(stream);
    // }
}

impl<'a> Drop for Handle<'a> {
    fn drop(&mut self) {
        if let Err(e) = cuda_error(unsafe { sys::cuCtxSetCurrent(null_mut()) }) {
            eprintln!("CUDA: error dropping context handle: {:?}", e);
        }
    }
}

/// Context limit types
#[derive(Clone, Copy, Debug, TryFromPrimitive)]
#[repr(u32)]
pub enum LimitType {
    /// GPU thread stack size
    StackSize = 0x00,
    /// GPU printf FIFO size
    PrintfFifoSize = 0x01,
    /// GPU malloc heap size
    MallocHeapSize = 0x02,
    /// GPU device runtime launch synchronize depth
    DevRuntimeSyncDepth = 0x03,
    /// GPU device runtime pending launch count
    DevRuntimePendingLaunchCount = 0x04,
    /// A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint
    MaxL2FetchGranularity = 0x05,
    /// A size in bytes for L2 persisting lines cache size
    PersistingL2CacheSize = 0x06,
}