use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use std::sync::{Arc, OnceLock};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::driver::{CudaContext, CudaSlice, CudaStream, LaunchConfig};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use std::collections::HashMap;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use std::sync::Mutex;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::nvrtc::compile_ptx;
pub struct CudaImpl {
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
context: Arc<CudaContext>,
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
stream: Arc<CudaStream>,
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
kernel_cache: Arc<Mutex<HashMap<String, CudaKernel>>>,
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
memory_pool: Arc<Mutex<MemoryPool>>,
#[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
_placeholder: (),
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
#[derive(Clone)]
pub struct CudaKernel {
func: cudarc::driver::CudaFunction,
name: String,
grid_config: (u32, u32, u32),
block_config: (u32, u32, u32),
shared_memory: u32,
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub struct MemoryPool {
available_blocks: Vec<MemoryBlock>,
allocated_blocks: HashMap<usize, MemoryBlock>,
total_allocated: usize,
peak_memory: usize,
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
#[derive(Clone)]
pub struct MemoryBlock {
slice: CudaSlice<f32>,
size: usize,
id: usize,
}
#[allow(dead_code)]
static CUDA_INSTANCE: OnceLock<Arc<CudaImpl>> = OnceLock::new();
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl CudaImpl {
pub fn new() -> Result<Self> {
let context = CudaContext::new(0).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to initialize CUDA context: {}", e),
"cuda_init",
)
})?;
let stream = context.default_stream();
Ok(Self {
context,
stream,
kernel_cache: Arc::new(Mutex::new(HashMap::new())),
memory_pool: Arc::new(Mutex::new(MemoryPool::new())),
})
}
pub fn global() -> Result<&'static Arc<CudaImpl>> {
static ONCE: std::sync::Once = std::sync::Once::new();
static mut INIT_RESULT: Option<std::result::Result<Arc<CudaImpl>, TrustformersError>> =
None;
unsafe {
ONCE.call_once(|| {
INIT_RESULT = Some(Self::new().map(Arc::new));
});
match INIT_RESULT.as_ref()
.expect("INIT_RESULT must be initialized by ONCE.call_once") {
Ok(instance) => {
CUDA_INSTANCE.set(instance.clone()).ok();
Ok(CUDA_INSTANCE.get()
.expect("CUDA_INSTANCE must be set after successful initialization"))
},
Err(e) => Err(e.clone()),
}
}
}
pub fn compile_kernel(
&self,
name: &str,
source: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
) -> Result<CudaKernel> {
{
let cache = self.kernel_cache.lock().expect("Lock poisoned");
if let Some(kernel) = cache.get(name) {
return Ok(kernel.clone());
}
}
let ptx = compile_ptx(source).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to compile CUDA kernel: {}", e),
"cuda_compile",
)
})?;
self.device.load_ptx(ptx, "module", &["kernel"]).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to load PTX: {}", e),
"cuda_load_ptx",
)
})?;
let func = self.device.get_func("module", "kernel").ok_or_else(|| {
TrustformersError::hardware_error(
&format!("Failed to get kernel function: {}", name),
"cuda_get_func",
)
})?;
let kernel = CudaKernel {
func,
name: name.to_string(),
grid_config: grid,
block_config: block,
shared_memory: 0,
};
{
let mut cache = self.kernel_cache.lock().expect("Lock poisoned");
cache.insert(name.to_string(), kernel.clone());
}
Ok(kernel)
}
pub fn allocate_memory(&self, size: usize) -> Result<CudaSlice<f32>> {
let elements = (size + 3) / 4;
{
let mut pool = self.memory_pool.lock().expect("Lock poisoned");
if let Some(block) = pool.get_block(elements) {
return Ok(block.slice);
}
}
let slice = self.device.alloc_zeros::<f32>(elements).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to allocate GPU memory: {}", e),
"cuda_alloc",
)
})?;
{
let mut pool = self.memory_pool.lock().expect("Lock poisoned");
pool.total_allocated += size;
pool.peak_memory = pool.peak_memory.max(pool.total_allocated);
}
Ok(slice)
}
pub fn copy_to_gpu(&self, tensor: &Tensor) -> Result<CudaSlice<f32>> {
let data = tensor.data_f32()?;
let mut gpu_slice = self.allocate_memory(data.len() * 4)?;
self.device.htod_copy_into(data, &mut gpu_slice).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy data to GPU: {}", e),
"cuda_htod",
)
})?;
Ok(gpu_slice)
}
pub fn copy_from_gpu(&self, gpu_slice: &CudaSlice<f32>, tensor: &mut Tensor) -> Result<()> {
let mut data = vec![0.0f32; gpu_slice.len()];
self.device.dtoh_sync_copy_into(gpu_slice, &mut data).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy data from GPU: {}", e),
"cuda_dtoh",
)
})?;
tensor.set_data_f32(&data)?;
Ok(())
}
pub fn matmul(&self, a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = c.shape();
if a_shape.len() != 2 || b_shape.len() != 2 || c_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Matrix multiplication requires 2D tensors",
"CudaImpl::matmul",
));
}
let m = a_shape[0] as u32;
let k = a_shape[1] as u32;
let n = b_shape[1] as u32;
let kernel_source = self.generate_optimized_matmul_kernel(m, k, n);
let kernel = self.compile_kernel(
"matmul_kernel",
&kernel_source,
((n + 15) / 16, (m + 15) / 16, 1),
(16, 16, 1),
)?;
let a_gpu = self.copy_to_gpu(a)?;
let b_gpu = self.copy_to_gpu(b)?;
let mut c_gpu = self.allocate_memory(c_shape[0] * c_shape[1] * 4)?;
let launch_config = LaunchConfig {
grid_dim: kernel.grid_config,
block_dim: kernel.block_config,
shared_mem_bytes: kernel.shared_memory,
};
let args = (&a_gpu, &b_gpu, &mut c_gpu, m, k, n);
unsafe {
kernel.func.launch(launch_config, args).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to launch CUDA kernel: {}", e),
"kernel_launch",
)
})?;
}
self.device.synchronize().map_err(|e| {
TrustformersError::hardware_error(
&format!("CUDA synchronization failed: {}", e),
"cuda_synchronize_matmul",
)
})?;
self.copy_from_gpu(&c_gpu, c)?;
Ok(())
}
pub fn flash_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
) -> Result<()> {
let q_shape = query.shape();
let batch_size = q_shape[0] as u32;
let seq_len = q_shape[1] as u32;
let head_dim = q_shape[2] as u32;
let kernel_source = self.generate_flash_attention_kernel(batch_size, seq_len, head_dim);
let kernel = self.compile_kernel(
"flash_attention_kernel",
&kernel_source,
(batch_size, seq_len, 1),
(256, 1, 1),
)?;
let q_gpu = self.copy_to_gpu(query)?;
let k_gpu = self.copy_to_gpu(key)?;
let v_gpu = self.copy_to_gpu(value)?;
let mut o_gpu = self.allocate_memory(output.memory_usage())?;
let launch_config = LaunchConfig {
grid_dim: kernel.grid_config,
block_dim: kernel.block_config,
shared_mem_bytes: 48 * 1024, };
let args = (
&q_gpu, &k_gpu, &v_gpu, &mut o_gpu, batch_size, seq_len, head_dim,
);
unsafe {
kernel.func.launch(launch_config, args).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to launch Flash Attention kernel: {}", e),
"flash_attention_launch",
)
})?;
}
self.device.synchronize().map_err(|e| {
TrustformersError::hardware_error(
&format!("CUDA synchronization failed: {}", e),
"cuda_synchronize_flash_attention",
)
})?;
self.copy_from_gpu(&o_gpu, output)?;
Ok(())
}
pub fn fused_gelu(&self, input: &Tensor, output: &mut Tensor, approximate: bool) -> Result<()> {
let shape = input.shape();
let total_elements = shape.iter().product::<usize>();
let kernel_source = self.generate_fused_gelu_kernel(approximate);
let grid_size = ((total_elements + 255) / 256) as u32;
let kernel = self.compile_kernel(
"fused_gelu_kernel",
&kernel_source,
(grid_size, 1, 1),
(256, 1, 1),
)?;
let input_data = self.device.htod_copy(input.data()?).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy input to device: {}", e),
"cuda_htod_gelu",
)
})?;
let mut output_data = self.allocate_memory(total_elements * 4)?;
unsafe {
kernel
.func
.launch(
LaunchConfig {
grid_dim: kernel.grid_config,
block_dim: kernel.block_config,
shared_mem_bytes: kernel.shared_memory,
},
(&input_data, &mut output_data, total_elements as u32),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to launch GELU kernel: {}", e),
"gelu_launch",
)
})?;
}
self.device.synchronize().map_err(|e| {
TrustformersError::hardware_error(
&format!("CUDA synchronization failed: {}", e),
"cuda_synchronize_gelu",
)
})?;
let result_data = self.device.dtoh_sync_copy(&output_data).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy result from device: {}", e),
"cuda_dtoh_gelu",
)
})?;
*output = Tensor::from_vec(result_data, &shape)?;
Ok(())
}
pub fn fused_bias_activation(
&self,
input: &Tensor,
bias: &Tensor,
output: &mut Tensor,
activation: &str,
) -> Result<()> {
let shape = input.shape();
let total_elements = shape.iter().product::<usize>();
let bias_size = bias.shape().iter().product::<usize>();
let kernel_source = self.generate_fused_bias_activation_kernel(activation);
let grid_size = ((total_elements + 255) / 256) as u32;
let kernel = self.compile_kernel(
"fused_bias_activation_kernel",
&kernel_source,
(grid_size, 1, 1),
(256, 1, 1),
)?;
let input_data = self.device.htod_copy(input.data()?).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy input to device: {}", e),
"cuda_htod_bias",
)
})?;
let bias_data = self.device.htod_copy(bias.data()?).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy bias to device: {}", e),
"cuda_htod_bias",
)
})?;
let mut output_data = self.allocate_memory(total_elements * 4)?;
unsafe {
kernel
.func
.launch(
LaunchConfig {
grid_dim: kernel.grid_config,
block_dim: kernel.block_config,
shared_mem_bytes: kernel.shared_memory,
},
(
&input_data,
&bias_data,
&mut output_data,
total_elements as u32,
bias_size as u32,
),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to launch bias activation kernel: {}", e),
"bias_activation_launch",
)
})?;
}
self.device.synchronize().map_err(|e| {
TrustformersError::hardware_error(
&format!("CUDA synchronization failed: {}", e),
"cuda_synchronize_bias",
)
})?;
let result_data = self.device.dtoh_sync_copy(&output_data).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to copy result from device: {}", e),
"cuda_dtoh_bias",
)
})?;
*output = Tensor::from_vec(result_data, &shape)?;
Ok(())
}
fn generate_fused_gelu_kernel(&self, approximate: bool) -> String {
if approximate {
r#"
extern "C" __global__ void fused_gelu_kernel(const float* input, float* output, unsigned int n) {
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = input[idx];
float x3 = x * x * x;
float arg = 0.7978845608f * (x + 0.044715f * x3); // sqrt(2/π) ≈ 0.7978845608
float tanh_val = tanhf(arg);
output[idx] = 0.5f * x * (1.0f + tanh_val);
}
}
"#
.to_string()
} else {
r#"
extern "C" __global__ void fused_gelu_kernel(const float* input, float* output, unsigned int n) {
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = input[idx];
float scaled = x * 0.7071067812f; // 1/sqrt(2) ≈ 0.7071067812
float erf_val = erff(scaled);
output[idx] = 0.5f * x * (1.0f + erf_val);
}
}
"#
.to_string()
}
}
fn generate_fused_bias_activation_kernel(&self, activation: &str) -> String {
let activation_code = match activation {
"relu" => "fmaxf(value, 0.0f)",
"gelu" => "0.5f * value * (1.0f + tanhf(0.7978845608f * (value + 0.044715f * value * value * value)))",
"silu" => "value / (1.0f + expf(-value))", "tanh" => "tanhf(value)",
"none" => "value",
_ => "value", };
format!(
r#"
extern "C" __global__ void fused_bias_activation_kernel(
const float* input,
const float* bias,
float* output,
unsigned int n,
unsigned int bias_size
) {{
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {{
float bias_val = bias[idx % bias_size]; // Broadcast bias
float value = input[idx] + bias_val;
output[idx] = {};
}}
}}
"#,
activation_code
)
}
fn generate_optimized_matmul_kernel(&self, m: u32, k: u32, n: u32) -> String {
format!(
r#"
extern "C" __global__ void matmul_kernel(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
const unsigned int M,
const unsigned int K,
const unsigned int N
) {{
// Optimized matrix multiplication with tiling
const int TILE_SIZE = 16;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int bx = blockIdx.x;
const int by = blockIdx.y;
// Calculate global thread indices
const int row = by * TILE_SIZE + ty;
const int col = bx * TILE_SIZE + tx;
// Shared memory for tiling
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
// Tile loop
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {{
// Load tile into shared memory
int a_col = t * TILE_SIZE + tx;
int b_row = t * TILE_SIZE + ty;
As[ty][tx] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
Bs[ty][tx] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
__syncthreads();
// Compute partial dot product
#pragma unroll
for (int i = 0; i < TILE_SIZE; ++i) {{
sum += As[ty][i] * Bs[i][tx];
}}
__syncthreads();
}}
// Write result
if (row < M && col < N) {{
C[row * N + col] = sum;
}}
}}
"#
)
}
fn generate_flash_attention_kernel(
&self,
batch_size: u32,
seq_len: u32,
head_dim: u32,
) -> String {
format!(
r#"
extern "C" __global__ void flash_attention_kernel(
const float* __restrict__ Q,
const float* __restrict__ K,
const float* __restrict__ V,
float* __restrict__ O,
const unsigned int batch_size,
const unsigned int seq_len,
const unsigned int head_dim
) {{
// Flash Attention implementation with memory-efficient tiling
const int batch_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int tid = threadIdx.x;
if (batch_id >= batch_size || seq_id >= seq_len) return;
// Shared memory for computation
extern __shared__ float shared_mem[];
float* shared_scores = shared_mem;
float* shared_values = shared_mem + seq_len;
// Compute QK^T scores
float max_score = -INFINITY;
for (int k = tid; k < seq_len; k += blockDim.x) {{
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {{
int q_idx = batch_id * seq_len * head_dim + seq_id * head_dim + d;
int k_idx = batch_id * seq_len * head_dim + k * head_dim + d;
score += Q[q_idx] * K[k_idx];
}}
shared_scores[k] = score;
max_score = fmaxf(max_score, score);
}}
// Reduce to find global maximum
__shared__ float max_shared[256];
max_shared[tid] = max_score;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {{
if (tid < stride) {{
max_shared[tid] = fmaxf(max_shared[tid], max_shared[tid + stride]);
}}
__syncthreads();
}}
float global_max = max_shared[0];
// Compute softmax
float sum_exp = 0.0f;
for (int k = tid; k < seq_len; k += blockDim.x) {{
float exp_score = expf(shared_scores[k] - global_max);
shared_scores[k] = exp_score;
sum_exp += exp_score;
}}
// Reduce sum
__shared__ float sum_shared[256];
sum_shared[tid] = sum_exp;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {{
if (tid < stride) {{
sum_shared[tid] += sum_shared[tid + stride];
}}
__syncthreads();
}}
float global_sum = sum_shared[0];
// Normalize attention weights
for (int k = tid; k < seq_len; k += blockDim.x) {{
shared_scores[k] /= global_sum;
}}
__syncthreads();
// Compute output
for (int d = tid; d < head_dim; d += blockDim.x) {{
float output_val = 0.0f;
for (int k = 0; k < seq_len; k++) {{
int v_idx = batch_id * seq_len * head_dim + k * head_dim + d;
output_val += shared_scores[k] * V[v_idx];
}}
int o_idx = batch_id * seq_len * head_dim + seq_id * head_dim + d;
O[o_idx] = output_val;
}}
}}
"#
)
}
pub fn device_info(&self) -> String {
format!(
"CUDA Device: {}, Properties: Available",
self.device.name().unwrap_or_else(|_| "Unknown".to_string())
)
}
pub fn memory_stats(&self) -> (usize, usize) {
let pool = self.memory_pool.lock().expect("Lock poisoned");
(pool.total_allocated, pool.peak_memory)
}
}
#[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
impl CudaImpl {
pub fn new() -> Result<Self> {
Err(TrustformersError::hardware_error(
"CUDA support not available on this platform",
"CudaImpl::new",
))
}
pub fn global() -> Result<&'static Arc<CudaImpl>> {
Err(TrustformersError::hardware_error(
"CUDA support not available on this platform",
"CudaImpl::global",
))
}
pub fn memory_stats(&self) -> (usize, usize) {
(0, 0)
}
pub fn matmul(&self, _a: &Tensor, _b: &Tensor, _c: &mut Tensor) -> Result<()> {
Err(TrustformersError::hardware_error(
"CUDA support not available",
"CudaImpl::matmul",
))
}
pub fn fused_gelu(
&self,
_input: &Tensor,
_output: &mut Tensor,
_approximate: bool,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"CUDA support not available",
"CudaImpl::fused_gelu",
))
}
pub fn fused_bias_activation(
&self,
_input: &Tensor,
_bias: &Tensor,
_output: &mut Tensor,
_activation: &str,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"CUDA support not available",
"CudaImpl::fused_bias_activation",
))
}
pub fn flash_attention(
&self,
_query: &Tensor,
_key: &Tensor,
_value: &Tensor,
_output: &mut Tensor,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"CUDA support not available",
"CudaImpl::flash_attention",
))
}
pub fn device_info(&self) -> String {
"CUDA not available".to_string()
}
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl MemoryPool {
pub fn new() -> Self {
Self {
available_blocks: Vec::new(),
allocated_blocks: HashMap::new(),
total_allocated: 0,
peak_memory: 0,
}
}
pub fn get_block(&mut self, elements: usize) -> Option<MemoryBlock> {
let pos = self.available_blocks.iter().position(|block| block.size >= elements * 4)?;
Some(self.available_blocks.remove(pos))
}
pub fn return_block(&mut self, block: MemoryBlock) {
self.available_blocks.push(block);
}
}
pub mod api {
use super::*;
pub fn init_cuda() -> Result<()> {
CudaImpl::global().map(|_| ())
}
pub fn cuda_matmul(a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
CudaImpl::global()?.matmul(a, b, c)
}
pub fn cuda_flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
) -> Result<()> {
CudaImpl::global()?.flash_attention(query, key, value, output)
}
pub fn cuda_device_info() -> Result<String> {
Ok(CudaImpl::global()?.device_info())
}
pub fn cuda_memory_stats() -> Result<(usize, usize)> {
Ok(CudaImpl::global()?.memory_stats())
}
pub fn cuda_fused_gelu(input: &Tensor, output: &mut Tensor, approximate: bool) -> Result<()> {
CudaImpl::global()?.fused_gelu(input, output, approximate)
}
pub fn cuda_fused_bias_activation(
input: &Tensor,
bias: &Tensor,
output: &mut Tensor,
activation: &str,
) -> Result<()> {
CudaImpl::global()?.fused_bias_activation(input, bias, output, activation)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_cuda_initialization() {
match CudaImpl::new() {
Ok(_) => println!("CUDA initialized successfully"),
Err(_) => println!("CUDA not available, skipping test"),
}
}
#[test]
fn test_cuda_matmul() {
if let Ok(cuda) = CudaImpl::new() {
let a = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
let b = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
let mut c = Tensor::zeros(&[4, 4]).expect("Failed to create zero tensor");
cuda.matmul(&a, &b, &mut c).expect("Matrix multiplication failed");
let data = c.data_f32().expect("operation failed in test");
assert!(data.iter().all(|&x| (x - 4.0).abs() < 1e-6));
}
}
#[test]
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
fn test_memory_pool() {
let mut pool = MemoryPool::new();
assert_eq!(pool.total_allocated, 0);
assert_eq!(pool.peak_memory, 0);
}
#[test]
#[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
fn test_cuda_new_returns_error_without_feature() {
let result = CudaImpl::new();
assert!(result.is_err());
let err_msg = format!("{}", result.err().expect("expected error"));
assert!(err_msg.contains("CUDA support not available"));
}
#[test]
#[cfg(not(all(feature = "cuda", any(target_os = "linux", target_os = "windows"))))]
fn test_cuda_global_returns_error_without_feature() {
let result = CudaImpl::global();
assert!(result.is_err());
}
#[test]
fn test_cuda_init_api() {
let result = api::init_cuda();
let _ = result;
}
#[test]
fn test_cuda_device_info_api() {
let result = api::cuda_device_info();
let _ = result;
}
#[test]
fn test_cuda_memory_stats_api() {
let result = api::cuda_memory_stats();
let _ = result;
}
#[test]
fn test_cuda_matmul_api() {
let a = Tensor::ones(&[2, 2]).expect("tensor failed");
let b = Tensor::ones(&[2, 2]).expect("tensor failed");
let mut c = Tensor::zeros(&[2, 2]).expect("tensor failed");
let result = api::cuda_matmul(&a, &b, &mut c);
let _ = result;
}
#[test]
fn test_cuda_flash_attention_api() {
let q = Tensor::ones(&[1, 4, 8]).expect("tensor failed");
let k = Tensor::ones(&[1, 4, 8]).expect("tensor failed");
let v = Tensor::ones(&[1, 4, 8]).expect("tensor failed");
let mut out = Tensor::zeros(&[1, 4, 8]).expect("tensor failed");
let result = api::cuda_flash_attention(&q, &k, &v, &mut out);
let _ = result;
}
#[test]
fn test_cuda_fused_gelu_api() {
let input = Tensor::ones(&[4, 4]).expect("tensor failed");
let mut output = Tensor::zeros(&[4, 4]).expect("tensor failed");
let result = api::cuda_fused_gelu(&input, &mut output, false);
let _ = result;
}
#[test]
fn test_cuda_fused_gelu_approximate_api() {
let input = Tensor::ones(&[4, 4]).expect("tensor failed");
let mut output = Tensor::zeros(&[4, 4]).expect("tensor failed");
let result = api::cuda_fused_gelu(&input, &mut output, true);
let _ = result;
}
#[test]
fn test_cuda_fused_bias_activation_relu_api() {
let input = Tensor::ones(&[4, 4]).expect("tensor failed");
let bias = Tensor::ones(&[4]).expect("tensor failed");
let mut output = Tensor::zeros(&[4, 4]).expect("tensor failed");
let result = api::cuda_fused_bias_activation(&input, &bias, &mut output, "relu");
let _ = result;
}
#[test]
fn test_cuda_fused_bias_activation_gelu_api() {
let input = Tensor::ones(&[4, 4]).expect("tensor failed");
let bias = Tensor::ones(&[4]).expect("tensor failed");
let mut output = Tensor::zeros(&[4, 4]).expect("tensor failed");
let result = api::cuda_fused_bias_activation(&input, &bias, &mut output, "gelu");
let _ = result;
}
#[test]
fn test_cuda_fused_bias_activation_silu_api() {
let input = Tensor::ones(&[2, 3]).expect("tensor failed");
let bias = Tensor::ones(&[3]).expect("tensor failed");
let mut output = Tensor::zeros(&[2, 3]).expect("tensor failed");
let result = api::cuda_fused_bias_activation(&input, &bias, &mut output, "silu");
let _ = result;
}
#[test]
fn test_tensor_creation_for_cuda() {
let query = Tensor::ones(&[2, 8, 64]).expect("tensor failed");
let key = Tensor::ones(&[2, 8, 64]).expect("tensor failed");
let value = Tensor::ones(&[2, 8, 64]).expect("tensor failed");
assert_eq!(query.shape(), &[2, 8, 64]);
assert_eq!(key.shape(), &[2, 8, 64]);
assert_eq!(value.shape(), &[2, 8, 64]);
}
#[test]
fn test_tensor_shapes_for_matmul() {
let a = Tensor::ones(&[32, 64]).expect("tensor failed");
let b = Tensor::ones(&[64, 128]).expect("tensor failed");
let c = Tensor::zeros(&[32, 128]).expect("tensor failed");
assert_eq!(a.shape(), &[32, 64]);
assert_eq!(b.shape(), &[64, 128]);
assert_eq!(c.shape(), &[32, 128]);
}
#[test]
fn test_tensor_memory_usage_for_gpu() {
let small = Tensor::ones(&[4, 4]).expect("tensor failed");
let large = Tensor::ones(&[64, 64]).expect("tensor failed");
assert!(large.memory_usage() > small.memory_usage());
}
#[test]
fn test_multiple_tensor_sizes_for_kernels() {
let sizes: Vec<(usize, usize)> = vec![
(1, 1), (4, 4), (16, 16), (32, 32), (64, 64), (128, 128), (256, 256),
];
for (rows, cols) in sizes {
let t = Tensor::ones(&[rows, cols]).expect("tensor failed");
assert_eq!(t.shape(), &[rows, cols]);
}
}
#[test]
fn test_tensor_from_vec_for_cuda_input() {
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let tensor = Tensor::from_vec(data.clone(), &[8, 8]).expect("tensor failed");
let retrieved = tensor.data_f32().expect("data failed");
for (original, retrieved_val) in data.iter().zip(retrieved.iter()) {
assert!((original - retrieved_val).abs() < f32::EPSILON);
}
}
#[test]
fn test_cuda_batch_dims_validation() {
let batch_sizes = [1, 2, 4, 8];
let seq_lengths = [16, 32, 64];
let head_dims = [32, 64];
for &batch in &batch_sizes {
for &seq_len in &seq_lengths {
for &head_dim in &head_dims {
let t = Tensor::ones(&[batch, seq_len, head_dim]).expect("tensor failed");
assert_eq!(t.shape()[0], batch);
assert_eq!(t.shape()[1], seq_len);
assert_eq!(t.shape()[2], head_dim);
}
}
}
}
#[test]
fn test_tensor_zeros_for_output_buffer() {
let output = Tensor::zeros(&[16, 32]).expect("tensor failed");
let data = output.data_f32().expect("data failed");
assert!(data.iter().all(|&x| x.abs() < f32::EPSILON));
}
#[test]
fn test_tensor_ones_for_weight_init() {
let weight = Tensor::ones(&[64, 128]).expect("tensor failed");
let data = weight.data_f32().expect("data failed");
assert!(data.iter().all(|&x| (x - 1.0).abs() < f32::EPSILON));
}
#[test]
fn test_cuda_stub_matmul_error_message() {
match CudaImpl::new() {
Ok(cuda) => {
let a = Tensor::ones(&[2, 2]).expect("tensor failed");
let b = Tensor::ones(&[2, 2]).expect("tensor failed");
let mut c = Tensor::zeros(&[2, 2]).expect("tensor failed");
let _ = cuda.matmul(&a, &b, &mut c);
},
Err(e) => {
let msg = format!("{}", e);
assert!(!msg.is_empty());
},
}
}
#[test]
fn test_cuda_stub_device_info() {
match CudaImpl::new() {
Ok(cuda) => {
let info = cuda.device_info();
assert!(!info.is_empty());
},
Err(_) => {
},
}
}
#[test]
fn test_cuda_stub_memory_stats() {
match CudaImpl::new() {
Ok(cuda) => {
let (total, peak) = cuda.memory_stats();
assert!(peak >= total || total == 0);
},
Err(_) => {
},
}
}
}