use std::{
fmt::{self, Display, Formatter},
ptr,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_cuda_sys::{driver, runtime};
use singe_core::impl_enum_conversion;
#[derive(Debug, Default)]
pub struct JitOptions<'a> {
pub max_registers: Option<u32>,
pub threads_per_block: Option<u32>,
pub wall_time: Option<&'a mut f32>,
pub info_log_buffer: Option<&'a mut [u8]>,
pub error_log_buffer: Option<&'a mut [u8]>,
pub optimization_level: Option<u32>,
pub target_from_cuda_context: Option<()>,
pub target: Option<JitTarget>,
pub fallback_strategy: Option<JitFallback>,
pub generate_debug_info: Option<bool>,
pub log_verbose: Option<bool>,
pub generate_line_info: Option<bool>,
pub cache_mode: Option<JitCacheMode>,
}
pub struct JitOptionsArtifact {
pub names: Vec<driver::CUjit_option>,
pub values: Vec<*mut ()>,
storage_target: Option<u32>,
storage_fallback: Option<u32>,
storage_debug_info: Option<i32>,
storage_log_verbose: Option<i32>,
storage_line_info: Option<i32>,
storage_cache_mode: Option<u32>,
storage_info_log_ptr: Option<*mut u8>,
storage_info_log_size: Option<u32>,
storage_error_log_ptr: Option<*mut u8>,
storage_error_log_size: Option<u32>,
storage_max_registers: Option<u32>,
storage_threads_per_block: Option<u32>,
storage_optimization_level: Option<u32>,
}
impl<'a> JitOptions<'a> {
pub const fn set_max_registers(mut self, value: u32) -> Self {
self.max_registers = Some(value);
self
}
pub const fn set_threads_per_block(mut self, value: u32) -> Self {
self.threads_per_block = Some(value);
self
}
pub const fn set_wall_time(mut self, value: &'a mut f32) -> Self {
self.wall_time = Some(value);
self
}
pub const fn set_info_log(mut self, buffer: &'a mut [u8]) -> Self {
self.info_log_buffer = Some(buffer);
self
}
pub const fn set_error_log(mut self, buffer: &'a mut [u8]) -> Self {
self.error_log_buffer = Some(buffer);
self
}
pub fn set_optimization_level(mut self, level: u32) -> Self {
self.optimization_level = Some(level.min(4));
self
}
pub const fn set_target_from_cuda_context(mut self) -> Self {
self.target_from_cuda_context = Some(());
self
}
pub const fn set_target(mut self, target: JitTarget) -> Self {
self.target = Some(target);
self
}
pub const fn set_fallback_strategy(mut self, strategy: JitFallback) -> Self {
self.fallback_strategy = Some(strategy);
self
}
pub const fn set_generate_debug_info(mut self, enable: bool) -> Self {
self.generate_debug_info = Some(enable);
self
}
pub const fn set_log_verbose(mut self, enable: bool) -> Self {
self.log_verbose = Some(enable);
self
}
pub const fn set_generate_line_info(mut self, enable: bool) -> Self {
self.generate_line_info = Some(enable);
self
}
pub const fn set_cache_mode(mut self, mode: JitCacheMode) -> Self {
self.cache_mode = Some(mode);
self
}
pub fn build(&mut self) -> JitOptionsArtifact {
let mut artifact = JitOptionsArtifact {
names: Vec::new(),
values: Vec::new(),
storage_target: None,
storage_fallback: None,
storage_debug_info: None,
storage_log_verbose: None,
storage_line_info: None,
storage_cache_mode: None,
storage_info_log_ptr: None,
storage_info_log_size: None,
storage_error_log_ptr: None,
storage_error_log_size: None,
storage_max_registers: self.max_registers,
storage_threads_per_block: self.threads_per_block,
storage_optimization_level: self.optimization_level.map(|value| value.clamp(0, 4)),
};
artifact.storage_target = self.target.map(Into::into);
artifact.storage_fallback = self.fallback_strategy.map(Into::into);
artifact.storage_cache_mode = self.cache_mode.map(Into::into);
artifact.storage_debug_info = self.generate_debug_info.map(i32::from);
artifact.storage_log_verbose = self.log_verbose.map(i32::from);
artifact.storage_line_info = self.generate_line_info.map(i32::from);
artifact.storage_info_log_ptr = self
.info_log_buffer
.as_mut()
.map(|slice| slice.as_mut_ptr().cast::<u8>());
artifact.storage_info_log_size = self
.info_log_buffer
.as_ref()
.map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
artifact.storage_error_log_ptr = self
.error_log_buffer
.as_mut()
.map(|slice| slice.as_mut_ptr().cast::<u8>());
artifact.storage_error_log_size = self
.error_log_buffer
.as_ref()
.map(|buffer| buffer.len().min(u32::MAX as usize) as u32);
if let Some(ref mut val) = artifact.storage_max_registers {
artifact.names.push(JitOption::MaxRegisters.into());
artifact.values.push(ptr::from_mut::<u32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_threads_per_block {
artifact.names.push(JitOption::ThreadsPerBlock.into());
artifact.values.push(ptr::from_mut::<u32>(val).cast());
}
if let Some(ref mut val_ref) = self.wall_time {
artifact.names.push(JitOption::WallTime.into());
artifact.values.push(ptr::from_mut::<f32>(*val_ref).cast());
}
if let Some(info_log_ptr) = artifact.storage_info_log_ptr {
artifact.names.push(JitOption::InfoLogBuffer.into());
artifact.values.push(info_log_ptr.cast());
if let Some(ref mut size_val) = artifact.storage_info_log_size {
artifact
.names
.push(JitOption::InfoLogBufferSizeBytes.into());
artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
}
}
if let Some(error_log_ptr) = artifact.storage_error_log_ptr {
artifact.names.push(JitOption::ErrorLogBuffer.into());
artifact.values.push(error_log_ptr.cast());
if let Some(ref mut size_val) = artifact.storage_error_log_size {
artifact
.names
.push(JitOption::ErrorLogBufferSizeBytes.into());
artifact.values.push(ptr::from_mut::<u32>(size_val).cast());
}
}
if let Some(ref mut value) = artifact.storage_optimization_level {
artifact.names.push(JitOption::OptimizationLevel.into());
artifact.values.push(ptr::from_mut::<u32>(value).cast());
}
if self.target_from_cuda_context.is_some() {
artifact.names.push(JitOption::TargetFromCudaContext.into());
artifact.values.push(ptr::null_mut()); }
if let Some(ref mut val) = artifact.storage_target {
artifact.names.push(JitOption::Target.into());
artifact.values.push(ptr::from_mut::<u32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_fallback {
artifact.names.push(JitOption::FallbackStrategy.into());
artifact.values.push(ptr::from_mut::<u32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_debug_info {
artifact.names.push(JitOption::GenerateDebugInfo.into());
artifact.values.push(ptr::from_mut::<i32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_log_verbose {
artifact.names.push(JitOption::LogVerbose.into());
artifact.values.push(ptr::from_mut::<i32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_line_info {
artifact.names.push(JitOption::GenerateLineInfo.into());
artifact.values.push(ptr::from_mut::<i32>(val).cast());
}
if let Some(ref mut val) = artifact.storage_cache_mode {
artifact.names.push(JitOption::CacheMode.into());
artifact.values.push(ptr::from_mut::<u32>(val).cast());
}
artifact
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum JitOption {
MaxRegisters = runtime::cudaJitOption::CU_JIT_MAX_REGISTERS as _,
ThreadsPerBlock = runtime::cudaJitOption::CU_JIT_THREADS_PER_BLOCK as _,
WallTime = runtime::cudaJitOption::CU_JIT_WALL_TIME as _,
InfoLogBuffer = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER as _,
InfoLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES as _,
ErrorLogBuffer = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER as _,
ErrorLogBufferSizeBytes = runtime::cudaJitOption::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES as _,
OptimizationLevel = runtime::cudaJitOption::CU_JIT_OPTIMIZATION_LEVEL as _,
TargetFromCudaContext = runtime::cudaJitOption::CU_JIT_TARGET_FROM_CUCONTEXT as _,
Target = runtime::cudaJitOption::CU_JIT_TARGET as _,
FallbackStrategy = runtime::cudaJitOption::CU_JIT_FALLBACK_STRATEGY as _,
GenerateDebugInfo = runtime::cudaJitOption::CU_JIT_GENERATE_DEBUG_INFO as _,
LogVerbose = runtime::cudaJitOption::CU_JIT_LOG_VERBOSE as _,
GenerateLineInfo = runtime::cudaJitOption::CU_JIT_GENERATE_LINE_INFO as _,
CacheMode = runtime::cudaJitOption::CU_JIT_CACHE_MODE as _,
NewSm3xOpt = runtime::cudaJitOption::CU_JIT_NEW_SM3X_OPT as _,
FastCompile = runtime::cudaJitOption::CU_JIT_FAST_COMPILE as _,
GlobalSymbolNames = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_NAMES as _,
GlobalSymbolAddresses = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_ADDRESSES as _,
GlobalSymbolCount = runtime::cudaJitOption::CU_JIT_GLOBAL_SYMBOL_COUNT as _,
Lto = runtime::cudaJitOption::CU_JIT_LTO as _,
Ftz = runtime::cudaJitOption::CU_JIT_FTZ as _,
PrecDiv = runtime::cudaJitOption::CU_JIT_PREC_DIV as _,
PrecSqrt = runtime::cudaJitOption::CU_JIT_PREC_SQRT as _,
Fma = runtime::cudaJitOption::CU_JIT_FMA as _,
ReferencedKernelNames = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_NAMES as _,
ReferencedKernelCount = runtime::cudaJitOption::CU_JIT_REFERENCED_KERNEL_COUNT as _,
ReferencedVariableNames = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_NAMES as _,
ReferencedVariableCount = runtime::cudaJitOption::CU_JIT_REFERENCED_VARIABLE_COUNT as _,
OptimizeUnusedDeviceVariables =
runtime::cudaJitOption::CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES as _,
NumOptions = runtime::cudaJitOption::CU_JIT_NUM_OPTIONS as _,
}
impl_enum_conversion!(u32, runtime::cudaJitOption, JitOption);
impl Display for JitOption {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::MaxRegisters => write!(f, "CU_JIT_MAX_REGISTERS"),
Self::ThreadsPerBlock => write!(f, "CU_JIT_THREADS_PER_BLOCK"),
Self::WallTime => write!(f, "CU_JIT_WALL_TIME"),
Self::InfoLogBuffer => write!(f, "CU_JIT_INFO_LOG_BUFFER"),
Self::InfoLogBufferSizeBytes => write!(f, "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES"),
Self::ErrorLogBuffer => write!(f, "CU_JIT_ERROR_LOG_BUFFER"),
Self::ErrorLogBufferSizeBytes => write!(f, "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES"),
Self::OptimizationLevel => write!(f, "CU_JIT_OPTIMIZATION_LEVEL"),
Self::TargetFromCudaContext => write!(f, "CU_JIT_TARGET_FROM_CUCONTEXT"),
Self::Target => write!(f, "CU_JIT_TARGET"),
Self::FallbackStrategy => write!(f, "CU_JIT_FALLBACK_STRATEGY"),
Self::GenerateDebugInfo => write!(f, "CU_JIT_GENERATE_DEBUG_INFO"),
Self::LogVerbose => write!(f, "CU_JIT_LOG_VERBOSE"),
Self::GenerateLineInfo => write!(f, "CU_JIT_GENERATE_LINE_INFO"),
Self::CacheMode => write!(f, "CU_JIT_CACHE_MODE"),
Self::NewSm3xOpt => write!(f, "CU_JIT_NEW_SM3X_OPT"),
Self::FastCompile => write!(f, "CU_JIT_FAST_COMPILE"),
Self::GlobalSymbolNames => write!(f, "CU_JIT_GLOBAL_SYMBOL_NAMES"),
Self::GlobalSymbolAddresses => write!(f, "CU_JIT_GLOBAL_SYMBOL_ADDRESSES"),
Self::GlobalSymbolCount => write!(f, "CU_JIT_GLOBAL_SYMBOL_COUNT"),
Self::Lto => write!(f, "CU_JIT_LTO"),
Self::Ftz => write!(f, "CU_JIT_FTZ"),
Self::PrecDiv => write!(f, "CU_JIT_PREC_DIV"),
Self::PrecSqrt => write!(f, "CU_JIT_PREC_SQRT"),
Self::Fma => write!(f, "CU_JIT_FMA"),
Self::ReferencedKernelNames => write!(f, "CU_JIT_REFERENCED_KERNEL_NAMES"),
Self::ReferencedKernelCount => write!(f, "CU_JIT_REFERENCED_KERNEL_COUNT"),
Self::ReferencedVariableNames => write!(f, "CU_JIT_REFERENCED_VARIABLE_NAMES"),
Self::ReferencedVariableCount => write!(f, "CU_JIT_REFERENCED_VARIABLE_COUNT"),
Self::OptimizeUnusedDeviceVariables => {
write!(f, "CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES")
}
Self::NumOptions => write!(f, "CU_JIT_NUM_OPTIONS"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum JitTarget {
Compute30 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_30 as _,
Compute32 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_32 as _,
Compute35 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_35 as _,
Compute37 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_37 as _,
Compute50 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_50 as _,
Compute52 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_52 as _,
Compute53 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_53 as _,
Compute60 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_60 as _,
Compute61 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_61 as _,
Compute62 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_62 as _,
Compute70 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_70 as _,
Compute72 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_72 as _,
Compute75 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_75 as _,
Compute80 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_80 as _,
Compute86 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_86 as _,
Compute87 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_87 as _,
Compute89 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_89 as _,
Compute90 = driver::CUjit_target_enum::CU_TARGET_COMPUTE_90 as _,
}
impl_enum_conversion!(u32, driver::CUjit_target, JitTarget);
impl Display for JitTarget {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Compute30 => write!(f, "CU_TARGET_COMPUTE_30"),
Self::Compute32 => write!(f, "CU_TARGET_COMPUTE_32"),
Self::Compute35 => write!(f, "CU_TARGET_COMPUTE_35"),
Self::Compute37 => write!(f, "CU_TARGET_COMPUTE_37"),
Self::Compute50 => write!(f, "CU_TARGET_COMPUTE_50"),
Self::Compute52 => write!(f, "CU_TARGET_COMPUTE_52"),
Self::Compute53 => write!(f, "CU_TARGET_COMPUTE_53"),
Self::Compute60 => write!(f, "CU_TARGET_COMPUTE_60"),
Self::Compute61 => write!(f, "CU_TARGET_COMPUTE_61"),
Self::Compute62 => write!(f, "CU_TARGET_COMPUTE_62"),
Self::Compute70 => write!(f, "CU_TARGET_COMPUTE_70"),
Self::Compute72 => write!(f, "CU_TARGET_COMPUTE_72"),
Self::Compute75 => write!(f, "CU_TARGET_COMPUTE_75"),
Self::Compute80 => write!(f, "CU_TARGET_COMPUTE_80"),
Self::Compute86 => write!(f, "CU_TARGET_COMPUTE_86"),
Self::Compute87 => write!(f, "CU_TARGET_COMPUTE_87"),
Self::Compute89 => write!(f, "CU_TARGET_COMPUTE_89"),
Self::Compute90 => write!(f, "CU_TARGET_COMPUTE_90"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum JitFallback {
PreferPtx = driver::CUjit_fallback_enum::CU_PREFER_PTX as _,
PreferBinary = driver::CUjit_fallback_enum::CU_PREFER_BINARY as _,
}
impl_enum_conversion!(u32, driver::CUjit_fallback, JitFallback);
impl Display for JitFallback {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::PreferPtx => write!(f, "CU_PREFER_PTX"),
Self::PreferBinary => write!(f, "CU_PREFER_BINARY"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum JitCacheMode {
OptionNone = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_NONE as _,
OptionCg = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CG as _,
OptionCa = driver::CUjit_cacheMode_enum::CU_JIT_CACHE_OPTION_CA as _,
}
impl_enum_conversion!(u32, driver::CUjit_cacheMode, JitCacheMode);
impl Display for JitCacheMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::OptionNone => write!(f, "CU_JIT_CACHE_OPTION_NONE"),
Self::OptionCg => write!(f, "CU_JIT_CACHE_OPTION_CG"),
Self::OptionCa => write!(f, "CU_JIT_CACHE_OPTION_CA"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum JitInputType {
Cubin = driver::CUjitInputType_enum::CU_JIT_INPUT_CUBIN as _,
Ptx = driver::CUjitInputType_enum::CU_JIT_INPUT_PTX as _,
Fatbinary = driver::CUjitInputType_enum::CU_JIT_INPUT_FATBINARY as _,
Object = driver::CUjitInputType_enum::CU_JIT_INPUT_OBJECT as _,
Library = driver::CUjitInputType_enum::CU_JIT_INPUT_LIBRARY as _,
Nvvm = driver::CUjitInputType_enum::CU_JIT_INPUT_NVVM as _,
NumInputTypes = driver::CUjitInputType_enum::CU_JIT_NUM_INPUT_TYPES as _,
}
impl_enum_conversion!(u32, driver::CUjitInputType, JitInputType);
impl Display for JitInputType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Cubin => write!(f, "CU_JIT_INPUT_CUBIN"),
Self::Ptx => write!(f, "CU_JIT_INPUT_PTX"),
Self::Fatbinary => write!(f, "CU_JIT_INPUT_FATBINARY"),
Self::Object => write!(f, "CU_JIT_INPUT_OBJECT"),
Self::Library => write!(f, "CU_JIT_INPUT_LIBRARY"),
Self::Nvvm => write!(f, "CU_JIT_INPUT_NVVM"),
Self::NumInputTypes => write!(f, "CU_JIT_NUM_INPUT_TYPES"),
}
}
}