use alloc::string::ToString;
use core::ops::BitOr;
use super::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct AttentionBackend(u32);
impl AttentionBackend {
pub const FLASH_ATTENTION: Self = Self(1 << 0);
pub const EFFICIENT_ATTENTION: Self = Self(1 << 1);
pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2);
pub const CUDNN_FLASH_ATTENTION: Self = Self(1 << 3);
pub const MATH: Self = Self(1 << 4);
pub const TRT_FLASH_ATTENTION: Self = Self(1 << 5);
pub const TRT_CROSS_ATTENTION: Self = Self(1 << 6);
pub const TRT_CAUSAL_ATTENTION: Self = Self(1 << 7);
pub const LEAN_ATTENTION: Self = Self(1 << 8);
pub fn none() -> Self {
AttentionBackend(0)
}
pub fn all() -> Self {
Self::FLASH_ATTENTION
| Self::EFFICIENT_ATTENTION
| Self::TRT_FUSED_ATTENTION
| Self::CUDNN_FLASH_ATTENTION
| Self::MATH
| Self::TRT_FLASH_ATTENTION
| Self::TRT_CROSS_ATTENTION
| Self::TRT_CAUSAL_ATTENTION
}
}
impl BitOr for AttentionBackend {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(rhs.0 | self.0)
}
}
#[derive(Debug, Clone, Default)]
pub enum ConvAlgorithmSearch {
#[default]
Exhaustive,
Heuristic,
Default
}
#[derive(Debug, Default, Clone)]
pub struct CUDA {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; CUDA);
impl CUDA {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.options.set("device_id", device_id.to_string());
self
}
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.options.set("gpu_mem_limit", limit.to_string());
self
}
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.options.set(
"arena_extend_strategy",
match strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}
);
self
}
#[must_use]
pub fn with_conv_algorithm_search(mut self, search: ConvAlgorithmSearch) -> Self {
self.options.set(
"cudnn_conv_algo_search",
match search {
ConvAlgorithmSearch::Exhaustive => "EXHAUSTIVE",
ConvAlgorithmSearch::Heuristic => "HEURISTIC",
ConvAlgorithmSearch::Default => "DEFAULT"
}
);
self
}
#[must_use]
pub fn with_conv_max_workspace(mut self, enable: bool) -> Self {
self.options.set("cudnn_conv_use_max_workspace", if enable { "1" } else { "0" });
self
}
#[must_use]
pub fn with_conv1d_pad_to_nc1d(mut self, enable: bool) -> Self {
self.options.set("cudnn_conv1d_pad_to_nc1d", if enable { "1" } else { "0" });
self
}
#[must_use]
pub fn with_cuda_graph(mut self, enable: bool) -> Self {
self.options.set("enable_cuda_graph", if enable { "1" } else { "0" });
self
}
#[must_use]
pub fn with_skip_layer_norm_strict_mode(mut self, enable: bool) -> Self {
self.options.set("enable_skip_layer_norm_strict_mode", if enable { "1" } else { "0" });
self
}
#[must_use]
pub fn with_tf32(mut self, enable: bool) -> Self {
self.options.set("use_tf32", if enable { "1" } else { "0" });
self
}
#[must_use]
pub fn with_prefer_nhwc(mut self, enable: bool) -> Self {
self.options.set("prefer_nhwc", if enable { "1" } else { "0" });
self
}
#[must_use]
pub unsafe fn with_compute_stream(mut self, stream: *mut ()) -> Self {
self.options.set("has_user_compute_stream", "1");
self.options.set("user_compute_stream", (stream as usize).to_string());
self
}
#[must_use]
pub fn with_attention_backend(mut self, flags: AttentionBackend) -> Self {
self.options.set("sdpa_kernel", flags.0.to_string());
self
}
#[must_use]
pub fn with_fuse_conv_bias(mut self, enable: bool) -> Self {
self.options.set("fuse_conv_bias", if enable { "1" } else { "0" });
self
}
}
impl ExecutionProvider for CUDA {
fn name(&self) -> &'static str {
"CUDAExecutionProvider"
}
fn supported_by_platform(&self) -> bool {
cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64")))
}
#[allow(unused, unreachable_code)]
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
#[cfg(any(feature = "load-dynamic", feature = "cuda"))]
{
use core::ptr;
use crate::{AsPointer, ortsys, util};
let mut cuda_options: *mut ort_sys::OrtCUDAProviderOptionsV2 = ptr::null_mut();
ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)?];
let _guard = util::run_on_drop(|| {
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
});
let ffi_options = self.options.to_ffi();
ortsys![unsafe UpdateCUDAProviderOptions(
cuda_options,
ffi_options.key_ptrs(),
ffi_options.value_ptrs(),
ffi_options.len()
)?];
ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(session_builder.ptr_mut(), cuda_options)?];
return Ok(());
}
Err(RegisterError::MissingFeature)
}
}
#[cfg(windows)]
pub const CUDA_DYLIBS: &[&str] = &["cudart64_12.dll", "cublasLt64_12.dll", "cublas64_12.dll", "cufft64_11.dll"];
#[cfg(not(windows))]
pub const CUDA_DYLIBS: &[&str] = &["libcudart.so.12", "libcublasLt.so.12", "libcublas.so.12", "libnvrtc.so.12", "libcurand.so.10", "libcufft.so.11"];
#[cfg(windows)]
pub const CUDNN_DYLIBS: &[&str] = &[
"cudnn64_9.dll",
"cudnn_graph64_9.dll",
"cudnn_ops64_9.dll",
"cudnn_heuristic64_9.dll",
"cudnn_adv64_9.dll",
"cudnn_cnn64_9.dll",
"cudnn_engines_precompiled64_9.dll",
"cudnn_engines_runtime_compiled64_9.dll"
];
#[cfg(not(windows))]
pub const CUDNN_DYLIBS: &[&str] = &[
"libcudnn.so.9",
"libcudnn_graph.so.9",
"libcudnn_ops.so.9",
"libcudnn_heuristic.so.9",
"libcudnn_adv.so.9",
"libcudnn_cnn.so.9",
"libcudnn_engines_precompiled.so.9",
"libcudnn_engines_runtime_compiled.so.9"
];
#[cfg_attr(docsrs, doc(cfg(any(feature = "preload-dylibs", feature = "load-dynamic"))))]
#[cfg(all(feature = "preload-dylibs", not(target_arch = "wasm32")))]
pub fn preload_dylibs(cuda_root_dir: Option<&std::path::Path>, cudnn_root_dir: Option<&std::path::Path>) -> Result<()> {
use crate::util::preload_dylib;
if let Some(cuda_root_dir) = cuda_root_dir {
for dylib in CUDA_DYLIBS {
preload_dylib(cuda_root_dir.join(dylib)).map_err(|e| crate::Error::new(format!("Failed to preload `{dylib}`: {e}")))?;
}
}
if let Some(cudnn_root_dir) = cudnn_root_dir {
for dylib in CUDNN_DYLIBS {
preload_dylib(cudnn_root_dir.join(dylib)).map_err(|e| crate::Error::new(format!("Failed to preload `{dylib}`: {e}")))?;
}
}
Ok(())
}