#[cfg(feature = "cuda-oxide-copy-u8")]
use crate::build_flags::ensure_cuda_oxide_copy_u8_ptx_built;
#[cfg(feature = "cuda-oxide-copy-u8")]
use crate::kernels;
use crate::{
build_flags::{ensure_kernel_ptx_built, CUDA_IDWT_TRACE_ENV_VAR},
bytes::{f32_slice_as_bytes_mut, i32_slice_as_bytes_mut},
driver::{CuContext, CuFunction, CuModule, Driver},
error::CudaError,
execution::{CudaExecutionStats, CudaLaunchMode},
htj2k_decode::{
htj2k_decode_needs_zero_fill, CudaHtj2kCodeBlockJob, CudaHtj2kDecodeOutput,
CudaHtj2kDecodeStageTimings, CudaQueuedHtj2kCleanup,
},
htj2k_encode::{
CudaHtj2kEncodeStageTimings, CudaHtj2kEncodeStatus, CudaHtj2kEncodedCodeBlock,
CudaHtj2kEncodedCodeBlocks,
},
kernels::CudaKernel,
memory::{pooled_device_buffer, CudaDeviceBuffer, CudaPooledDeviceBuffer},
};
use std::{
collections::HashMap,
ffi::{c_char, c_void},
sync::{Arc, Mutex},
};
pub(crate) struct ContextInner {
pub(crate) driver: Driver,
pub(crate) context: CuContext,
pub(crate) modules: Mutex<HashMap<CompiledKernelKey, CompiledKernel>>,
pub(crate) pinned_upload_staging: Mutex<Vec<PinnedUploadStaging>>,
}
pub(crate) struct PinnedUploadStaging {
pub(crate) ptr: *mut u8,
pub(crate) len: usize,
}
impl PinnedUploadStaging {
pub(crate) fn as_slice(&self) -> &[u8] {
if self.len == 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(self.ptr.cast_const(), self.len) }
}
}
pub(crate) fn as_mut_slice(&mut self) -> &mut [u8] {
if self.len == 0 {
&mut []
} else {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
pub(crate) fn free(self, driver: &Driver) -> Result<(), CudaError> {
if self.ptr.is_null() {
return Ok(());
}
driver.check("cuMemFreeHost", unsafe {
(driver.cu_mem_free_host)(self.ptr.cast())
})
}
}
unsafe impl Send for PinnedUploadStaging {}
impl ContextInner {
pub(crate) fn set_current(&self) -> Result<(), CudaError> {
self.driver.check("cuCtxSetCurrent", unsafe {
(self.driver.cu_ctx_set_current)(self.context)
})
}
pub(crate) fn kernel_function(&self, kernel: CudaKernel) -> Result<CuFunction, CudaError> {
self.kernel_function_from_key(CompiledKernelKey::Builtin(kernel))
}
#[cfg(feature = "cuda-oxide-copy-u8")]
pub(crate) fn cuda_oxide_copy_u8_kernel_function(&self) -> Result<CuFunction, CudaError> {
ensure_cuda_oxide_copy_u8_ptx_built()?;
self.kernel_function_from_key(CompiledKernelKey::CudaOxideCopyU8)
}
fn kernel_function_from_key(&self, key: CompiledKernelKey) -> Result<CuFunction, CudaError> {
ensure_kernel_ptx_built(key.kernel())?;
self.set_current()?;
let mut modules = self
.modules
.lock()
.map_err(|error| CudaError::StatePoisoned {
message: error.to_string(),
})?;
if let Some(compiled) = modules.get(&key) {
return Ok(compiled.function);
}
let compiled = CompiledKernel::load(self, key)?;
let function = compiled.function;
modules.insert(key, compiled);
Ok(function)
}
}
impl Drop for ContextInner {
fn drop(&mut self) {
if !self.context.is_null() {
let _ = self.set_current();
let pinned_upload_staging = match self.pinned_upload_staging.get_mut() {
Ok(pinned_upload_staging) => pinned_upload_staging,
Err(poisoned) => poisoned.into_inner(),
};
for staging in pinned_upload_staging.drain(..) {
let _ = staging.free(&self.driver);
}
let modules = match self.modules.get_mut() {
Ok(modules) => modules,
Err(poisoned) => poisoned.into_inner(),
};
for compiled in modules.drain().map(|(_, compiled)| compiled) {
let _ = unsafe { (self.driver.cu_module_unload)(compiled.module) };
}
let _ = unsafe { (self.driver.cu_ctx_destroy)(self.context) };
}
}
}
unsafe impl Send for ContextInner {}
unsafe impl Sync for ContextInner {}
#[derive(Clone)]
pub struct CudaContext {
pub(crate) inner: Arc<ContextInner>,
}
#[derive(Debug)]
pub struct CudaHtj2kCompactEncodedCodeBlock {
pub(crate) payload_range: std::ops::Range<usize>,
pub(crate) status: CudaHtj2kEncodeStatus,
pub(crate) execution: CudaExecutionStats,
pub(crate) stage_timings: CudaHtj2kEncodeStageTimings,
}
impl CudaHtj2kCompactEncodedCodeBlock {
pub fn payload_range(&self) -> std::ops::Range<usize> {
self.payload_range.clone()
}
pub fn cleanup_length(&self) -> u32 {
if self.status.number_of_coding_passes <= 1 {
self.status.data_len
} else {
self.status.reserved0
}
}
pub fn refinement_length(&self) -> u32 {
if self.status.number_of_coding_passes <= 1 {
0
} else {
self.status.reserved1
}
}
pub fn num_coding_passes(&self) -> u8 {
u8::try_from(self.status.number_of_coding_passes).unwrap_or(u8::MAX)
}
pub fn num_zero_bitplanes(&self) -> u8 {
u8::try_from(self.status.missing_bit_planes).unwrap_or(u8::MAX)
}
pub fn into_parts(self) -> (std::ops::Range<usize>, u32, u32, u8, u8) {
let cleanup_length = if self.status.number_of_coding_passes <= 1 {
self.status.data_len
} else {
self.status.reserved0
};
let refinement_length = if self.status.number_of_coding_passes <= 1 {
0
} else {
self.status.reserved1
};
(
self.payload_range,
cleanup_length,
refinement_length,
u8::try_from(self.status.number_of_coding_passes).unwrap_or(u8::MAX),
u8::try_from(self.status.missing_bit_planes).unwrap_or(u8::MAX),
)
}
pub fn status(&self) -> CudaHtj2kEncodeStatus {
self.status
}
pub fn execution(&self) -> CudaExecutionStats {
self.execution
}
pub fn stage_timings(&self) -> CudaHtj2kEncodeStageTimings {
self.stage_timings
}
}
#[derive(Debug)]
pub struct CudaHtj2kCompactEncodedCodeBlocks {
pub(crate) payload: Vec<u8>,
pub(crate) code_blocks: Vec<CudaHtj2kCompactEncodedCodeBlock>,
pub(crate) execution: CudaExecutionStats,
pub(crate) stage_timings: CudaHtj2kEncodeStageTimings,
}
impl CudaHtj2kCompactEncodedCodeBlocks {
pub fn payload(&self) -> &[u8] {
&self.payload
}
pub fn code_blocks(&self) -> &[CudaHtj2kCompactEncodedCodeBlock] {
&self.code_blocks
}
pub fn into_payload_and_code_blocks(self) -> (Vec<u8>, Vec<CudaHtj2kCompactEncodedCodeBlock>) {
(self.payload, self.code_blocks)
}
pub fn execution(&self) -> CudaExecutionStats {
self.execution
}
pub fn stage_timings(&self) -> CudaHtj2kEncodeStageTimings {
self.stage_timings
}
pub(crate) fn into_owned_code_blocks(self) -> Result<CudaHtj2kEncodedCodeBlocks, CudaError> {
let Self {
payload,
code_blocks,
execution,
stage_timings,
} = self;
let code_blocks = code_blocks
.into_iter()
.map(|block| {
let CudaHtj2kCompactEncodedCodeBlock {
payload_range,
status,
execution,
stage_timings,
} = block;
if payload_range.start > payload_range.end || payload_range.end > payload.len() {
return Err(CudaError::LengthTooLarge {
len: payload_range.end,
});
}
Ok(CudaHtj2kEncodedCodeBlock {
data: payload[payload_range].to_vec(),
status,
execution,
stage_timings,
})
})
.collect::<Result<Vec<_>, CudaError>>()?;
Ok(CudaHtj2kEncodedCodeBlocks {
code_blocks,
execution,
stage_timings,
})
}
}
pub(crate) const HTJ2K_UVLC_ENCODE_TABLE_BYTES: usize = 75 * 6;
impl CudaContext {
pub fn system_default() -> Result<Self, CudaError> {
let driver = Driver::load()?;
driver.check("cuInit", unsafe { (driver.cu_init)(0) })?;
let mut count = 0;
driver.check("cuDeviceGetCount", unsafe {
(driver.cu_device_get_count)(&raw mut count)
})?;
if count <= 0 {
return Err(CudaError::Unavailable {
message: "no CUDA devices reported by driver".to_string(),
});
}
let mut device = 0;
driver.check("cuDeviceGet", unsafe {
(driver.cu_device_get)(&raw mut device, 0)
})?;
let mut context = std::ptr::null_mut();
driver.check("cuCtxCreate_v2", unsafe {
(driver.cu_ctx_create)(&raw mut context, 0, device)
})?;
Ok(Self {
inner: Arc::new(ContextInner {
driver,
context,
modules: Mutex::new(HashMap::new()),
pinned_upload_staging: Mutex::new(Vec::new()),
}),
})
}
pub fn j2k_dequantize_queued_htj2k_cleanup_with_pool(
&self,
cleanup: &CudaQueuedHtj2kCleanup,
) -> Result<CudaExecutionStats, CudaError> {
self.inner.set_current()?;
if cleanup.status_count == 0 {
return Ok(CudaExecutionStats::default());
}
let Some(jobs_buffer) = cleanup.resources.first() else {
return Err(CudaError::InvalidArgument {
message: "queued HTJ2K cleanup has no metadata buffer".to_string(),
});
};
self.launch_j2k_dequantize_htj2k_cleanup_jobs_multi(
pooled_device_buffer(jobs_buffer)?,
cleanup.status_count,
CudaLaunchMode::Sync,
)?;
Ok(CudaExecutionStats {
kernel_dispatches: 1,
copy_kernel_dispatches: 0,
decode_kernel_dispatches: 1,
hardware_decode: false,
})
}
pub(crate) fn decode_empty_htj2k_codeblocks(
&self,
jobs: &[CudaHtj2kCodeBlockJob],
output_words: usize,
) -> Result<CudaHtj2kDecodeOutput, CudaError> {
self.inner.set_current()?;
let output_bytes = output_words
.checked_mul(std::mem::size_of::<f32>())
.ok_or(CudaError::LengthTooLarge { len: output_words })?;
let coefficients = self.allocate(output_bytes)?;
if htj2k_decode_needs_zero_fill(jobs, output_words)? {
self.memset_d32(&coefficients, 0, output_words)?;
}
Ok(CudaHtj2kDecodeOutput {
coefficients,
execution: CudaExecutionStats::default(),
statuses: Vec::new(),
stage_timings: CudaHtj2kDecodeStageTimings::default(),
})
}
}
impl std::fmt::Debug for CudaContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaContext").finish_non_exhaustive()
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum CudaKernelName {
CopyU8,
J2kDeinterleaveToF32,
J2kForwardRct,
J2kForwardIct,
J2kForwardDwt53Horizontal,
J2kForwardDwt53Vertical,
J2kForwardDwt97Horizontal,
J2kForwardDwt97Vertical,
J2kQuantizeSubband,
J2kQuantizeSubbandStrided,
Htj2kDecodeCodeblocks,
Htj2kDecodeCodeblocksMultiCleanupDequantize,
J2kDequantizeHtj2kCodeblocks,
J2kDequantizeHtj2kCodeblocksMulti,
J2kDequantizeHtj2kCleanupJobsMulti,
J2kIdwtInterleave,
J2kIdwtInterleaveHorizontal53Multi,
J2kIdwtInterleaveHorizontal97Multi,
J2kIdwtHorizontal,
J2kIdwtHorizontal53,
J2kIdwtHorizontal97,
J2kIdwtVertical,
J2kIdwtVertical53Multi,
J2kIdwtVertical97Multi,
J2kIdwtVertical97MultiCols4,
J2kIdwtVertical53,
J2kIdwtVertical97,
J2kInverseDwtSingle,
J2kInverseMct,
J2kStoreGray8,
J2kStoreGray16,
J2kStoreRgb8,
J2kStoreRgb8Mct,
J2kStoreRgb8MctBatch,
J2kStoreRgb16,
J2kStoreRgb16Mct,
Htj2kEncodeCodeblock,
Htj2kEncodeCodeblocks,
Htj2kEncodeCodeblocksMultiInput,
Htj2kEncodeCodeblocksMultiInputCleanup,
Htj2kEncodeCodeblocksMultiInputCleanup64,
Htj2kCompactCodeblocks,
Htj2kPacketizeCleanup,
}
impl CudaKernelName {
pub(crate) fn kernel(self) -> CudaKernel {
match self {
Self::CopyU8 => CudaKernel::CopyU8,
Self::J2kDeinterleaveToF32 => CudaKernel::J2kDeinterleaveToF32,
Self::J2kForwardRct => CudaKernel::J2kForwardRct,
Self::J2kForwardIct => CudaKernel::J2kForwardIct,
Self::J2kForwardDwt53Horizontal => CudaKernel::J2kForwardDwt53Horizontal,
Self::J2kForwardDwt53Vertical => CudaKernel::J2kForwardDwt53Vertical,
Self::J2kForwardDwt97Horizontal => CudaKernel::J2kForwardDwt97Horizontal,
Self::J2kForwardDwt97Vertical => CudaKernel::J2kForwardDwt97Vertical,
Self::J2kQuantizeSubband => CudaKernel::J2kQuantizeSubband,
Self::J2kQuantizeSubbandStrided => CudaKernel::J2kQuantizeSubbandStrided,
Self::Htj2kDecodeCodeblocks => CudaKernel::Htj2kDecodeCodeblocks,
Self::Htj2kDecodeCodeblocksMultiCleanupDequantize => {
CudaKernel::Htj2kDecodeCodeblocksMultiCleanupDequantize
}
Self::J2kDequantizeHtj2kCodeblocks => CudaKernel::J2kDequantizeHtj2kCodeblocks,
Self::J2kDequantizeHtj2kCodeblocksMulti => {
CudaKernel::J2kDequantizeHtj2kCodeblocksMulti
}
Self::J2kDequantizeHtj2kCleanupJobsMulti => {
CudaKernel::J2kDequantizeHtj2kCleanupJobsMulti
}
Self::J2kIdwtInterleave => CudaKernel::J2kIdwtInterleave,
Self::J2kIdwtInterleaveHorizontal53Multi => {
CudaKernel::J2kIdwtInterleaveHorizontal53Multi
}
Self::J2kIdwtInterleaveHorizontal97Multi => {
CudaKernel::J2kIdwtInterleaveHorizontal97Multi
}
Self::J2kIdwtHorizontal => CudaKernel::J2kIdwtHorizontal,
Self::J2kIdwtHorizontal53 => CudaKernel::J2kIdwtHorizontal53,
Self::J2kIdwtHorizontal97 => CudaKernel::J2kIdwtHorizontal97,
Self::J2kIdwtVertical => CudaKernel::J2kIdwtVertical,
Self::J2kIdwtVertical53Multi => CudaKernel::J2kIdwtVertical53Multi,
Self::J2kIdwtVertical97Multi => CudaKernel::J2kIdwtVertical97Multi,
Self::J2kIdwtVertical97MultiCols4 => CudaKernel::J2kIdwtVertical97MultiCols4,
Self::J2kIdwtVertical53 => CudaKernel::J2kIdwtVertical53,
Self::J2kIdwtVertical97 => CudaKernel::J2kIdwtVertical97,
Self::J2kInverseDwtSingle => CudaKernel::J2kInverseDwtSingle,
Self::J2kInverseMct => CudaKernel::J2kInverseMct,
Self::J2kStoreGray8 => CudaKernel::J2kStoreGray8,
Self::J2kStoreGray16 => CudaKernel::J2kStoreGray16,
Self::J2kStoreRgb8 => CudaKernel::J2kStoreRgb8,
Self::J2kStoreRgb8Mct => CudaKernel::J2kStoreRgb8Mct,
Self::J2kStoreRgb8MctBatch => CudaKernel::J2kStoreRgb8MctBatch,
Self::J2kStoreRgb16 => CudaKernel::J2kStoreRgb16,
Self::J2kStoreRgb16Mct => CudaKernel::J2kStoreRgb16Mct,
Self::Htj2kEncodeCodeblock => CudaKernel::Htj2kEncodeCodeblock,
Self::Htj2kEncodeCodeblocks => CudaKernel::Htj2kEncodeCodeblocks,
Self::Htj2kEncodeCodeblocksMultiInput => CudaKernel::Htj2kEncodeCodeblocksMultiInput,
Self::Htj2kEncodeCodeblocksMultiInputCleanup => {
CudaKernel::Htj2kEncodeCodeblocksMultiInputCleanup
}
Self::Htj2kEncodeCodeblocksMultiInputCleanup64 => {
CudaKernel::Htj2kEncodeCodeblocksMultiInputCleanup64
}
Self::Htj2kCompactCodeblocks => CudaKernel::Htj2kCompactCodeblocks,
Self::Htj2kPacketizeCleanup => CudaKernel::Htj2kPacketizeCleanup,
}
}
pub(crate) fn entrypoint(self) -> &'static str {
match self {
Self::CopyU8 => "j2k_copy_u8",
Self::J2kDeinterleaveToF32 => "j2k_deinterleave_to_f32",
Self::J2kForwardRct => "j2k_forward_rct",
Self::J2kForwardIct => "j2k_forward_ict",
Self::J2kForwardDwt53Horizontal => "j2k_forward_dwt53_horizontal",
Self::J2kForwardDwt53Vertical => "j2k_forward_dwt53_vertical",
Self::J2kForwardDwt97Horizontal => "j2k_forward_dwt97_horizontal",
Self::J2kForwardDwt97Vertical => "j2k_forward_dwt97_vertical",
Self::J2kQuantizeSubband => "j2k_quantize_subband",
Self::J2kQuantizeSubbandStrided => "j2k_quantize_subband_strided",
Self::Htj2kDecodeCodeblocks => "j2k_htj2k_decode_codeblocks",
Self::Htj2kDecodeCodeblocksMultiCleanupDequantize => {
"j2k_htj2k_decode_codeblocks_multi_cleanup_dequantize"
}
Self::J2kDequantizeHtj2kCodeblocks => "j2k_dequantize_htj2k_codeblocks",
Self::J2kDequantizeHtj2kCodeblocksMulti => "j2k_dequantize_htj2k_codeblocks_multi",
Self::J2kDequantizeHtj2kCleanupJobsMulti => "j2k_dequantize_htj2k_cleanup_jobs_multi",
Self::J2kIdwtInterleave => "j2k_idwt_interleave",
Self::J2kIdwtInterleaveHorizontal53Multi => "j2k_idwt_interleave_horizontal_53_multi",
Self::J2kIdwtInterleaveHorizontal97Multi => "j2k_idwt_interleave_horizontal_97_multi",
Self::J2kIdwtHorizontal => "j2k_idwt_horizontal",
Self::J2kIdwtHorizontal53 => "j2k_idwt_horizontal_53",
Self::J2kIdwtHorizontal97 => "j2k_idwt_horizontal_97",
Self::J2kIdwtVertical => "j2k_idwt_vertical",
Self::J2kIdwtVertical53Multi => "j2k_idwt_vertical_53_multi",
Self::J2kIdwtVertical97Multi => "j2k_idwt_vertical_97_multi",
Self::J2kIdwtVertical97MultiCols4 => "j2k_idwt_vertical_97_multi_cols4",
Self::J2kIdwtVertical53 => "j2k_idwt_vertical_53",
Self::J2kIdwtVertical97 => "j2k_idwt_vertical_97",
Self::J2kInverseDwtSingle => "j2k_inverse_dwt_single",
Self::J2kInverseMct => "j2k_inverse_mct",
Self::J2kStoreGray8 => "j2k_store_gray8",
Self::J2kStoreGray16 => "j2k_store_gray16",
Self::J2kStoreRgb8 => "j2k_store_rgb8",
Self::J2kStoreRgb8Mct => "j2k_store_rgb8_mct",
Self::J2kStoreRgb8MctBatch => "j2k_store_rgb8_mct_batch",
Self::J2kStoreRgb16 => "j2k_store_rgb16",
Self::J2kStoreRgb16Mct => "j2k_store_rgb16_mct",
Self::Htj2kEncodeCodeblock => "j2k_htj2k_encode_codeblock",
Self::Htj2kEncodeCodeblocks => "j2k_htj2k_encode_codeblocks",
Self::Htj2kEncodeCodeblocksMultiInput => "j2k_htj2k_encode_codeblocks_multi_input",
Self::Htj2kEncodeCodeblocksMultiInputCleanup => {
"j2k_htj2k_encode_codeblocks_multi_input_cleanup"
}
Self::Htj2kEncodeCodeblocksMultiInputCleanup64 => {
"j2k_htj2k_encode_codeblocks_multi_input_cleanup_64"
}
Self::Htj2kCompactCodeblocks => "j2k_htj2k_compact_codeblocks",
Self::Htj2kPacketizeCleanup => "j2k_htj2k_packetize_cleanup",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct CudaKernelModule {
pub(crate) kernel: CudaKernelName,
pub(crate) entrypoint: &'static str,
}
impl CudaKernelModule {
pub fn kernel(&self) -> CudaKernelName {
self.kernel
}
pub fn entrypoint(&self) -> &'static str {
self.entrypoint
}
}
pub(crate) fn cuda_idwt_trace_enabled() -> bool {
std::env::var_os(CUDA_IDWT_TRACE_ENV_VAR).is_some()
}
impl CudaContext {
pub(crate) fn download_i32_band(
buffer: &CudaDeviceBuffer,
count: usize,
) -> Result<Vec<i32>, CudaError> {
let mut out = vec![0i32; count];
if count != 0 {
buffer.copy_to_host(i32_slice_as_bytes_mut(&mut out))?;
}
Ok(out)
}
}
impl CudaContext {
pub(crate) fn download_f32_band(
buffer: &CudaDeviceBuffer,
count: usize,
) -> Result<Vec<f32>, CudaError> {
let mut out = vec![0f32; count];
if count != 0 {
buffer.copy_to_host(f32_slice_as_bytes_mut(&mut out))?;
}
Ok(out)
}
pub(crate) fn download_pooled_f32_band(
buffer: &CudaPooledDeviceBuffer,
count: usize,
) -> Result<Vec<f32>, CudaError> {
let mut out = vec![0f32; count];
if count != 0 {
buffer.copy_to_host(f32_slice_as_bytes_mut(&mut out))?;
}
Ok(out)
}
}
#[derive(Debug)]
pub(crate) struct CompiledKernel {
pub(crate) module: CuModule,
pub(crate) function: CuFunction,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(crate) enum CompiledKernelKey {
Builtin(CudaKernel),
#[cfg(feature = "cuda-oxide-copy-u8")]
CudaOxideCopyU8,
}
impl CompiledKernelKey {
pub(crate) fn kernel(self) -> CudaKernel {
match self {
Self::Builtin(kernel) => kernel,
#[cfg(feature = "cuda-oxide-copy-u8")]
Self::CudaOxideCopyU8 => CudaKernel::CopyU8,
}
}
pub(crate) fn ptx(self) -> &'static [u8] {
match self {
Self::Builtin(kernel) => kernel.ptx(),
#[cfg(feature = "cuda-oxide-copy-u8")]
Self::CudaOxideCopyU8 => kernels::cuda_oxide_copy_u8_ptx(),
}
}
pub(crate) fn entrypoint(self) -> &'static [u8] {
self.kernel().entrypoint()
}
}
impl CompiledKernel {
pub(crate) fn load(context: &ContextInner, key: CompiledKernelKey) -> Result<Self, CudaError> {
context.set_current()?;
let mut module = std::ptr::null_mut();
context.driver.check("cuModuleLoadData", unsafe {
(context.driver.cu_module_load_data)(
&raw mut module,
key.ptx().as_ptr().cast::<c_void>(),
)
})?;
let mut function = std::ptr::null_mut();
context.driver.check("cuModuleGetFunction", unsafe {
(context.driver.cu_module_get_function)(
&raw mut function,
module,
key.entrypoint().as_ptr().cast::<c_char>(),
)
})?;
Ok(Self { module, function })
}
}
unsafe impl Send for CompiledKernel {}