1use morok_dtype::DType;
8use smallvec::SmallVec;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum TcOpt {
16 Upcast(usize),
18 Local(usize),
20}
21
22impl TcOpt {
23 pub const fn dim(&self) -> usize {
25 match self {
26 Self::Upcast(dim) | Self::Local(dim) => *dim,
27 }
28 }
29
30 pub const fn is_upcast(&self) -> bool {
32 matches!(self, Self::Upcast(_))
33 }
34
35 pub const fn is_local(&self) -> bool {
37 matches!(self, Self::Local(_))
38 }
39}
40
41impl std::fmt::Display for TcOpt {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Upcast(dim) => write!(f, "u{}", dim),
45 Self::Local(dim) => write!(f, "l{}", dim),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub enum SwizzleAxis {
57 Upcast(usize),
59 Local(usize),
61 Reduce(usize),
63}
64
65impl std::fmt::Display for SwizzleAxis {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::Upcast(idx) => write!(f, "u{}", idx),
69 Self::Local(idx) => write!(f, "l{}", idx),
70 Self::Reduce(idx) => write!(f, "r{}", idx),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
80pub struct Renderer {
81 pub device: String,
83
84 pub has_local: bool,
86
87 pub has_shared: bool,
89
90 pub has_threads: bool,
92
93 pub shared_max: usize,
98
99 pub global_max: Option<Vec<usize>>,
105
106 pub local_max: Option<usize>,
111
112 pub upcast_max: usize,
117
118 pub buffer_max: Option<usize>,
123
124 pub tensor_cores: Vec<TensorCore>,
129}
130
131impl Renderer {
132 pub fn cpu() -> Self {
134 let cores = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8);
135 Self {
136 device: "CPU".to_string(),
137 has_local: false,
138 has_shared: false,
139 has_threads: true,
140 shared_max: 0,
141 global_max: Some(vec![cores]), local_max: None,
143 upcast_max: 16, buffer_max: None,
145 tensor_cores: vec![],
146 }
147 }
148
149 pub fn cuda() -> Self {
153 Self::cuda_sm80(false) }
155
156 pub fn cuda_sm75() -> Self {
158 Self {
159 device: "CUDA_SM75".to_string(),
160 has_local: true,
161 has_shared: true,
162 has_threads: false,
163 shared_max: 49152,
164 global_max: Some(vec![2147483647, 65535, 65535]),
165 local_max: Some(1024),
166 upcast_max: 8,
167 buffer_max: None,
168 tensor_cores: TensorCore::sm75_tensor_cores(),
169 }
170 }
171
172 pub fn cuda_sm80(allow_tf32: bool) -> Self {
174 Self {
175 device: "CUDA_SM80".to_string(),
176 has_local: true,
177 has_shared: true,
178 has_threads: false,
179 shared_max: 49152,
180 global_max: Some(vec![2147483647, 65535, 65535]),
181 local_max: Some(1024),
182 upcast_max: 8,
183 buffer_max: None,
184 tensor_cores: TensorCore::sm80_tensor_cores(allow_tf32),
185 }
186 }
187
188 pub fn cuda_sm89(allow_tf32: bool) -> Self {
190 Self {
191 device: "CUDA_SM89".to_string(),
192 has_local: true,
193 has_shared: true,
194 has_threads: false,
195 shared_max: 49152,
196 global_max: Some(vec![2147483647, 65535, 65535]),
197 local_max: Some(1024),
198 upcast_max: 8,
199 buffer_max: None,
200 tensor_cores: TensorCore::sm89_tensor_cores(allow_tf32),
201 }
202 }
203
204 pub fn metal() -> Self {
206 Self {
207 device: "Metal".to_string(),
208 has_local: true,
209 has_shared: true,
210 has_threads: false,
211 shared_max: 32768, global_max: None,
213 local_max: Some(1024),
214 upcast_max: 4, buffer_max: Some(31), tensor_cores: TensorCore::metal_tensor_cores(),
217 }
218 }
219
220 pub fn apple_amx() -> Self {
222 Self {
223 device: "AppleAMX".to_string(),
224 has_local: false, has_shared: false,
226 has_threads: true, shared_max: 0,
228 global_max: Some(vec![256]),
229 local_max: None,
230 upcast_max: 16,
231 buffer_max: None,
232 tensor_cores: TensorCore::amx_tensor_cores(),
233 }
234 }
235
236 pub fn is_amx(&self) -> bool {
238 self.device == "AppleAMX"
239 }
240
241 pub fn amd_rdna3() -> Self {
243 Self {
244 device: "AMD_RDNA3".to_string(),
245 has_local: true,
246 has_shared: true,
247 has_threads: false,
248 shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
250 local_max: Some(1024),
251 upcast_max: 8,
252 buffer_max: None,
253 tensor_cores: TensorCore::rdna3_tensor_cores(),
254 }
255 }
256
257 pub fn amd_rdna4() -> Self {
259 Self {
260 device: "AMD_RDNA4".to_string(),
261 has_local: true,
262 has_shared: true,
263 has_threads: false,
264 shared_max: 65536,
265 global_max: Some(vec![2147483647, 65535, 65535]),
266 local_max: Some(1024),
267 upcast_max: 8,
268 buffer_max: None,
269 tensor_cores: TensorCore::rdna4_tensor_cores(),
270 }
271 }
272
273 pub fn amd_cdna3() -> Self {
275 Self {
276 device: "AMD_CDNA3".to_string(),
277 has_local: true,
278 has_shared: true,
279 has_threads: false,
280 shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
282 local_max: Some(1024),
283 upcast_max: 8,
284 buffer_max: None,
285 tensor_cores: TensorCore::cdna3_tensor_cores(),
286 }
287 }
288
289 pub fn amd_cdna4() -> Self {
291 Self {
292 device: "AMD_CDNA4".to_string(),
293 has_local: true,
294 has_shared: true,
295 has_threads: false,
296 shared_max: 65536,
297 global_max: Some(vec![2147483647, 65535, 65535]),
298 local_max: Some(1024),
299 upcast_max: 8,
300 buffer_max: None,
301 tensor_cores: TensorCore::cdna4_tensor_cores(),
302 }
303 }
304
305 pub fn intel_xe() -> Self {
307 Self {
308 device: "IntelXe".to_string(),
309 has_local: true,
310 has_shared: true,
311 has_threads: false,
312 shared_max: 65536, global_max: Some(vec![2147483647, 65535, 65535]),
314 local_max: Some(512),
315 upcast_max: 8,
316 buffer_max: None,
317 tensor_cores: TensorCore::intel_tensor_cores(),
318 }
319 }
320
321 pub fn webgpu() -> Self {
323 Self {
324 device: "WebGPU".to_string(),
325 has_local: true,
326 has_shared: true,
327 has_threads: false,
328 shared_max: 16384, global_max: Some(vec![65535, 65535, 65535]),
330 local_max: Some(256),
331 upcast_max: 4,
332 buffer_max: Some(8), tensor_cores: vec![],
334 }
335 }
336}
337
338#[derive(Debug, Clone)]
358pub struct TensorCore {
359 pub dims: (usize, usize, usize),
361
362 pub threads: usize,
364
365 pub elements_per_thread: (usize, usize, usize),
371
372 pub dtype_in: DType,
374
375 pub dtype_out: DType,
377
378 pub opts: SmallVec<[TcOpt; 8]>,
391
392 #[allow(clippy::type_complexity)]
402 pub swizzle: (
403 (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
404 (SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>, SmallVec<[SwizzleAxis; 8]>),
405 ),
406
407 pub pack_a: bool,
410
411 pub tile_grid: (usize, usize),
417}
418
419pub struct TcConfig {
428 dims: (usize, usize, usize),
429 threads: usize,
430 ept: (usize, usize, usize),
431 opts: &'static [TcOpt],
432 swizzle_a: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
433 swizzle_b: (&'static [SwizzleAxis], &'static [SwizzleAxis], &'static [SwizzleAxis]),
434 pack_a: bool,
435 tile_grid: (usize, usize),
436}
437
438impl TcConfig {
439 pub fn build(&self, dtype_in: DType, dtype_out: DType) -> TensorCore {
441 TensorCore {
442 dims: self.dims,
443 threads: self.threads,
444 elements_per_thread: self.ept,
445 dtype_in,
446 dtype_out,
447 opts: self.opts.iter().copied().collect(),
448 swizzle: (
449 (
450 self.swizzle_a.0.iter().copied().collect(),
451 self.swizzle_a.1.iter().copied().collect(),
452 self.swizzle_a.2.iter().copied().collect(),
453 ),
454 (
455 self.swizzle_b.0.iter().copied().collect(),
456 self.swizzle_b.1.iter().copied().collect(),
457 self.swizzle_b.2.iter().copied().collect(),
458 ),
459 ),
460 pack_a: self.pack_a,
461 tile_grid: self.tile_grid,
462 }
463 }
464}
465
466use SwizzleAxis::{Local as SL, Reduce as R, Upcast as SU};
468use TcOpt::{Local as L, Upcast as U};
469
470pub const CUDA_81616: TcConfig = TcConfig {
472 dims: (8, 16, 16),
473 threads: 32,
474 ept: (8, 4, 4),
475 opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
476 swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[SU(1), R(3)], &[SL(0), SL(1), SU(0), R(0)]),
477 swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[R(0), R(3)], &[SL(2), SL(3), SL(4), SU(1)]),
478 pack_a: false,
479 tile_grid: (1, 1),
480};
481
482pub const CUDA_81632: TcConfig = TcConfig {
483 dims: (8, 16, 32),
484 threads: 32,
485 ept: (16, 8, 4),
486 opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
487 swizzle_a: (&[R(2), R(3), SL(2), SL(3), SL(4)], &[SU(1), R(4)], &[SL(0), SL(1), SU(0), R(0), R(1)]),
488 swizzle_b: (&[R(2), R(3), SU(0), SL(0), SL(1)], &[R(1), R(4)], &[SL(2), SL(3), SL(4), SU(1), R(0)]),
489 pack_a: false,
490 tile_grid: (1, 1),
491};
492
493pub const CUDA_8168: TcConfig = TcConfig {
494 dims: (8, 16, 8),
495 threads: 32,
496 ept: (4, 2, 4),
497 opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
498 swizzle_a: (&[R(1), R(2), SL(2), SL(3), SL(4)], &[R(0), SU(1)], &[SL(0), SL(1), SU(0)]),
499 swizzle_b: (&[R(1), R(2), SU(0), SL(0), SL(1)], &[SU(1), R(0)], &[SL(2), SL(3), SL(4)]),
500 pack_a: false,
501 tile_grid: (1, 1),
502};
503
504pub const CUDA_8168_TF32: TcConfig = TcConfig {
505 dims: (8, 16, 8),
506 threads: 32,
507 ept: (4, 2, 4),
508 opts: &[U(0), L(0), L(0), L(1), L(1), L(1), U(1)],
509 swizzle_a: (&[R(0), R(1), SL(2), SL(3), SL(4)], &[SU(1), R(2)], &[SL(0), SL(1), SU(0)]),
510 swizzle_b: (&[R(0), R(1), SU(0), SL(0), SL(1)], &[SU(1), R(2)], &[SL(2), SL(3), SL(4)]),
511 pack_a: false,
512 tile_grid: (1, 1),
513};
514
515pub const AMD_RDNA3: TcConfig = TcConfig {
517 dims: (16, 16, 16),
518 threads: 32,
519 ept: (16, 16, 8),
520 opts: &[L(0), L(0), L(0), L(0), L(1), U(1), U(1), U(1)],
521 swizzle_a: (&[SL(4), SU(0), SU(1), SU(2), SL(0)], &[R(1), R(2), R(3)], &[SL(1), SL(2), SL(3), R(0)]),
522 swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), SL(4)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
523 pack_a: false,
524 tile_grid: (1, 1),
525};
526
527pub const AMD_RDNA4: TcConfig = TcConfig {
528 dims: (16, 16, 16),
529 threads: 32,
530 ept: (8, 8, 8),
531 opts: &[L(0), L(0), L(0), L(0), U(1), U(1), U(1), L(1)],
532 swizzle_a: (&[SU(0), SU(1), SU(2), SL(4), R(2)], &[R(0), R(1), R(3)], &[SL(0), SL(1), SL(2), SL(3)]),
533 swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2)], &[R(0), R(1), R(3)], &[SL(4), SU(0), SU(1), SU(2)]),
534 pack_a: false,
535 tile_grid: (1, 1),
536};
537
538pub const AMD_CDNA_161616: TcConfig = TcConfig {
539 dims: (16, 16, 16),
540 threads: 64,
541 ept: (4, 4, 4),
542 opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
543 swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(2), R(3)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3)]),
544 swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(2), R(3)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1)]),
545 pack_a: false,
546 tile_grid: (1, 1),
547};
548
549pub const AMD_CDNA_161632: TcConfig = TcConfig {
550 dims: (16, 16, 32),
551 threads: 64,
552 ept: (8, 8, 4),
553 opts: &[L(0), L(0), L(0), L(0), U(1), U(1), L(1), L(1)],
554 swizzle_a: (&[SU(0), SU(1), SL(4), SL(5), R(3), R(4)], &[R(0), R(1)], &[SL(0), SL(1), SL(2), SL(3), R(2)]),
555 swizzle_b: (&[SL(0), SL(1), SL(2), SL(3), R(3), R(4)], &[R(0), R(1)], &[SL(4), SL(5), SU(0), SU(1), R(2)]),
556 pack_a: false,
557 tile_grid: (1, 1),
558};
559
560pub const METAL_888: TcConfig = TcConfig {
562 dims: (8, 8, 8),
563 threads: 32,
564 ept: (2, 2, 2),
565 opts: &[U(0), L(0), L(1), L(1), L(0), L(1)],
566 swizzle_a: (&[R(1), SL(1), SL(2), R(2), SL(4)], &[R(0)], &[SU(0), SL(0), SL(3)]),
567 swizzle_b: (&[SL(0), R(0), R(1), SL(3), R(2)], &[SU(0)], &[SL(1), SL(2), SL(4)]),
568 pack_a: false,
569 tile_grid: (1, 1),
570};
571
572pub const APPLE_AMX: TcConfig = TcConfig {
576 dims: (16, 16, 1),
577 threads: 1,
578 ept: (16, 16, 256),
579 opts: &[U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1)],
580 swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7)], &[]),
581 swizzle_b: (&[], &[SU(4), SU(5), SU(6), SU(7), SU(0), SU(1), SU(2), SU(3)], &[]),
582 pack_a: true,
583 tile_grid: (1, 1),
584};
585
586pub const APPLE_AMX_F16_F32: TcConfig = TcConfig {
587 dims: (32, 32, 1),
588 threads: 1,
589 ept: (32, 32, 1024),
590 opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
591 swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
592 swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
593 pack_a: true,
594 tile_grid: (1, 1),
595};
596
597pub const APPLE_AMX_F16: TcConfig = TcConfig {
598 dims: (32, 32, 1),
599 threads: 1,
600 ept: (32, 32, 1024),
601 opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
602 swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
603 swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
604 pack_a: true,
605 tile_grid: (1, 1),
606};
607
608pub const APPLE_AMX_F64: TcConfig = TcConfig {
609 dims: (8, 8, 1),
610 threads: 1,
611 ept: (8, 8, 64),
612 opts: &[U(0), U(0), U(0), U(1), U(1), U(1)],
613 swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5)], &[]),
614 swizzle_b: (&[], &[SU(3), SU(4), SU(5), SU(0), SU(1), SU(2)], &[]),
615 pack_a: true,
616 tile_grid: (1, 1),
617};
618
619pub const APPLE_AMX_I16: TcConfig = TcConfig {
621 dims: (32, 32, 1),
622 threads: 1,
623 ept: (32, 32, 1024),
624 opts: &[U(0), U(0), U(0), U(0), U(0), U(1), U(1), U(1), U(1), U(1)],
625 swizzle_a: (&[], &[SU(0), SU(1), SU(2), SU(3), SU(4), SU(5), SU(6), SU(7), SU(8), SU(9)], &[]),
626 swizzle_b: (&[], &[SU(5), SU(6), SU(7), SU(8), SU(9), SU(0), SU(1), SU(2), SU(3), SU(4)], &[]),
627 pack_a: true,
628 tile_grid: (1, 1),
629};
630
631pub const INTEL_XE_8816: TcConfig = TcConfig {
633 dims: (8, 8, 16),
634 threads: 8,
635 ept: (16, 16, 8),
636 opts: &[L(0), L(0), L(0), U(1), U(1), U(1)],
637 swizzle_a: (&[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2)], &[SL(0), SL(1), SL(2), R(0)]),
638 swizzle_b: (&[SL(0), SL(1), SL(2)], &[R(1), R(2), R(3)], &[SU(0), SU(1), SU(2), R(0)]),
639 pack_a: false,
640 tile_grid: (1, 1),
641};
642
643impl TensorCore {
644 pub fn get_reduce_axes(&self) -> Vec<(usize, usize)> {
651 (0..(self.dims.2 as f64).log2().floor() as usize).map(|i| (i, 2)).collect()
652 }
653
654 pub fn upcast_axes(&self) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
659 (vec![0, 1], vec![0, 1], vec![0, 1])
662 }
663
664 pub fn sm75_tensor_cores() -> Vec<TensorCore> {
668 vec![CUDA_8168.build(DType::Float16, DType::Float32), CUDA_8168.build(DType::Float16, DType::Float16)]
669 }
670
671 pub fn sm80_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
673 let mut tcs = vec![
674 CUDA_81616.build(DType::Float16, DType::Float32),
675 CUDA_81616.build(DType::BFloat16, DType::Float32),
676 CUDA_81616.build(DType::Float16, DType::Float16),
677 CUDA_8168.build(DType::Float16, DType::Float32),
678 CUDA_8168.build(DType::Float16, DType::Float16),
679 ];
680 if allow_tf32 {
681 tcs.push(CUDA_8168_TF32.build(DType::Float32, DType::Float32));
682 }
683 tcs
684 }
685
686 pub fn sm89_tensor_cores(allow_tf32: bool) -> Vec<TensorCore> {
688 let mut tcs = Self::sm80_tensor_cores(allow_tf32);
689 tcs.push(CUDA_81632.build(DType::FP8E4M3, DType::Float32));
690 tcs.push(CUDA_81632.build(DType::FP8E5M2, DType::Float32));
691 tcs
692 }
693
694 pub fn rdna3_tensor_cores() -> Vec<TensorCore> {
696 vec![
697 AMD_RDNA3.build(DType::Float16, DType::Float32),
698 AMD_RDNA3.build(DType::Float16, DType::Float16),
699 AMD_RDNA3.build(DType::BFloat16, DType::Float32),
700 ]
701 }
702
703 pub fn rdna4_tensor_cores() -> Vec<TensorCore> {
705 vec![
706 AMD_RDNA4.build(DType::Float16, DType::Float32),
707 AMD_RDNA4.build(DType::Float16, DType::Float16),
708 AMD_RDNA4.build(DType::BFloat16, DType::Float32),
709 AMD_RDNA4.build(DType::BFloat16, DType::BFloat16),
710 ]
711 }
712
713 pub fn cdna3_tensor_cores() -> Vec<TensorCore> {
715 vec![
716 AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
717 AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
718 AMD_CDNA_161616.build(DType::Float16, DType::Float32),
719 AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
720 ]
721 }
722
723 pub fn cdna4_tensor_cores() -> Vec<TensorCore> {
725 vec![
726 AMD_CDNA_161632.build(DType::FP8E5M2, DType::Float32),
727 AMD_CDNA_161632.build(DType::FP8E4M3, DType::Float32),
728 AMD_CDNA_161632.build(DType::Float16, DType::Float32),
729 AMD_CDNA_161632.build(DType::BFloat16, DType::Float32),
730 AMD_CDNA_161616.build(DType::Float16, DType::Float32),
731 AMD_CDNA_161616.build(DType::BFloat16, DType::Float32),
732 ]
733 }
734
735 pub fn metal_tensor_cores() -> Vec<TensorCore> {
737 vec![
738 METAL_888.build(DType::Float32, DType::Float32),
739 METAL_888.build(DType::Float16, DType::Float32),
740 METAL_888.build(DType::Float16, DType::Float16),
741 METAL_888.build(DType::BFloat16, DType::Float32),
742 METAL_888.build(DType::BFloat16, DType::BFloat16),
743 ]
744 }
745
746 pub fn amx_tensor_cores() -> Vec<TensorCore> {
748 vec![
749 APPLE_AMX.build(DType::Float32, DType::Float32),
750 APPLE_AMX_F16.build(DType::Float16, DType::Float16),
751 APPLE_AMX_F16_F32.build(DType::Float16, DType::Float32), APPLE_AMX_F64.build(DType::Float64, DType::Float64),
753 APPLE_AMX_I16.build(DType::Int16, DType::Int16),
754 ]
755 }
756
757 pub fn intel_tensor_cores() -> Vec<TensorCore> {
759 vec![INTEL_XE_8816.build(DType::Float16, DType::Float32)]
760 }
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766
767 #[test]
768 fn test_renderer_cpu() {
769 let r = Renderer::cpu();
770 assert_eq!(r.device, "CPU");
771 assert!(!r.has_local);
772 assert!(r.has_threads);
773 assert_eq!(r.tensor_cores.len(), 0);
774 }
775
776 #[test]
777 fn test_renderer_cuda() {
778 let r = Renderer::cuda();
779 assert_eq!(r.device, "CUDA_SM80"); assert!(r.has_local);
781 assert!(r.has_shared);
782 assert!(!r.has_threads);
783 assert!(r.shared_max > 0);
784 assert!(!r.tensor_cores.is_empty());
785 }
786
787 #[test]
788 fn test_tensor_core_cuda() {
789 let tc = CUDA_81616.build(DType::Float16, DType::Float32);
790 assert_eq!(tc.dims, (8, 16, 16));
791 assert_eq!(tc.threads, 32);
792 assert_eq!(tc.dtype_in, DType::Float16);
793 assert_eq!(tc.dtype_out, DType::Float32);
794 assert!(!tc.opts.is_empty());
795 }
796}