Skip to main content

morok_schedule/optimizer/
renderer.rs

1//! Backend renderer capabilities and tensor core configurations.
2//!
3//! This module defines the interface between the optimizer and backend code generators.
4//! It describes what optimizations a backend supports (local memory, threading, etc.)
5//! and provides tensor core configurations for hardware-accelerated matrix multiplication.
6
7use morok_dtype::DType;
8use smallvec::SmallVec;
9
10/// Tensor core optimization operation.
11///
12/// Represents a single transformation step when applying tensor cores.
13/// Each operation splits a dimension and assigns it to a new axis type.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum TcOpt {
16    /// Upcast (vectorize) the specified dimension (0=N, 1=M, 2=K).
17    Upcast(usize),
18    /// Move the specified dimension to local memory (0=N, 1=M, 2=K).
19    Local(usize),
20}
21
22impl TcOpt {
23    /// Get the dimension index (0=N, 1=M, 2=K).
24    pub const fn dim(&self) -> usize {
25        match self {
26            Self::Upcast(dim) | Self::Local(dim) => *dim,
27        }
28    }
29
30    /// Returns true if this is an upcast operation.
31    pub const fn is_upcast(&self) -> bool {
32        matches!(self, Self::Upcast(_))
33    }
34
35    /// Returns true if this is a local operation.
36    pub const fn is_local(&self) -> bool {
37        matches!(self, Self::Local(_))
38    }
39}
40
41impl std::fmt::Display for TcOpt {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            Self::Upcast(dim) => write!(f, "u{}", dim),
45            Self::Local(dim) => write!(f, "l{}", dim),
46        }
47    }
48}
49
50/// Swizzle axis specifier for tensor core data layout transformations.
51///
52/// Describes axis references in swizzle patterns that remap data layouts
53/// for optimal tensor core memory access. Unlike TcOpt (operations),
54/// SwizzleAxis describes axis identities in the remapping pattern.
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub enum SwizzleAxis {
57    /// Upcast axis with index (0, 1, 2, ...).
58    Upcast(usize),
59    /// Local axis with index (0, 1, 2, ...).
60    Local(usize),
61    /// Reduce axis with index (0, 1, 2, ...).
62    Reduce(usize),
63}
64
65impl std::fmt::Display for SwizzleAxis {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            Self::Upcast(idx) => write!(f, "u{}", idx),
69            Self::Local(idx) => write!(f, "l{}", idx),
70            Self::Reduce(idx) => write!(f, "r{}", idx),
71        }
72    }
73}
74
75/// Backend renderer capabilities.
76///
77/// Describes what features and optimizations a particular backend supports.
78/// Used by the optimizer to determine valid transformations and enforce device limits.
79#[derive(Debug, Clone)]
80pub struct Renderer {
81    /// Backend device identifier (e.g., "CUDA", "Metal", "CPU").
82    pub device: String,
83
84    /// Whether the backend supports local/shared memory (GPU workgroups).
85    pub has_local: bool,
86
87    /// Whether the backend supports shared memory across threads in a workgroup.
88    pub has_shared: bool,
89
90    /// Whether the backend supports CPU-style threading (not GPU threads).
91    pub has_threads: bool,
92
93    /// Maximum shared memory size in bytes.
94    ///
95    /// Used to validate GROUP/GROUPTOP optimizations that allocate shared memory.
96    /// Typical values: 48KB-96KB for modern GPUs.
97    pub shared_max: usize,
98
99    /// Maximum global work dimensions [x, y, z].
100    ///
101    /// Maximum size for each global thread dimension.
102    /// Used to validate thread count in THREAD optimization.
103    /// None if unlimited or not applicable.
104    pub global_max: Option<Vec<usize>>,
105
106    /// Maximum local work group size.
107    ///
108    /// Maximum number of threads in a workgroup (product of local dimensions).
109    /// Typical values: 256-1024 for GPUs.
110    pub local_max: Option<usize>,
111
112    /// Maximum vectorization width (upcast limit).
113    ///
114    /// Maximum number of elements that can be processed as a vector.
115    /// Typical values: 8-16 for SIMD, 4 for GPU float4.
116    pub upcast_max: usize,
117
118    /// Maximum number of buffers/arguments per kernel.
119    ///
120    /// Some backends have limits on kernel arguments.
121    /// Metal: 31, WebGPU: 8, CUDA: typically unlimited.
122    pub buffer_max: Option<usize>,
123
124    /// Available tensor core configurations.
125    ///
126    /// Hardware-accelerated matrix multiplication units with specific size constraints.
127    /// Empty if tensor cores not available.
128    pub tensor_cores: Vec<TensorCore>,
129}
130
131impl Renderer {
132    /// Create a CPU renderer configuration.
133    pub fn cpu() -> Self {
134        let cores = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8);
135        Self {
136            device: "CPU".to_string(),
137            has_local: false,
138            has_shared: false,
139            has_threads: true,
140            shared_max: 0,
141            global_max: Some(vec![cores]), // Actual available CPU cores
142            local_max: None,
143            upcast_max: 16, // AVX512 can do 16-wide float
144            buffer_max: None,
145            tensor_cores: vec![],
146        }
147    }
148
149    /// Create a CUDA GPU renderer configuration (SM80/Ampere by default).
150    ///
151    /// For specific architectures, use `cuda_sm75()`, `cuda_sm80()`, or `cuda_sm89()`.
152    pub fn cuda() -> Self {
153        Self::cuda_sm80(false) // Default to SM80 (A100) without TF32
154    }
155
156    /// Create a CUDA GPU renderer for SM75 (Turing - RTX 20xx, T4).
157    pub fn cuda_sm75() -> Self {
158        Self {
159            device: "CUDA_SM75".to_string(),
160            has_local: true,
161            has_shared: true,
162            has_threads: false,
163            shared_max: 49152,
164            global_max: Some(vec![2147483647, 65535, 65535]),
165            local_max: Some(1024),
166            upcast_max: 8,
167            buffer_max: None,
168            tensor_cores: TensorCore::sm75_tensor_cores(),
169        }
170    }
171
172    /// Create a CUDA GPU renderer for SM80 (Ampere - A100, RTX 30xx).
173    pub fn cuda_sm80(allow_tf32: bool) -> Self {
174        Self {
175            device: "CUDA_SM80".to_string(),
176            has_local: true,
177            has_shared: true,
178            has_threads: false,
179            shared_max: 49152,
180            global_max: Some(vec![2147483647, 65535, 65535]),
181            local_max: Some(1024),
182            upcast_max: 8,
183            buffer_max: None,
184            tensor_cores: TensorCore::sm80_tensor_cores(allow_tf32),
185        }
186    }
187
188    /// Create a CUDA GPU renderer for SM89 (Hopper - H100).
189    pub fn cuda_sm89(allow_tf32: bool) -> Self {
190        Self {
191            device: "CUDA_SM89".to_string(),
192            has_local: true,
193            has_shared: true,
194            has_threads: false,
195            shared_max: 49152,
196            global_max: Some(vec![2147483647, 65535, 65535]),
197            local_max: Some(1024),
198            upcast_max: 8,
199            buffer_max: None,
200            tensor_cores: TensorCore::sm89_tensor_cores(allow_tf32),
201        }
202    }
203
204    /// Create a Metal GPU renderer configuration (Apple M1/M2/M3).
205    pub fn metal() -> Self {
206        Self {
207            device: "Metal".to_string(),
208            has_local: true,
209            has_shared: true,
210            has_threads: false,
211            shared_max: 32768, // 32KB for Metal
212            global_max: None,
213            local_max: Some(1024),
214            upcast_max: 4,        // float4 for Metal
215            buffer_max: Some(31), // Metal has 31 buffer argument limit
216            tensor_cores: TensorCore::metal_tensor_cores(),
217        }
218    }
219
220    /// Create an Apple AMX renderer configuration (M1/M2/M3 matrix coprocessor).
221    pub fn apple_amx() -> Self {
222        Self {
223            device: "AppleAMX".to_string(),
224            has_local: false, // AMX doesn't use traditional local memory
225            has_shared: false,
226            has_threads: true, // CPU-style threading
227            shared_max: 0,
228            global_max: Some(vec![256]),
229            local_max: None,
230            upcast_max: 16,
231            buffer_max: None,
232            tensor_cores: TensorCore::amx_tensor_cores(),
233        }
234    }
235
236    /// Whether this renderer is for Apple AMX (CPU matrix coprocessor).
237    pub fn is_amx(&self) -> bool {
238        self.device == "AppleAMX"
239    }
240
241    /// Create an AMD RDNA3 GPU renderer (RX 7000 series).
242    pub fn amd_rdna3() -> Self {
243        Self {
244            device: "AMD_RDNA3".to_string(),
245            has_local: true,
246            has_shared: true,
247            has_threads: false,
248            shared_max: 65536, // 64KB for RDNA3
249            global_max: Some(vec![2147483647, 65535, 65535]),
250            local_max: Some(1024),
251            upcast_max: 8,
252            buffer_max: None,
253            tensor_cores: TensorCore::rdna3_tensor_cores(),
254        }
255    }
256
257    /// Create an AMD RDNA4 GPU renderer.
258    pub fn amd_rdna4() -> Self {
259        Self {
260            device: "AMD_RDNA4".to_string(),
261            has_local: true,
262            has_shared: true,
263            has_threads: false,
264            shared_max: 65536,
265            global_max: Some(vec![2147483647, 65535, 65535]),
266            local_max: Some(1024),
267            upcast_max: 8,
268            buffer_max: None,
269            tensor_cores: TensorCore::rdna4_tensor_cores(),
270        }
271    }
272
273    /// Create an AMD CDNA3 GPU renderer (MI300 series).
274    pub fn amd_cdna3() -> Self {
275        Self {
276            device: "AMD_CDNA3".to_string(),
277            has_local: true,
278            has_shared: true,
279            has_threads: false,
280            shared_max: 65536, // 64KB for CDNA
281            global_max: Some(vec![2147483647, 65535, 65535]),
282            local_max: Some(1024),
283            upcast_max: 8,
284            buffer_max: None,
285            tensor_cores: TensorCore::cdna3_tensor_cores(),
286        }
287    }
288
289    /// Create an AMD CDNA4 GPU renderer.
290    pub fn amd_cdna4() -> Self {
291        Self {
292            device: "AMD_CDNA4".to_string(),
293            has_local: true,
294            has_shared: true,
295            has_threads: false,
296            shared_max: 65536,
297            global_max: Some(vec![2147483647, 65535, 65535]),
298            local_max: Some(1024),
299            upcast_max: 8,
300            buffer_max: None,
301            tensor_cores: TensorCore::cdna4_tensor_cores(),
302        }
303    }
304
305    /// Create an Intel Xe GPU renderer.
306    pub fn intel_xe() -> Self {
307        Self {
308            device: "IntelXe".to_string(),
309            has_local: true,
310            has_shared: true,
311            has_threads: false,
312            shared_max: 65536, // 64KB for Xe
313            global_max: Some(vec![2147483647, 65535, 65535]),
314            local_max: Some(512),
315            upcast_max: 8,
316            buffer_max: None,
317            tensor_cores: TensorCore::intel_tensor_cores(),
318        }
319    }
320
321    /// Create a WebGPU renderer configuration.
322    pub fn webgpu() -> Self {
323        Self {
324            device: "WebGPU".to_string(),
325            has_local: true,
326            has_shared: true,
327            has_threads: false,
328            shared_max: 16384, // 16KB typical for WebGPU
329            global_max: Some(vec![65535, 65535, 65535]),
330            local_max: Some(256),
331            upcast_max: 4,
332            buffer_max: Some(8), // WebGPU has 8 buffer limit in some implementations
333            tensor_cores: vec![],
334        }
335    }
336}
337
338/// Tensor core configuration for hardware-accelerated matrix multiplication.
339///
340/// Describes a specific matrix multiplication unit with fixed dimensions and data types.
341/// Based on NVIDIA's WMMA (Warp Matrix Multiply-Accumulate) API and similar accelerators.
342///
343/// # Matrix Dimensions
344///
345/// Tensor cores perform: `C[M,N] += A[M,K] × B[K,N]`
346/// - `dims.0` (N): Number of output columns
347/// - `dims.1` (M): Number of output rows
348/// - `dims.2` (K): Reduction dimension size
349///
350/// # Example
351///
352/// NVIDIA Tensor Core 16x16x16:
353/// - Processes 16×16 output tile
354/// - Accumulates across 16 K elements
355/// - Uses 32 threads (warp size)
356/// - Each thread handles multiple elements via opts
357#[derive(Debug, Clone)]
358pub struct TensorCore {
359    /// Matrix dimensions (N, M, K).
360    pub dims: (usize, usize, usize),
361
362    /// Number of threads required (typically warp size: 32 for CUDA, 64 for AMD).
363    pub threads: usize,
364
365    /// Elements per thread in each dimension (N, M, K).
366    ///
367    /// Describes how the matrix is distributed across threads.
368    /// Example: (2, 2, 4) means each thread handles 2×2 output elements
369    /// and processes 4 K elements.
370    pub elements_per_thread: (usize, usize, usize),
371
372    /// Input matrix data type (A and B matrices).
373    pub dtype_in: DType,
374
375    /// Output/accumulator data type (C matrix).
376    pub dtype_out: DType,
377
378    /// Optimization sequence for tensor core application.
379    ///
380    /// A sequence of operations to transform ranges. Each operation splits
381    /// a dimension (N, M, or K) and assigns it to a new axis type.
382    ///
383    /// Example: `[Upcast(0), Local(0), Local(0), Local(1), Local(1), Local(1), Upcast(1)]`
384    /// - Upcast N once
385    /// - Local split N twice
386    /// - Local split M three times
387    /// - Upcast M once
388    ///
389    /// Uses SmallVec to avoid heap allocation for typical tensor cores (≤8 ops).
390    pub opts: SmallVec<[TcOpt; 8]>,
391
392    /// Swizzle patterns for input permutation.
393    ///
394    /// Describes how to permute input matrices to match hardware layout.
395    /// Format: ((A_local, A_upcast, A_reduce), (B_local, B_upcast, B_reduce))
396    ///
397    /// Each tuple contains axis references that describe the permutation pattern
398    /// for optimal memory access. The first tuple is for matrix A, second for B.
399    ///
400    /// Uses SmallVec to avoid heap allocation for typical swizzles (≤8 axes per vec).
401    #[allow(clippy::type_complexity)]
402    pub swizzle: (
403        (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
404        (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
405    ),
406
407    /// Pre-pack operand A into contiguous scratch buffer before the reduction loop.
408    /// Beneficial when the A operand has non-unit stride access (e.g., AMX row-major matmul).
409    pub pack_a: bool,
410
411    /// Tile grid for multi-FMA batching (tile_y_count, tile_x_count).
412    ///
413    /// When > (1, 1), the codegen emits load-pair instructions and multiple FMAs
414    /// per K iteration to compute a grid of output tiles simultaneously.
415    /// Default is (1, 1) for single-tile operation.
416    pub tile_grid: (usize, usize),
417}
418
419// ============================================================================
420// TENSOR CORE CONFIGURATION (Static Const Data)
421// ============================================================================
422
423/// Static tensor core configuration for const definitions.
424///
425/// Uses static slices instead of SmallVec for const-compatibility.
426/// Use `build()` to convert to runtime `TensorCore`.
427pub struct TcConfig {
428    dims: (usize, usize, usize),
429    threads: usize,
430    ept: (usize, usize, usize),
431    opts: &'static [TcOpt],
432    swizzle_a: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
433    swizzle_b: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
434    pack_a: bool,
435    tile_grid: (usize, usize),
436}
437
438impl TcConfig {
439    /// Build a TensorCore from static config with specified dtypes.
440    pub fn build(&self, dtype_in: DType, dtype_out: DType) -> TensorCore {
441        TensorCore {
442            dims: self.dims,
443            threads: self.threads,
444            elements_per_thread: self.ept,
445            dtype_in,
446            dtype_out,
447            opts: self.opts.iter().copied().collect(),
448            swizzle: (
449                (
450                    self.swizzle_a.0.iter().copied().collect(),
451                    self.swizzle_a.1.iter().copied().collect(),
452                    self.swizzle_a.2.iter().copied().collect(),
453                ),
454                (
455                    self.swizzle_b.0.iter().copied().collect(),
456                    self.swizzle_b.1.iter().copied().collect(),
457                    self.swizzle_b.2.iter().copied().collect(),
458                ),
459            ),
460            pack_a: self.pack_a,
461            tile_grid: self.tile_grid,
462        }
463    }
464}
465
466// Aliases for brevity in const definitions
467use SwizzleAxis::{Local as SL, Reduce as R, Upcast as SU};
468use TcOpt::{Local as L, Upcast as U};
469
470// NVIDIA CUDA Tensor Cores
471pub const CUDA_81616: TcConfig = TcConfig {
472    dims: (8, 16, 16),
473    threads: 32,
474    ept: (8, 4, 4),
475    opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
476    swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[SU(1), R(3)], &[SL(0), SL(1), SU(0), R(0)]),
477    swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[R(0), R(3)], &[SL(2), SL(3), SL(4), SU(1)]),
478    pack_a: false,
479    tile_grid: (1, 1),
480};
481
482pub const CUDA_81632: TcConfig = TcConfig {
483    dims: (8, 16, 32),
484    threads: 32,
485    ept: (16, 8, 4),
486    opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
487    swizzle_a: (&[R(2), R(3), SL(2), SL(3), SL(4)], &[SU(1), R(4)], &[SL(0), SL(1), SU(0), R(0), R(1)]),
488    swizzle_b: (&[R(2), R(3), SU(0), SL(0), SL(1)], &[R(1), R(4)], &[SL(2), SL(3), SL(4), SU(1), R(0)]),
489    pack_a: false,
490    tile_grid: (1, 1),
491};
492
493pub const CUDA_8168: TcConfig = TcConfig {
494    dims: (8, 16, 8),
495    threads: 32,
496    ept: (4, 2, 4),
497    opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
498    swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[R(0), SU(1)], &[SL(0), SL(1), SU(0)]),
499    swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[SU(1), R(0)], &[SL(2), SL(3), SL(4)]),
500    pack_a: false,
501    tile_grid: (1, 1),
502};
503
504pub const CUDA_8168_TF32: TcConfig = TcConfig {
505    dims: (8, 16, 8),
506    threads: 32,
507    ept: (4, 2, 4),
508    opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
509    swizzle_a: (&[R(0), R(1), SL(2), SL(3), SL(4)], &[SU(1), R(2)], &[SL(0), SL(1), SU(0)]),
510    swizzle_b: (&[R(0), R(1), SU(0), SL(0), SL(1)], &[SU(1), R(2)], &[SL(2), SL(3), SL(4)]),
511    pack_a: false,
512    tile_grid: (1, 1),
513};
514
515// AMD Tensor Cores
516pub const AMD_RDNA3: TcConfig = TcConfig {
517    dims: (16, 16, 16),
518    threads: 32,
519    ept: (16, 16, 8),
520    opts: &[L(0), L(0), L(0), L(0), L(1), U(1), U(1), U(1)],
521    swizzle_a: (&[SL(4), SU(0), SU(1), SU(2), SL(0)], &[R(1), R(2), R(3)], &[SL(1), SL(2), SL(3), R(0)]),
522    swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), SL(4)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
523    pack_a: false,
524    tile_grid: (1, 1),
525};
526
527pub const AMD_RDNA4: TcConfig = TcConfig {
528    dims: (16, 16, 16),
529    threads: 32,
530    ept: (8, 8, 8),
531    opts: &[L(0), L(0), L(0), L(0), U(1), U(1), U(1), L(1)],
532    swizzle_a: (&[SU(0), SU(1), SU(2), SL(4), R(2)], &[R(0), R(1), R(3)], &[SL(0), SL(1), SL(2), SL(3)]),
533    swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2)], &[R(0), R(1), R(3)], &[SL(4), SU(0), SU(1), SU(2)]),
534    pack_a: false,
535    tile_grid: (1, 1),
536};
537
538pub const AMD_CDNA_161616: TcConfig = TcConfig {
539    dims: (16, 16, 16),
540    threads: 64,
541    ept: (4, 4, 4),
542    opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
543    swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(2), R(3)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3)]),
544    swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2), R(3)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1)]),
545    pack_a: false,
546    tile_grid: (1, 1),
547};
548
549pub const AMD_CDNA_161632: TcConfig = TcConfig {
550    dims: (16, 16, 32),
551    threads: 64,
552    ept: (8, 8, 4),
553    opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
554    swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(3), R(4)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3), R(2)]),
555    swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(3), R(4)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1), R(2)]),
556    pack_a: false,
557    tile_grid: (1, 1),
558};
559
560// Apple Metal Tensor Cores
561pub const METAL_888: TcConfig = TcConfig {
562    dims: (8, 8, 8),
563    threads: 32,
564    ept: (2, 2, 2),
565    opts: &[U(0), L(0), L(1), L(1), L(0), L(1)],
566    swizzle_a: (&[R(1), SL(1), SL(2), R(2), SL(4)], &[R(0)], &[SU(0), SL(0), SL(3)]),
567    swizzle_b: (&[SL(0), R(0), R(1), SL(3), R(2)], &[SU(0)], &[SL(1), SL(2), SL(4)]),
568    pack_a: false,
569    tile_grid: (1, 1),
570};
571
572// Apple AMX (64 bytes / 4 bytes per float32 = 16 elements per register)
573// NOTE: tile_grid=(2,2) requires direct memory loads (load-pair from source matrices)
574// Temp buffer approach is incompatible with load-pair. Keep at (1,1) until fixed.
575pub const APPLE_AMX: TcConfig = TcConfig {
576    dims: (16, 16, 1),
577    threads: 1,
578    ept: (16, 16, 256),
579    opts: &[U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1)],
580    swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7)], &[]),
581    swizzle_b: (&[], &[SU(4), SU(5), SU(6), SU(7), SU(0), SU(1), SU(2), SU(3)], &[]),
582    pack_a: true,
583    tile_grid: (1, 1),
584};
585
586pub const APPLE_AMX_F16_F32: TcConfig = TcConfig {
587    dims: (32, 32, 1),
588    threads: 1,
589    ept: (32, 32, 1024),
590    opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
591    swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
592    swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
593    pack_a: true,
594    tile_grid: (1, 1),
595};
596
597pub const APPLE_AMX_F16: TcConfig = TcConfig {
598    dims: (32, 32, 1),
599    threads: 1,
600    ept: (32, 32, 1024),
601    opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
602    swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
603    swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
604    pack_a: true,
605    tile_grid: (1, 1),
606};
607
608pub const APPLE_AMX_F64: TcConfig = TcConfig {
609    dims: (8, 8, 1),
610    threads: 1,
611    ept: (8, 8, 64),
612    opts: &[U(0), U(0), U(0), U(1), U(1), U(1)],
613    swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5)], &[]),
614    swizzle_b: (&[], &[SU(3), SU(4), SU(5), SU(0), SU(1), SU(2)], &[]),
615    pack_a: true,
616    tile_grid: (1, 1),
617};
618
619// MAC16: i16×i16→i16, same geometry as FMA16
620pub const APPLE_AMX_I16: TcConfig = TcConfig {
621    dims: (32, 32, 1),
622    threads: 1,
623    ept: (32, 32, 1024),
624    opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
625    swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
626    swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
627    pack_a: true,
628    tile_grid: (1, 1),
629};
630
631// Intel Xe Tensor Cores
632pub const INTEL_XE_8816: TcConfig = TcConfig {
633    dims: (8, 8, 16),
634    threads: 8,
635    ept: (16, 16, 8),
636    opts: &[L(0), L(0), L(0), U(1), U(1), U(1)],
637    swizzle_a: (&[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2)], &[SL(0), SL(1), SL(2), R(0)]),
638    swizzle_b: (&[SL(0), SL(1), SL(2)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
639    pack_a: false,
640    tile_grid: (1, 1),
641};
642
643impl TensorCore {
644    // ===== Helper Methods =====
645
646    /// Get the axes for reduction unrolling.
647    ///
648    /// Returns pairs of (dimension_index, unroll_amount) for the K dimension.
649    /// Used during TC application to unroll the reduction dimension.
650    pub fn get_reduce_axes(&self) -> Vec<(usize, usize)> {
651        (0..(self.dims.2 as f64).log2().floor() as usize).map(|i| (i, 2)).collect()
652    }
653
654    /// Get the upcast axes configuration for WMMA construction.
655    ///
656    /// Returns axes configuration for CONTRACT operations.
657    /// Format: (A_axes, B_axes, output_axes)
658    pub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
659        // This is simplified - actual implementation depends on opts sequence
660        // For 16x16x16 WMMA: each has specific upcast patterns
661        (vec![0, 1], vec![0, 1], vec![0, 1])
662    }
663
664    // ===== Hardware-Specific Collections =====
665
666    /// Get all tensor cores for NVIDIA SM75 architecture (Turing).
667    pub fn sm75_tensor_cores() -> Vec<TensorCore> {
668        vec![CUDA_8168.build(DType::Float16, DType::Float32), CUDA_8168.build(DType::Float16, DType::Float16)]
669    }
670
671    /// Get all tensor cores for NVIDIA SM80 architecture (Ampere).
672    pub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
673        let mut tcs = vec![
674            CUDA_81616.build(DType::Float16, DType::Float32),
675            CUDA_81616.build(DType::BFloat16, DType::Float32),
676            CUDA_81616.build(DType::Float16, DType::Float16),
677            CUDA_8168.build(DType::Float16, DType::Float32),
678            CUDA_8168.build(DType::Float16, DType::Float16),
679        ];
680        if allow_tf32 {
681            tcs.push(CUDA_8168_TF32.build(DType::Float32, DType::Float32));
682        }
683        tcs
684    }
685
686    /// Get all tensor cores for NVIDIA SM89 architecture (Hopper).
687    pub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
688        let mut tcs = Self::sm80_tensor_cores(allow_tf32);
689        tcs.push(CUDA_81632.build(DType::FP8E4M3, DType::Float32));
690        tcs.push(CUDA_81632.build(DType::FP8E5M2, DType::Float32));
691        tcs
692    }
693
694    /// Get all tensor cores for AMD RDNA3 architecture (RX 7000 series).
695    pub fn rdna3_tensor_cores() -> Vec<TensorCore> {
696        vec![
697            AMD_RDNA3.build(DType::Float16, DType::Float32),
698            AMD_RDNA3.build(DType::Float16, DType::Float16),
699            AMD_RDNA3.build(DType::BFloat16, DType::Float32),
700        ]
701    }
702
703    /// Get all tensor cores for AMD RDNA4 architecture.
704    pub fn rdna4_tensor_cores() -> Vec<TensorCore> {
705        vec![
706            AMD_RDNA4.build(DType::Float16, DType::Float32),
707            AMD_RDNA4.build(DType::Float16, DType::Float16),
708            AMD_RDNA4.build(DType::BFloat16, DType::Float32),
709            AMD_RDNA4.build(DType::BFloat16, DType::BFloat16),
710        ]
711    }
712
713    /// Get all tensor cores for AMD CDNA3 architecture (MI300).
714    pub fn cdna3_tensor_cores() -> Vec<TensorCore> {
715        vec![
716            AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
717            AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
718            AMD_CDNA_161616.build(DType::Float16, DType::Float32),
719            AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
720        ]
721    }
722
723    /// Get all tensor cores for AMD CDNA4 architecture.
724    pub fn cdna4_tensor_cores() -> Vec<TensorCore> {
725        vec![
726            AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
727            AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
728            AMD_CDNA_161632.build(DType::Float16, DType::Float32),
729            AMD_CDNA_161632.build(DType::BFloat16, DType::Float32),
730            AMD_CDNA_161616.build(DType::Float16, DType::Float32),
731            AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
732        ]
733    }
734
735    /// Get all tensor cores for Apple Metal (M1/M2/M3).
736    pub fn metal_tensor_cores() -> Vec<TensorCore> {
737        vec![
738            METAL_888.build(DType::Float32, DType::Float32),
739            METAL_888.build(DType::Float16, DType::Float32),
740            METAL_888.build(DType::Float16, DType::Float16),
741            METAL_888.build(DType::BFloat16, DType::Float32),
742            METAL_888.build(DType::BFloat16, DType::BFloat16),
743        ]
744    }
745
746    /// Get all tensor cores for Apple AMX (M1/M2/M3 matrix accelerators).
747    pub fn amx_tensor_cores() -> Vec<TensorCore> {
748        vec![
749            APPLE_AMX.build(DType::Float32, DType::Float32),
750            APPLE_AMX_F16.build(DType::Float16, DType::Float16),
751            APPLE_AMX_F16_F32.build(DType::Float16, DType::Float32), // Mixed-precision
752            APPLE_AMX_F64.build(DType::Float64, DType::Float64),
753            APPLE_AMX_I16.build(DType::Int16, DType::Int16),
754        ]
755    }
756
757    /// Get all tensor cores for Intel Xe architecture.
758    pub fn intel_tensor_cores() -> Vec<TensorCore> {
759        vec![INTEL_XE_8816.build(DType::Float16, DType::Float32)]
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766
767    #[test]
768    fn test_renderer_cpu() {
769        let r = Renderer::cpu();
770        assert_eq!(r.device, "CPU");
771        assert!(!r.has_local);
772        assert!(r.has_threads);
773        assert_eq!(r.tensor_cores.len(), 0);
774    }
775
776    #[test]
777    fn test_renderer_cuda() {
778        let r = Renderer::cuda();
779        assert_eq!(r.device, "CUDA_SM80"); // Default is SM80/Ampere
780        assert!(r.has_local);
781        assert!(r.has_shared);
782        assert!(!r.has_threads);
783        assert!(r.shared_max > 0);
784        assert!(!r.tensor_cores.is_empty());
785    }
786
787    #[test]
788    fn test_tensor_core_cuda() {
789        let tc = CUDA_81616.build(DType::Float16, DType::Float32);
790        assert_eq!(tc.dims, (8, 16, 16));
791        assert_eq!(tc.threads, 32);
792        assert_eq!(tc.dtype_in, DType::Float16);
793        assert_eq!(tc.dtype_out, DType::Float32);
794        assert!(!tc.opts.is_empty());
795    }
796}