use crate::Device;
use torsh_core::dtype::DType;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
#[derive(Debug)]
pub struct Kernel {
pub id: usize,
pub device: Device,
pub name: String,
pub descriptor: KernelDescriptor,
pub handle: KernelHandle,
pub metadata: KernelMetadata,
}
impl Kernel {
pub fn new(
id: usize,
device: Device,
name: String,
descriptor: KernelDescriptor,
handle: KernelHandle,
metadata: KernelMetadata,
) -> Self {
Self {
id,
device,
name,
descriptor,
handle,
metadata,
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
pub fn handle(&self) -> &KernelHandle {
&self.handle
}
}
#[derive(Debug, Clone)]
pub struct KernelDescriptor {
pub name: String,
pub source: KernelSource,
pub compile_options: Vec<String>,
pub parameters: Vec<KernelParameter>,
pub workgroup_size_hint: Option<(u32, u32, u32)>,
pub cache: bool,
}
impl KernelDescriptor {
pub fn new(name: String, source: KernelSource) -> Self {
Self {
name,
source,
compile_options: Vec::new(),
parameters: Vec::new(),
workgroup_size_hint: None,
cache: true,
}
}
pub fn with_compile_option(mut self, option: String) -> Self {
self.compile_options.push(option);
self
}
pub fn with_parameter(mut self, param: KernelParameter) -> Self {
self.parameters.push(param);
self
}
pub fn with_workgroup_size_hint(mut self, size: (u32, u32, u32)) -> Self {
self.workgroup_size_hint = Some(size);
self
}
pub fn without_cache(mut self) -> Self {
self.cache = false;
self
}
}
#[derive(Debug, Clone)]
pub enum KernelSource {
Source {
code: String,
language: KernelLanguage,
},
Bytecode {
data: Vec<u8>,
format: BytecodeFormat,
},
SpirV { data: Vec<u32> },
Binary { data: Vec<u8>, platform: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KernelLanguage {
Wgsl,
Hlsl,
Glsl,
Metal,
Cuda,
OpenCl,
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BytecodeFormat {
SpirV,
Dxil,
MetalAir,
Ptx,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct KernelParameter {
pub name: String,
pub param_type: KernelParameterType,
pub binding: Option<u32>,
pub readonly: bool,
}
impl KernelParameter {
pub fn buffer(name: String, dtype: DType, readonly: bool) -> Self {
Self {
name,
param_type: KernelParameterType::Buffer { dtype },
binding: None,
readonly,
}
}
pub fn uniform(name: String, dtype: DType) -> Self {
Self {
name,
param_type: KernelParameterType::Uniform { dtype },
binding: None,
readonly: true,
}
}
pub fn with_binding(mut self, binding: u32) -> Self {
self.binding = Some(binding);
self
}
}
#[derive(Debug, Clone)]
pub enum KernelParameterType {
Buffer { dtype: DType },
Uniform { dtype: DType },
Texture { dimensions: u32, dtype: DType },
Sampler,
Scalar { dtype: DType },
}
#[derive(Debug)]
pub enum KernelHandle {
Cpu { function: *const () },
#[cfg(feature = "cuda")]
Cuda { module: u64, function: u64 },
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
Metal { library_id: u64, function_id: u64 },
#[cfg(feature = "webgpu")]
WebGpu {
shader_module_id: String,
entry_point: String,
},
Generic {
handle: Box<dyn std::any::Any + Send + Sync>,
},
}
impl Clone for KernelHandle {
fn clone(&self) -> Self {
match self {
KernelHandle::Cpu { function } => KernelHandle::Cpu {
function: *function,
},
#[cfg(feature = "cuda")]
KernelHandle::Cuda { module, function } => KernelHandle::Cuda {
module: *module,
function: *function,
},
#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
KernelHandle::Metal {
library_id,
function_id,
} => KernelHandle::Metal {
library_id: *library_id,
function_id: *function_id,
},
#[cfg(feature = "webgpu")]
KernelHandle::WebGpu {
shader_module_id,
entry_point,
} => KernelHandle::WebGpu {
shader_module_id: shader_module_id.clone(),
entry_point: entry_point.clone(),
},
KernelHandle::Generic { .. } => {
panic!("Cannot clone Generic kernel handles")
}
}
}
}
unsafe impl Send for KernelHandle {}
unsafe impl Sync for KernelHandle {}
#[derive(Debug, Clone)]
pub struct KernelMetadata {
pub compile_time_ms: f64,
pub binary_size: usize,
pub registers_per_thread: Option<u32>,
pub shared_memory_usage: Option<usize>,
pub max_workgroup_size: Option<(u32, u32, u32)>,
pub compiler_version: String,
pub warnings: Vec<String>,
pub performance_hints: Vec<String>,
}
impl Default for KernelMetadata {
fn default() -> Self {
Self {
compile_time_ms: 0.0,
binary_size: 0,
registers_per_thread: None,
shared_memory_usage: None,
max_workgroup_size: None,
compiler_version: "Unknown".to_string(),
warnings: Vec::new(),
performance_hints: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct KernelLaunchConfig {
pub workgroup_size: (u32, u32, u32),
pub workgroup_count: (u32, u32, u32),
pub shared_memory_size: Option<usize>,
pub stream_id: Option<usize>,
}
impl KernelLaunchConfig {
pub fn linear(global_size: u32, workgroup_size: Option<u32>) -> Self {
let wg_size = workgroup_size.unwrap_or(256);
let wg_count = global_size.div_ceil(wg_size);
Self {
workgroup_size: (wg_size, 1, 1),
workgroup_count: (wg_count, 1, 1),
shared_memory_size: None,
stream_id: None,
}
}
pub fn grid_2d(global_size: (u32, u32), workgroup_size: Option<(u32, u32)>) -> Self {
let wg_size = workgroup_size.unwrap_or((16, 16));
let wg_count = (
global_size.0.div_ceil(wg_size.0),
global_size.1.div_ceil(wg_size.1),
);
Self {
workgroup_size: (wg_size.0, wg_size.1, 1),
workgroup_count: (wg_count.0, wg_count.1, 1),
shared_memory_size: None,
stream_id: None,
}
}
pub fn grid_3d(global_size: (u32, u32, u32), workgroup_size: Option<(u32, u32, u32)>) -> Self {
let wg_size = workgroup_size.unwrap_or((8, 8, 8));
let wg_count = (
global_size.0.div_ceil(wg_size.0),
global_size.1.div_ceil(wg_size.1),
global_size.2.div_ceil(wg_size.2),
);
Self {
workgroup_size: wg_size,
workgroup_count: wg_count,
shared_memory_size: None,
stream_id: None,
}
}
pub fn with_shared_memory(mut self, size: usize) -> Self {
self.shared_memory_size = Some(size);
self
}
pub fn with_stream(mut self, stream_id: usize) -> Self {
self.stream_id = Some(stream_id);
self
}
pub fn total_threads(&self) -> u64 {
(self.workgroup_size.0 as u64)
* (self.workgroup_size.1 as u64)
* (self.workgroup_size.2 as u64)
* (self.workgroup_count.0 as u64)
* (self.workgroup_count.1 as u64)
* (self.workgroup_count.2 as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::{Device, DeviceInfo};
use torsh_core::{device::DeviceType, dtype::DType};
fn create_test_device() -> Device {
let info = DeviceInfo::default();
Device::new(0, DeviceType::Cpu, "Test CPU".to_string(), info)
}
#[test]
fn test_kernel_descriptor_creation() {
let source = KernelSource::Source {
code: "void main() {}".to_string(),
language: KernelLanguage::Hlsl,
};
let desc = KernelDescriptor::new("test_kernel".to_string(), source);
assert_eq!(desc.name, "test_kernel");
assert!(desc.compile_options.is_empty());
assert!(desc.parameters.is_empty());
assert_eq!(desc.workgroup_size_hint, None);
assert!(desc.cache);
}
#[test]
fn test_kernel_descriptor_builder() {
let source = KernelSource::Source {
code: "void main() {}".to_string(),
language: KernelLanguage::Cuda,
};
let param = KernelParameter::buffer("input".to_string(), DType::F32, true);
let desc = KernelDescriptor::new("complex_kernel".to_string(), source)
.with_compile_option("-O3".to_string())
.with_compile_option("--fast-math".to_string())
.with_parameter(param)
.with_workgroup_size_hint((256, 1, 1))
.without_cache();
assert_eq!(desc.name, "complex_kernel");
assert_eq!(desc.compile_options.len(), 2);
assert!(desc.compile_options.contains(&"-O3".to_string()));
assert!(desc.compile_options.contains(&"--fast-math".to_string()));
assert_eq!(desc.parameters.len(), 1);
assert_eq!(desc.workgroup_size_hint, Some((256, 1, 1)));
assert!(!desc.cache);
}
#[test]
fn test_kernel_source_variants() {
let source1 = KernelSource::Source {
code: "vertex main() {}".to_string(),
language: KernelLanguage::Metal,
};
let source2 = KernelSource::Bytecode {
data: vec![0x12, 0x34, 0x56, 0x78],
format: BytecodeFormat::SpirV,
};
let source3 = KernelSource::SpirV {
data: vec![0x07230203, 0x00010000],
};
let source4 = KernelSource::Binary {
data: vec![0xCA, 0xFE, 0xBA, 0xBE],
platform: "cuda".to_string(),
};
match source1 {
KernelSource::Source { language, .. } => assert_eq!(language, KernelLanguage::Metal),
_ => panic!("Wrong variant"),
}
match source2 {
KernelSource::Bytecode { format, .. } => assert_eq!(format, BytecodeFormat::SpirV),
_ => panic!("Wrong variant"),
}
match source3 {
KernelSource::SpirV { .. } => {}
_ => panic!("Wrong variant"),
}
match source4 {
KernelSource::Binary { platform, .. } => assert_eq!(platform, "cuda"),
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_kernel_language_variants() {
let languages = [
KernelLanguage::Wgsl,
KernelLanguage::Hlsl,
KernelLanguage::Glsl,
KernelLanguage::Metal,
KernelLanguage::Cuda,
KernelLanguage::OpenCl,
KernelLanguage::Custom("MyLang".to_string()),
];
for (i, lang1) in languages.iter().enumerate() {
for (j, lang2) in languages.iter().enumerate() {
if i != j {
assert_ne!(lang1, lang2);
}
}
}
}
#[test]
fn test_bytecode_format_variants() {
let formats = [
BytecodeFormat::SpirV,
BytecodeFormat::Dxil,
BytecodeFormat::MetalAir,
BytecodeFormat::Ptx,
BytecodeFormat::Custom("MyFormat".to_string()),
];
for (i, format1) in formats.iter().enumerate() {
for (j, format2) in formats.iter().enumerate() {
if i != j {
assert_ne!(format1, format2);
}
}
}
}
#[test]
fn test_kernel_parameter_creation() {
let buffer_param = KernelParameter::buffer("data".to_string(), DType::F32, false);
assert_eq!(buffer_param.name, "data");
assert!(!buffer_param.readonly);
assert_eq!(buffer_param.binding, None);
match buffer_param.param_type {
KernelParameterType::Buffer { dtype } => assert_eq!(dtype, DType::F32),
_ => panic!("Wrong parameter type"),
}
let uniform_param = KernelParameter::uniform("scale".to_string(), DType::F32);
assert_eq!(uniform_param.name, "scale");
assert!(uniform_param.readonly);
match uniform_param.param_type {
KernelParameterType::Uniform { dtype } => assert_eq!(dtype, DType::F32),
_ => panic!("Wrong parameter type"),
}
let bound_param = buffer_param.with_binding(0);
assert_eq!(bound_param.binding, Some(0));
}
#[test]
fn test_kernel_parameter_types() {
let buffer_type = KernelParameterType::Buffer { dtype: DType::I32 };
let uniform_type = KernelParameterType::Uniform { dtype: DType::F64 };
let texture_type = KernelParameterType::Texture {
dimensions: 2,
dtype: DType::F32,
};
let sampler_type = KernelParameterType::Sampler;
let scalar_type = KernelParameterType::Scalar { dtype: DType::U8 };
assert_ne!(
std::mem::discriminant(&buffer_type),
std::mem::discriminant(&uniform_type)
);
assert_ne!(
std::mem::discriminant(&uniform_type),
std::mem::discriminant(&texture_type)
);
assert_ne!(
std::mem::discriminant(&texture_type),
std::mem::discriminant(&sampler_type)
);
assert_ne!(
std::mem::discriminant(&sampler_type),
std::mem::discriminant(&scalar_type)
);
}
#[test]
fn test_kernel_handle_cpu() {
let handle = KernelHandle::Cpu {
function: std::ptr::null(),
};
match handle {
KernelHandle::Cpu { function } => assert!(function.is_null()),
_ => panic!("Wrong handle type"),
}
}
#[test]
fn test_kernel_metadata_default() {
let metadata = KernelMetadata::default();
assert_eq!(metadata.compile_time_ms, 0.0);
assert_eq!(metadata.binary_size, 0);
assert_eq!(metadata.registers_per_thread, None);
assert_eq!(metadata.shared_memory_usage, None);
assert_eq!(metadata.max_workgroup_size, None);
assert_eq!(metadata.compiler_version, "Unknown");
assert!(metadata.warnings.is_empty());
assert!(metadata.performance_hints.is_empty());
}
#[test]
fn test_kernel_creation() {
let device = create_test_device();
let source = KernelSource::Source {
code: "void main() {}".to_string(),
language: KernelLanguage::Hlsl,
};
let desc = KernelDescriptor::new("test".to_string(), source);
let handle = KernelHandle::Cpu {
function: std::ptr::null(),
};
let metadata = KernelMetadata::default();
let kernel = Kernel::new(
1,
device.clone(),
"test_kernel".to_string(),
desc,
handle,
metadata,
);
assert_eq!(kernel.id(), 1);
assert_eq!(kernel.name(), "test_kernel");
assert_eq!(kernel.device().id(), device.id());
}
#[test]
fn test_kernel_launch_config_linear() {
let config = KernelLaunchConfig::linear(1000, Some(64));
assert_eq!(config.workgroup_size, (64, 1, 1));
assert_eq!(config.workgroup_count, (16, 1, 1)); assert_eq!(config.shared_memory_size, None);
assert_eq!(config.stream_id, None);
assert_eq!(config.total_threads(), 64 * 16);
let config_default = KernelLaunchConfig::linear(1000, None);
assert_eq!(config_default.workgroup_size, (256, 1, 1));
assert_eq!(config_default.workgroup_count, (4, 1, 1)); }
#[test]
fn test_kernel_launch_config_2d() {
let config = KernelLaunchConfig::grid_2d((100, 50), Some((10, 5)));
assert_eq!(config.workgroup_size, (10, 5, 1));
assert_eq!(config.workgroup_count, (10, 10, 1)); assert_eq!(config.total_threads(), 10 * 5 * 10 * 10);
let config_default = KernelLaunchConfig::grid_2d((100, 50), None);
assert_eq!(config_default.workgroup_size, (16, 16, 1));
assert_eq!(config_default.workgroup_count, (7, 4, 1)); }
#[test]
fn test_kernel_launch_config_3d() {
let config = KernelLaunchConfig::grid_3d((64, 32, 16), Some((8, 4, 2)));
assert_eq!(config.workgroup_size, (8, 4, 2));
assert_eq!(config.workgroup_count, (8, 8, 8)); assert_eq!(config.total_threads(), 8 * 4 * 2 * 8 * 8 * 8);
let config_default = KernelLaunchConfig::grid_3d((64, 32, 16), None);
assert_eq!(config_default.workgroup_size, (8, 8, 8));
assert_eq!(config_default.workgroup_count, (8, 4, 2)); }
#[test]
fn test_kernel_launch_config_builder() {
let config = KernelLaunchConfig::linear(1000, Some(128))
.with_shared_memory(4096)
.with_stream(1);
assert_eq!(config.shared_memory_size, Some(4096));
assert_eq!(config.stream_id, Some(1));
}
}