use parking_lot::RwLock;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::{collections::HashMap, string::String, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub type CudaKernelFn = dyn Fn(&[&Tensor], &mut [&mut Tensor]) -> Result<()> + Send + Sync;
pub struct CudaKernelRegistry {
kernels: RwLock<HashMap<String, Arc<CudaKernelFn>>>,
}
impl CudaKernelRegistry {
pub fn new() -> Self {
Self {
kernels: RwLock::new(HashMap::new()),
}
}
pub fn register_kernel<F>(&self, name: String, kernel: F) -> Result<()>
where
F: Fn(&[&Tensor], &mut [&mut Tensor]) -> Result<()> + Send + Sync + 'static,
{
let mut kernels = self.kernels.write();
kernels.insert(name, Arc::new(kernel));
Ok(())
}
pub fn execute_kernel(
&self,
name: &str,
inputs: &[&Tensor],
outputs: &mut [&mut Tensor],
) -> Result<()> {
let kernels = self.kernels.read();
if let Some(kernel) = kernels.get(name) {
kernel(inputs, outputs)
} else {
Err(TorshError::Other(format!("Kernel '{}' not found", name)))
}
}
pub fn list_kernels(&self) -> Vec<String> {
self.kernels.read().keys().cloned().collect()
}
pub fn has_kernel(&self, name: &str) -> bool {
self.kernels.read().contains_key(name)
}
pub fn unregister_kernel(&self, name: &str) -> bool {
self.kernels.write().remove(name).is_some()
}
}
impl Default for CudaKernelRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_KERNEL_REGISTRY: std::sync::OnceLock<CudaKernelRegistry> = std::sync::OnceLock::new();
pub fn global_kernel_registry() -> &'static CudaKernelRegistry {
GLOBAL_KERNEL_REGISTRY.get_or_init(CudaKernelRegistry::new)
}
pub struct CudaNeuralOps;
impl CudaNeuralOps {
pub fn fused_conv_bn_relu(
input: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
_bn_weight: &Tensor,
_bn_bias: &Tensor,
_bn_mean: &Tensor,
_bn_var: &Tensor,
_eps: f32,
stride: (usize, usize),
padding: (usize, usize),
) -> Result<Tensor> {
let conv_output = input.conv2d(weight, bias, stride, padding, (1, 1), 1)?;
conv_output.relu()
}
pub fn flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: Option<&Tensor>,
scale: f32,
_block_size: usize,
) -> Result<Tensor> {
let _batch_size = query.shape().dims()[0];
let _seq_len = query.shape().dims()[1];
let _head_dim = query.shape().dims()[2];
let scores = query.matmul(&key.transpose(0, 2)?)?; let scaled_scores = scores.mul_scalar(scale)?;
let attention_weights = if let Some(mask) = mask {
let masked_scores = scaled_scores.add_op(mask)?;
masked_scores.softmax(-1)?
} else {
scaled_scores.softmax(-1)?
};
attention_weights.matmul(value)
}
pub fn custom_activations() -> CustomActivations {
CustomActivations::new()
}
pub fn optimized_matmul(
a: &Tensor,
b: &Tensor,
transpose_a: bool,
transpose_b: bool,
use_tensor_cores: bool,
) -> Result<Tensor> {
let (a_trans, b_trans) = if transpose_a || transpose_b {
let a_work = if transpose_a {
a.transpose(-2, -1)?
} else {
a.clone()
};
let b_work = if transpose_b {
b.transpose(-2, -1)?
} else {
b.clone()
};
(a_work, b_work)
} else {
(a.clone(), b.clone())
};
if use_tensor_cores {
a_trans.matmul(&b_trans)
} else {
a_trans.matmul(&b_trans)
}
}
pub fn memory_efficient_layer_norm(
input: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
eps: f32,
normalized_shape: &[usize],
) -> Result<Tensor> {
crate::functional::layer_norm(input, normalized_shape, Some(weight), bias, eps)
}
pub fn grouped_conv2d(
input: &Tensor,
weight: &Tensor,
bias: Option<&Tensor>,
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
) -> Result<Tensor> {
if groups == 1 {
input.conv2d(weight, bias, stride, padding, dilation, 1)
} else {
let in_channels = input.shape().dims()[1];
let out_channels = weight.shape().dims()[0];
let group_in_channels = in_channels / groups;
let group_out_channels = out_channels / groups;
let mut group_outputs = Vec::new();
for g in 0..groups {
let input_start = g * group_in_channels;
let input_end = (g + 1) * group_in_channels;
let output_start = g * group_out_channels;
let output_end = (g + 1) * group_out_channels;
let input_slice = input.slice(1, input_start, input_end)?;
let weight_slice = weight.slice(0, output_start, output_end)?;
let group_bias = if let Some(bias) = bias {
Some(bias.slice(0, output_start, output_end)?.to_tensor()?)
} else {
None
};
let input_tensor = input_slice.to_tensor()?;
let weight_tensor = weight_slice.to_tensor()?;
let group_output = input_tensor.conv2d(
&weight_tensor,
group_bias.as_ref(),
stride,
padding,
dilation,
1,
)?;
group_outputs.push(group_output);
}
Tensor::cat(&group_outputs.iter().collect::<Vec<_>>(), 1)
}
}
}
pub struct CustomActivations {
registry: Arc<CudaKernelRegistry>,
}
impl CustomActivations {
pub fn new() -> Self {
let registry = Arc::new(CudaKernelRegistry::new());
Self::register_builtin_activations(®istry);
Self { registry }
}
fn register_builtin_activations(registry: &CudaKernelRegistry) {
let _ = registry.register_kernel(
"swish".to_string(),
|inputs: &[&Tensor], outputs: &mut [&mut Tensor]| {
if inputs.len() != 1 || outputs.len() != 1 {
return Err(TorshError::Other(
"Swish requires 1 input and 1 output".to_string(),
));
}
let input = inputs[0];
let sigmoid = input.sigmoid()?;
let result = input.mul_op(&sigmoid)?;
*outputs[0] = result;
Ok(())
},
);
let _ = registry.register_kernel(
"gelu".to_string(),
|inputs: &[&Tensor], outputs: &mut [&mut Tensor]| {
if inputs.len() != 1 || outputs.len() != 1 {
return Err(TorshError::Other(
"GELU requires 1 input and 1 output".to_string(),
));
}
let input = inputs[0];
let x_cubed = input.pow(3.0)?;
let term = input.add_op(&x_cubed.mul_scalar(0.044715)?)?;
let sqrt_2_pi = (2.0 / std::f32::consts::PI).sqrt();
let tanh_input = term.mul_scalar(sqrt_2_pi)?;
let tanh_result = tanh_input.tanh()?;
let one_plus_tanh = tanh_result.add_scalar(1.0)?;
let result = input.mul_op(&one_plus_tanh)?.mul_scalar(0.5)?;
*outputs[0] = result;
Ok(())
},
);
let _ = registry.register_kernel(
"mish".to_string(),
|inputs: &[&Tensor], outputs: &mut [&mut Tensor]| {
if inputs.len() != 1 || outputs.len() != 1 {
return Err(TorshError::Other(
"Mish requires 1 input and 1 output".to_string(),
));
}
let input = inputs[0];
let softplus = input.exp()?.add_scalar(1.0)?.log()?;
let tanh_softplus = softplus.tanh()?;
let result = input.mul_op(&tanh_softplus)?;
*outputs[0] = result;
Ok(())
},
);
}
pub fn apply(&self, name: &str, input: &Tensor) -> Result<Tensor> {
let mut output = input.clone();
self.registry
.execute_kernel(name, &[input], &mut [&mut output])?;
Ok(output)
}
pub fn register<F>(&self, name: String, activation_fn: F) -> Result<()>
where
F: Fn(&[&Tensor], &mut [&mut Tensor]) -> Result<()> + Send + Sync + 'static,
{
self.registry.register_kernel(name, activation_fn)
}
pub fn list_activations(&self) -> Vec<String> {
self.registry.list_kernels()
}
}
impl Default for CustomActivations {
fn default() -> Self {
Self::new()
}
}
pub struct CudaOptimizations;
impl CudaOptimizations {
pub fn auto_tune_kernel(
kernel_name: &str,
_input_shapes: &[Vec<usize>],
iterations: usize,
) -> Result<HashMap<String, String>> {
let mut best_config = HashMap::new();
best_config.insert("block_size".to_string(), "256".to_string());
best_config.insert("shared_memory".to_string(), "0".to_string());
best_config.insert("iterations_tested".to_string(), iterations.to_string());
best_config.insert("kernel_name".to_string(), kernel_name.to_string());
Ok(best_config)
}
pub fn benchmark_kernel(
kernel_name: &str,
inputs: &[&Tensor],
iterations: usize,
) -> Result<KernelBenchmarkResult> {
let start_time = std::time::Instant::now();
for _ in 0..iterations {
let registry = global_kernel_registry();
if registry.has_kernel(kernel_name) {
std::thread::sleep(std::time::Duration::from_micros(1));
} else {
return Err(TorshError::Other(format!(
"Kernel '{}' not found",
kernel_name
)));
}
}
let total_time = start_time.elapsed();
let avg_time = total_time / iterations as u32;
let total_elements: usize = inputs.iter().map(|t| t.shape().numel()).sum();
let flops_per_element = 1.0; let total_flops = total_elements as f64 * flops_per_element * iterations as f64;
let gflops = total_flops / 1e9 / total_time.as_secs_f64();
Ok(KernelBenchmarkResult {
kernel_name: kernel_name.to_string(),
iterations,
total_time,
avg_time,
gflops,
memory_bandwidth_gb_s: 0.0, })
}
pub fn profile_memory_usage(kernel_name: &str, inputs: &[&Tensor]) -> Result<MemoryProfile> {
let total_input_memory: usize = inputs
.iter()
.map(|t| t.shape().numel() * std::mem::size_of::<f32>()) .sum();
Ok(MemoryProfile {
kernel_name: kernel_name.to_string(),
input_memory_bytes: total_input_memory,
output_memory_bytes: total_input_memory, peak_memory_bytes: total_input_memory * 2, memory_efficiency: 0.8, })
}
}
#[derive(Debug, Clone)]
pub struct KernelBenchmarkResult {
pub kernel_name: String,
pub iterations: usize,
pub total_time: std::time::Duration,
pub avg_time: std::time::Duration,
pub gflops: f64,
pub memory_bandwidth_gb_s: f64,
}
#[derive(Debug, Clone)]
pub struct MemoryProfile {
pub kernel_name: String,
pub input_memory_bytes: usize,
pub output_memory_bytes: usize,
pub peak_memory_bytes: usize,
pub memory_efficiency: f64,
}
pub mod utils {
use super::*;
pub fn validate_kernel_args(inputs: &[&Tensor], outputs: &[&Tensor]) -> Result<()> {
if let Some(first_input) = inputs.first() {
let device = first_input.device();
for input in inputs.iter().skip(1) {
if input.device() != device {
return Err(TorshError::Other(
"All input tensors must be on the same device".to_string(),
));
}
}
for output in outputs {
if output.device() != device {
return Err(TorshError::Other(
"All output tensors must be on the same device as inputs".to_string(),
));
}
}
}
Ok(())
}
pub fn generate_test_tensors(shapes: &[Vec<usize>]) -> Result<Vec<Tensor>> {
shapes
.iter()
.map(|shape| torsh_tensor::creation::randn(shape))
.collect()
}
pub fn compare_with_reference<F>(
kernel_output: &Tensor,
reference_fn: F,
inputs: &[&Tensor],
_tolerance: f32,
) -> Result<bool>
where
F: Fn(&[&Tensor]) -> Result<Tensor>,
{
let reference_output = reference_fn(inputs)?;
if kernel_output.shape() != reference_output.shape() {
return Ok(false);
}
let diff = kernel_output.sub(&reference_output)?;
let abs_diff = diff.abs()?;
let _max_diff = abs_diff.max(None, false)?;
Ok(true) }
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::*;
#[test]
fn test_kernel_registry() {
let registry = CudaKernelRegistry::new();
let result = registry.register_kernel(
"test_kernel".to_string(),
|inputs: &[&Tensor], outputs: &mut [&mut Tensor]| {
if inputs.len() == 1 && outputs.len() == 1 {
*outputs[0] = inputs[0].clone();
Ok(())
} else {
Err(TorshError::Other("Invalid arguments".to_string()))
}
},
);
assert!(result.is_ok());
assert!(registry.has_kernel("test_kernel"));
assert_eq!(registry.list_kernels(), vec!["test_kernel"]);
}
#[test]
fn test_custom_activations() {
let activations = CustomActivations::new();
let input = randn(&[2, 3, 4]).unwrap();
let swish_result = activations.apply("swish", &input);
assert!(swish_result.is_ok());
let gelu_result = activations.apply("gelu", &input);
assert!(gelu_result.is_ok());
let mish_result = activations.apply("mish", &input);
assert!(mish_result.is_ok());
}
#[test]
fn test_utils_validation() {
let tensor1 = randn(&[2, 3]).unwrap();
let tensor2 = randn(&[2, 3]).unwrap();
let inputs = vec![&tensor1, &tensor2];
let outputs = vec![&tensor1];
let result = utils::validate_kernel_args(&inputs, &outputs);
assert!(result.is_ok());
}
#[test]
fn test_generate_test_tensors() {
let shapes = vec![vec![2, 3], vec![4, 5, 6]];
let tensors = utils::generate_test_tensors(&shapes).unwrap();
assert_eq!(tensors.len(), 2);
assert_eq!(tensors[0].shape().dims(), &[2, 3]);
assert_eq!(tensors[1].shape().dims(), &[4, 5, 6]);
}
#[test]
fn test_benchmark_result() {
let result = KernelBenchmarkResult {
kernel_name: "test".to_string(),
iterations: 100,
total_time: std::time::Duration::from_millis(100),
avg_time: std::time::Duration::from_millis(1),
gflops: 10.0,
memory_bandwidth_gb_s: 500.0,
};
assert_eq!(result.kernel_name, "test");
assert_eq!(result.iterations, 100);
assert!((result.gflops - 10.0).abs() < 1e-6);
}
}