use morok_dtype::DType;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TcOpt {
Upcast(usize),
Local(usize),
}
impl TcOpt {
pub const fn dim(&self) -> usize {
match self {
Self::Upcast(dim) | Self::Local(dim) => *dim,
}
}
pub const fn is_upcast(&self) -> bool {
matches!(self, Self::Upcast(_))
}
pub const fn is_local(&self) -> bool {
matches!(self, Self::Local(_))
}
}
impl std::fmt::Display for TcOpt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Upcast(dim) => write!(f, "u{}", dim),
Self::Local(dim) => write!(f, "l{}", dim),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SwizzleAxis {
Upcast(usize),
Local(usize),
Reduce(usize),
}
impl std::fmt::Display for SwizzleAxis {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Upcast(idx) => write!(f, "u{}", idx),
Self::Local(idx) => write!(f, "l{}", idx),
Self::Reduce(idx) => write!(f, "r{}", idx),
}
}
}
#[derive(Debug, Clone)]
pub struct Renderer {
pub device: String,
pub has_local: bool,
pub has_shared: bool,
pub has_threads: bool,
pub shared_max: usize,
pub global_max: Option<Vec<usize>>,
pub local_max: Option<usize>,
pub upcast_max: usize,
pub buffer_max: Option<usize>,
pub tensor_cores: Vec<TensorCore>,
}
impl Renderer {
pub fn cpu() -> Self {
let cores = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8);
Self {
device: "CPU".to_string(),
has_local: false,
has_shared: false,
has_threads: true,
shared_max: 0,
global_max: Some(vec![cores]), local_max: None,
upcast_max: 16, buffer_max: None,
tensor_cores: vec![],
}
}
pub fn cuda() -> Self {
Self::cuda_sm80(false) }
pub fn cuda_sm75() -> Self {
Self {
device: "CUDA_SM75".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 49152,
global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::sm75_tensor_cores(),
}
}
pub fn cuda_sm80(allow_tf32: bool) -> Self {
Self {
device: "CUDA_SM80".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 49152,
global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::sm80_tensor_cores(allow_tf32),
}
}
pub fn cuda_sm89(allow_tf32: bool) -> Self {
Self {
device: "CUDA_SM89".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 49152,
global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::sm89_tensor_cores(allow_tf32),
}
}
pub fn metal() -> Self {
Self {
device: "Metal".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 32768, global_max: None,
local_max: Some(1024),
upcast_max: 4, buffer_max: Some(31), tensor_cores: TensorCore::metal_tensor_cores(),
}
}
pub fn apple_amx() -> Self {
Self {
device: "AppleAMX".to_string(),
has_local: false, has_shared: false,
has_threads: true, shared_max: 0,
global_max: Some(vec![256]),
local_max: None,
upcast_max: 16,
buffer_max: None,
tensor_cores: TensorCore::amx_tensor_cores(),
}
}
pub fn is_amx(&self) -> bool {
self.device == "AppleAMX"
}
pub fn amd_rdna3() -> Self {
Self {
device: "AMD_RDNA3".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::rdna3_tensor_cores(),
}
}
pub fn amd_rdna4() -> Self {
Self {
device: "AMD_RDNA4".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 65536,
global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::rdna4_tensor_cores(),
}
}
pub fn amd_cdna3() -> Self {
Self {
device: "AMD_CDNA3".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::cdna3_tensor_cores(),
}
}
pub fn amd_cdna4() -> Self {
Self {
device: "AMD_CDNA4".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 65536,
global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(1024),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::cdna4_tensor_cores(),
}
}
pub fn intel_xe() -> Self {
Self {
device: "IntelXe".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
local_max: Some(512),
upcast_max: 8,
buffer_max: None,
tensor_cores: TensorCore::intel_tensor_cores(),
}
}
pub fn webgpu() -> Self {
Self {
device: "WebGPU".to_string(),
has_local: true,
has_shared: true,
has_threads: false,
shared_max: 16384, global_max: Some(vec![65535, 65535, 65535]),
local_max: Some(256),
upcast_max: 4,
buffer_max: Some(8), tensor_cores: vec![],
}
}
}
#[derive(Debug, Clone)]
pub struct TensorCore {
pub dims: (usize, usize, usize),
pub threads: usize,
pub elements_per_thread: (usize, usize, usize),
pub dtype_in: DType,
pub dtype_out: DType,
pub opts: SmallVec<[TcOpt; 8]>,
#[allow(clippy::type_complexity)]
pub swizzle: (
(SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
(SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
),
pub pack_a: bool,
pub tile_grid: (usize, usize),
}
pub struct TcConfig {
dims: (usize, usize, usize),
threads: usize,
ept: (usize, usize, usize),
opts: &'static [TcOpt],
swizzle_a: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
swizzle_b: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
pack_a: bool,
tile_grid: (usize, usize),
}
impl TcConfig {
pub fn build(&self, dtype_in: DType, dtype_out: DType) -> TensorCore {
TensorCore {
dims: self.dims,
threads: self.threads,
elements_per_thread: self.ept,
dtype_in,
dtype_out,
opts: self.opts.iter().copied().collect(),
swizzle: (
(
self.swizzle_a.0.iter().copied().collect(),
self.swizzle_a.1.iter().copied().collect(),
self.swizzle_a.2.iter().copied().collect(),
),
(
self.swizzle_b.0.iter().copied().collect(),
self.swizzle_b.1.iter().copied().collect(),
self.swizzle_b.2.iter().copied().collect(),
),
),
pack_a: self.pack_a,
tile_grid: self.tile_grid,
}
}
}
use SwizzleAxis::{Local as SL, Reduce as R, Upcast as SU};
use TcOpt::{Local as L, Upcast as U};
pub const CUDA_81616: TcConfig = TcConfig {
dims: (8, 16, 16),
threads: 32,
ept: (8, 4, 4),
opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[SU(1), R(3)], &[SL(0), SL(1), SU(0), R(0)]),
swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[R(0), R(3)], &[SL(2), SL(3), SL(4), SU(1)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const CUDA_81632: TcConfig = TcConfig {
dims: (8, 16, 32),
threads: 32,
ept: (16, 8, 4),
opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
swizzle_a: (&[R(2), R(3), SL(2), SL(3), SL(4)], &[SU(1), R(4)], &[SL(0), SL(1), SU(0), R(0), R(1)]),
swizzle_b: (&[R(2), R(3), SU(0), SL(0), SL(1)], &[R(1), R(4)], &[SL(2), SL(3), SL(4), SU(1), R(0)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const CUDA_8168: TcConfig = TcConfig {
dims: (8, 16, 8),
threads: 32,
ept: (4, 2, 4),
opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[R(0), SU(1)], &[SL(0), SL(1), SU(0)]),
swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[SU(1), R(0)], &[SL(2), SL(3), SL(4)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const CUDA_8168_TF32: TcConfig = TcConfig {
dims: (8, 16, 8),
threads: 32,
ept: (4, 2, 4),
opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
swizzle_a: (&[R(0), R(1), SL(2), SL(3), SL(4)], &[SU(1), R(2)], &[SL(0), SL(1), SU(0)]),
swizzle_b: (&[R(0), R(1), SU(0), SL(0), SL(1)], &[SU(1), R(2)], &[SL(2), SL(3), SL(4)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const AMD_RDNA3: TcConfig = TcConfig {
dims: (16, 16, 16),
threads: 32,
ept: (16, 16, 8),
opts: &[L(0), L(0), L(0), L(0), L(1), U(1), U(1), U(1)],
swizzle_a: (&[SL(4), SU(0), SU(1), SU(2), SL(0)], &[R(1), R(2), R(3)], &[SL(1), SL(2), SL(3), R(0)]),
swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), SL(4)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const AMD_RDNA4: TcConfig = TcConfig {
dims: (16, 16, 16),
threads: 32,
ept: (8, 8, 8),
opts: &[L(0), L(0), L(0), L(0), U(1), U(1), U(1), L(1)],
swizzle_a: (&[SU(0), SU(1), SU(2), SL(4), R(2)], &[R(0), R(1), R(3)], &[SL(0), SL(1), SL(2), SL(3)]),
swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2)], &[R(0), R(1), R(3)], &[SL(4), SU(0), SU(1), SU(2)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const AMD_CDNA_161616: TcConfig = TcConfig {
dims: (16, 16, 16),
threads: 64,
ept: (4, 4, 4),
opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(2), R(3)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3)]),
swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2), R(3)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const AMD_CDNA_161632: TcConfig = TcConfig {
dims: (16, 16, 32),
threads: 64,
ept: (8, 8, 4),
opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(3), R(4)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3), R(2)]),
swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(3), R(4)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1), R(2)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const METAL_888: TcConfig = TcConfig {
dims: (8, 8, 8),
threads: 32,
ept: (2, 2, 2),
opts: &[U(0), L(0), L(1), L(1), L(0), L(1)],
swizzle_a: (&[R(1), SL(1), SL(2), R(2), SL(4)], &[R(0)], &[SU(0), SL(0), SL(3)]),
swizzle_b: (&[SL(0), R(0), R(1), SL(3), R(2)], &[SU(0)], &[SL(1), SL(2), SL(4)]),
pack_a: false,
tile_grid: (1, 1),
};
pub const APPLE_AMX: TcConfig = TcConfig {
dims: (16, 16, 1),
threads: 1,
ept: (16, 16, 256),
opts: &[U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1)],
swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7)], &[]),
swizzle_b: (&[], &[SU(4), SU(5), SU(6), SU(7), SU(0), SU(1), SU(2), SU(3)], &[]),
pack_a: true,
tile_grid: (1, 1),
};
pub const APPLE_AMX_F16_F32: TcConfig = TcConfig {
dims: (32, 32, 1),
threads: 1,
ept: (32, 32, 1024),
opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
pack_a: true,
tile_grid: (1, 1),
};
pub const APPLE_AMX_F16: TcConfig = TcConfig {
dims: (32, 32, 1),
threads: 1,
ept: (32, 32, 1024),
opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
pack_a: true,
tile_grid: (1, 1),
};
pub const APPLE_AMX_F64: TcConfig = TcConfig {
dims: (8, 8, 1),
threads: 1,
ept: (8, 8, 64),
opts: &[U(0), U(0), U(0), U(1), U(1), U(1)],
swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5)], &[]),
swizzle_b: (&[], &[SU(3), SU(4), SU(5), SU(0), SU(1), SU(2)], &[]),
pack_a: true,
tile_grid: (1, 1),
};
pub const APPLE_AMX_I16: TcConfig = TcConfig {
dims: (32, 32, 1),
threads: 1,
ept: (32, 32, 1024),
opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
pack_a: true,
tile_grid: (1, 1),
};
pub const INTEL_XE_8816: TcConfig = TcConfig {
dims: (8, 8, 16),
threads: 8,
ept: (16, 16, 8),
opts: &[L(0), L(0), L(0), U(1), U(1), U(1)],
swizzle_a: (&[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2)], &[SL(0), SL(1), SL(2), R(0)]),
swizzle_b: (&[SL(0), SL(1), SL(2)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
pack_a: false,
tile_grid: (1, 1),
};
impl TensorCore {
pub fn get_reduce_axes(&self) -> Vec<(usize, usize)> {
(0..(self.dims.2 as f64).log2().floor() as usize).map(|i| (i, 2)).collect()
}
pub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
(vec![0, 1], vec![0, 1], vec![0, 1])
}
pub fn sm75_tensor_cores() -> Vec<TensorCore> {
vec![CUDA_8168.build(DType::Float16, DType::Float32), CUDA_8168.build(DType::Float16, DType::Float16)]
}
pub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
let mut tcs = vec![
CUDA_81616.build(DType::Float16, DType::Float32),
CUDA_81616.build(DType::BFloat16, DType::Float32),
CUDA_81616.build(DType::Float16, DType::Float16),
CUDA_8168.build(DType::Float16, DType::Float32),
CUDA_8168.build(DType::Float16, DType::Float16),
];
if allow_tf32 {
tcs.push(CUDA_8168_TF32.build(DType::Float32, DType::Float32));
}
tcs
}
pub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
let mut tcs = Self::sm80_tensor_cores(allow_tf32);
tcs.push(CUDA_81632.build(DType::FP8E4M3, DType::Float32));
tcs.push(CUDA_81632.build(DType::FP8E5M2, DType::Float32));
tcs
}
pub fn rdna3_tensor_cores() -> Vec<TensorCore> {
vec![
AMD_RDNA3.build(DType::Float16, DType::Float32),
AMD_RDNA3.build(DType::Float16, DType::Float16),
AMD_RDNA3.build(DType::BFloat16, DType::Float32),
]
}
pub fn rdna4_tensor_cores() -> Vec<TensorCore> {
vec![
AMD_RDNA4.build(DType::Float16, DType::Float32),
AMD_RDNA4.build(DType::Float16, DType::Float16),
AMD_RDNA4.build(DType::BFloat16, DType::Float32),
AMD_RDNA4.build(DType::BFloat16, DType::BFloat16),
]
}
pub fn cdna3_tensor_cores() -> Vec<TensorCore> {
vec![
AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
AMD_CDNA_161616.build(DType::Float16, DType::Float32),
AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
]
}
pub fn cdna4_tensor_cores() -> Vec<TensorCore> {
vec![
AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
AMD_CDNA_161632.build(DType::Float16, DType::Float32),
AMD_CDNA_161632.build(DType::BFloat16, DType::Float32),
AMD_CDNA_161616.build(DType::Float16, DType::Float32),
AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
]
}
pub fn metal_tensor_cores() -> Vec<TensorCore> {
vec![
METAL_888.build(DType::Float32, DType::Float32),
METAL_888.build(DType::Float16, DType::Float32),
METAL_888.build(DType::Float16, DType::Float16),
METAL_888.build(DType::BFloat16, DType::Float32),
METAL_888.build(DType::BFloat16, DType::BFloat16),
]
}
pub fn amx_tensor_cores() -> Vec<TensorCore> {
vec![
APPLE_AMX.build(DType::Float32, DType::Float32),
APPLE_AMX_F16.build(DType::Float16, DType::Float16),
APPLE_AMX_F16_F32.build(DType::Float16, DType::Float32), APPLE_AMX_F64.build(DType::Float64, DType::Float64),
APPLE_AMX_I16.build(DType::Int16, DType::Int16),
]
}
pub fn intel_tensor_cores() -> Vec<TensorCore> {
vec![INTEL_XE_8816.build(DType::Float16, DType::Float32)]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_renderer_cpu() {
let r = Renderer::cpu();
assert_eq!(r.device, "CPU");
assert!(!r.has_local);
assert!(r.has_threads);
assert_eq!(r.tensor_cores.len(), 0);
}
#[test]
fn test_renderer_cuda() {
let r = Renderer::cuda();
assert_eq!(r.device, "CUDA_SM80"); assert!(r.has_local);
assert!(r.has_shared);
assert!(!r.has_threads);
assert!(r.shared_max > 0);
assert!(!r.tensor_cores.is_empty());
}
#[test]
fn test_tensor_core_cuda() {
let tc = CUDA_81616.build(DType::Float16, DType::Float32);
assert_eq!(tc.dims, (8, 16, 16));
assert_eq!(tc.threads, 32);
assert_eq!(tc.dtype_in, DType::Float16);
assert_eq!(tc.dtype_out, DType::Float32);
assert!(!tc.opts.is_empty());
}
}