use std::{ffi::CString, ptr, sync::Arc};
use singe_cuda_sys::driver;
use crate::{
device::Device,
error::{Error, Result},
jit::JitOptions,
library::Library,
module::{Module, ModuleImage},
nvrtc::{self, CompilationArtifact, OutputKind},
try_cuda,
types::Limit,
};
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ContextFlags: u32 {
const SCHEDULE_AUTO = driver::CUctx_flags::CU_CTX_SCHED_AUTO as _;
const SCHEDULE_SPIN = driver::CUctx_flags::CU_CTX_SCHED_SPIN as _;
const SCHEDULE_YIELD = driver::CUctx_flags::CU_CTX_SCHED_YIELD as _;
const SCHEDULE_BLOCKING_SYNC = driver::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC as _;
const MAP_HOST = driver::CUctx_flags::CU_CTX_MAP_HOST as _;
const LOCAL_MEMORY_RESIZE_TO_MAX = driver::CUctx_flags::CU_CTX_LMEM_RESIZE_TO_MAX as _;
const COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_COREDUMP_ENABLE as _;
const USER_COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_USER_COREDUMP_ENABLE as _;
const SYNC_MEMORY_OPERATIONS = driver::CUctx_flags::CU_CTX_SYNC_MEMOPS as _;
}
}
#[derive(Debug)]
pub struct Context {
handle: driver::CUcontext,
device: Device,
}
impl Context {
pub fn create() -> Result<Arc<Self>> {
Self::create_with_flags(ContextFlags::empty())
}
pub fn create_with_flags(flags: ContextFlags) -> Result<Arc<Self>> {
let device = Device::current()?;
Self::create_for_device_with_flags(device, flags)
}
pub fn create_for_device(device: Device) -> Result<Arc<Self>> {
Self::create_for_device_with_flags(device, ContextFlags::empty())
}
pub fn create_for_device_with_flags(device: Device, flags: ContextFlags) -> Result<Arc<Self>> {
unsafe {
try_cuda!(driver::cuInit(0))?;
let mut handle = ptr::null_mut();
try_cuda!(driver::cuCtxCreate_v4(
&raw mut handle,
ptr::null_mut(), flags.bits(),
device.id() as _,
))?;
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Arc::new(Self { handle, device }))
}
}
pub fn bind(&self) -> Result<()> {
unsafe {
let mut current_ctx = ptr::null_mut();
try_cuda!(driver::cuCtxGetCurrent(&raw mut current_ctx))?;
if current_ctx == self.as_raw() {
return Ok(());
}
try_cuda!(driver::cuCtxSetCurrent(self.as_raw()))?;
}
Ok(())
}
pub fn load_module(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Module> {
self.bind()?;
unsafe {
let mut module_handle = ptr::null_mut();
try_cuda!(driver::cuModuleLoadData(
&raw mut module_handle,
image.as_ptr() as _,
))?;
if module_handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Module::from_raw(module_handle, Arc::clone(self)))
}
}
pub fn unload_module(self: &Arc<Self>, module: Module) -> Result<()> {
drop(module);
Ok(())
}
pub fn load_module_with_options(
self: &Arc<Self>,
image: &ModuleImage<'_>,
mut jit_options: JitOptions<'_>,
) -> Result<Module> {
self.bind()?;
let mut jit_options = jit_options.build();
unsafe {
let mut module_handle = ptr::null_mut();
try_cuda!(driver::cuModuleLoadDataEx(
&raw mut module_handle,
image.as_ptr() as _,
jit_options.names.len() as _,
jit_options.names.as_mut_ptr() as _,
jit_options.values.as_mut_ptr() as _,
))?;
if module_handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Module::from_raw(module_handle, Arc::clone(self)))
}
}
pub fn load_nvrtc_module(
self: &Arc<Self>,
program: &nvrtc::Program,
output: OutputKind,
) -> Result<Module> {
self.load_nvrtc_module_with_options(program, output, JitOptions::default())
}
pub fn load_nvrtc_module_with_options(
self: &Arc<Self>,
program: &nvrtc::Program,
output: OutputKind,
jit_options: JitOptions<'_>,
) -> Result<Module> {
let image = module_loadable_image(program.artifact(output)?)?;
self.load_module_with_options(&image, jit_options)
}
pub fn load_library(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Library> {
self.load_library_with_options(image, JitOptions::default())
}
pub fn load_library_with_options(
self: &Arc<Self>,
image: &ModuleImage<'_>,
mut jit_options: JitOptions<'_>,
) -> Result<Library> {
self.bind()?;
let mut jit_options = jit_options.build();
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(driver::cuLibraryLoadData(
&raw mut handle,
image.as_ptr() as _,
jit_options.names.as_mut_ptr() as _,
jit_options.values.as_mut_ptr() as _,
jit_options.names.len() as _,
ptr::null_mut(),
ptr::null_mut(),
0,
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(unsafe { Library::from_raw(handle, Arc::clone(self)) })
}
pub fn load_nvrtc_library(
self: &Arc<Self>,
program: &nvrtc::Program,
output: OutputKind,
) -> Result<Library> {
self.load_nvrtc_library_with_options(program, output, JitOptions::default())
}
pub fn load_nvrtc_library_with_options(
self: &Arc<Self>,
program: &nvrtc::Program,
output: OutputKind,
jit_options: JitOptions<'_>,
) -> Result<Library> {
let image = library_loadable_image(program.artifact(output)?)?;
self.load_library_with_options(&image, jit_options)
}
pub fn load_library_from_file(self: &Arc<Self>, path: &str) -> Result<Library> {
self.bind()?;
let path = CString::new(path)?;
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(driver::cuLibraryLoadFromFile(
&raw mut handle,
path.as_ptr(),
ptr::null_mut(),
ptr::null_mut(),
0,
ptr::null_mut(),
ptr::null_mut(),
0,
))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(unsafe { Library::from_raw(handle, Arc::clone(self)) })
}
pub fn synchronize(&self) -> Result<()> {
self.bind()?;
unsafe {
try_cuda!(driver::cuCtxSynchronize())?;
}
Ok(())
}
pub fn flags(&self) -> Result<ContextFlags> {
self.bind()?;
unsafe {
let mut flags = 0;
try_cuda!(driver::cuCtxGetFlags(&raw mut flags))?;
Ok(ContextFlags::from_bits_truncate(flags))
}
}
pub fn limit(&self, limit: Limit) -> Result<usize> {
self.bind()?;
unsafe {
let mut value = 0;
try_cuda!(driver::cuCtxGetLimit(&raw mut value, limit.into()))?;
Ok(value as usize)
}
}
pub fn set_limit(&self, limit: Limit, value: usize) -> Result<()> {
self.bind()?;
unsafe {
try_cuda!(driver::cuCtxSetLimit(limit.into(), value as _))?;
}
Ok(())
}
pub const fn device(&self) -> Device {
self.device
}
pub const unsafe fn as_raw(&self) -> driver::CUcontext {
self.handle
}
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Drop for Context {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_cuda!(driver::cuCtxDestroy_v2(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy CUDA context: {err}");
}
}
}
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
unsafe { self.as_raw() == other.as_raw() }
}
}
impl Eq for Context {}
fn module_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
match artifact {
CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
}
}
fn library_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
match artifact {
CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
}
}