1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5 Dialect,
6 shared::{
7 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8 DialectIncludes, DialectInstructions, DialectTypes, DialectWmmaCompiler, Elem, Flags,
9 FmtLeft, Fragment, FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory,
10 SupportedWmmaCombinations, Variable, WarpInstruction, WmmaInstruction,
11 },
12};
13use cubecl_core::{
14 compute::{Location, Visibility},
15 ir::{self as gpu, Id},
16};
17
18use super::{
19 AddressSpace, Extension,
20 arch::MetalArchitecture,
21 extension::{format_ffs, format_mulhi},
22 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
23};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct MslDialect {}
27
28impl Dialect for MslDialect {}
31
32impl DialectIncludes<Self> for MslDialect {
35 type Extension = Extension<Self>;
36
37 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
38 write!(
39 f,
40 "
41#include <metal_stdlib>
42using namespace metal;
43"
44 )?;
45 Ok(())
46 }
47
48 fn compile_extensions(
49 f: &mut std::fmt::Formatter<'_>,
50 extensions: &[Self::Extension],
51 ) -> std::fmt::Result {
52 for extension in extensions {
53 match extension {
54 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
55 Extension::Ffs(elem) => format_ffs(f, elem)?,
56 Extension::MulHi(elem) => format_mulhi(f, elem)?,
57 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
58 Extension::NoExtension => {}
59 }
60 }
61 Ok(())
62 }
63
64 fn register_instruction_extension(
65 extensions: &mut Vec<Self::Extension>,
66 instruction: &Instruction<Self>,
67 ) {
68 let mut register_extension = |extension: Self::Extension| {
69 if !extensions.contains(&extension) {
70 extensions.push(extension);
71 }
72 };
73 #[allow(clippy::single_match)]
74 match instruction {
75 shared::Instruction::<Self>::Erf(instruction) => {
76 register_extension(Extension::Erf(
77 instruction.input.elem(),
78 instruction.out.elem(),
79 ));
80 }
81 shared::Instruction::<Self>::FindFirstSet(instruction) => {
82 let input_elem = instruction.input.elem();
83 match input_elem {
84 Elem::U32 | Elem::U64 => {
85 register_extension(Extension::Ffs(instruction.input.elem()));
86 }
87 Elem::I32 => {
88 register_extension(Extension::Ffs(Elem::<Self>::U32));
89 register_extension(Extension::Ffs(instruction.input.elem()));
90 }
91 Elem::I64 => {
92 register_extension(Extension::Ffs(Elem::<Self>::U64));
93 register_extension(Extension::Ffs(instruction.input.elem()));
94 }
95 _ => {
96 register_extension(Extension::Ffs(Elem::<Self>::U32));
97 }
98 }
99 }
100 shared::Instruction::<Self>::HiMul(instruction) => {
101 register_extension(Extension::MulHi(instruction.out.elem()));
102 }
103 shared::Instruction::<Self>::Tanh(instruction) => {
104 register_extension(Extension::SafeTanh(instruction.input.item()));
105 }
106 _ => {}
107 }
108 }
109
110 fn register_warp_instruction_extension(
111 _extensions: &mut Vec<Self::Extension>,
112 _instruction: &WarpInstruction<Self>,
113 ) {
114 }
115}
116
117impl DialectTypes<Self> for MslDialect {
120 fn item_can_be_optimized() -> bool {
121 false
122 }
123
124 fn compile_type_definitions(
125 f: &mut std::fmt::Formatter<'_>,
126 items: &std::collections::HashSet<crate::shared::Item<Self>>,
127 _scalars: &[(Elem<Self>, usize)],
128 _flags: &Flags,
129 ) -> std::fmt::Result {
130 for item in items.iter() {
131 let elem = item.elem;
132 let size = item.vectorization;
133 let alignment = elem.size() * size;
134 if size > 1 {
135 write!(
136 f,
137 "
138struct alignas({alignment}) {item} {{"
139 )?;
140
141 for i in 0..size {
142 write!(
143 f,
144 "
145 {elem} i_{i};"
146 )?;
147 }
148
149 f.write_str("\n};\n")?;
150 }
151 }
152 Ok(())
153 }
154
155 fn compile_elem(
156 f: &mut std::fmt::Formatter<'_>,
157 elem: &shared::Elem<Self>,
158 _words: bool,
159 ) -> std::fmt::Result {
160 match elem {
162 shared::Elem::F16 => f.write_str("half"),
163 shared::Elem::F162 => panic!("type F162 not supported!"),
164 shared::Elem::F32 => f.write_str("float"),
165 shared::Elem::F64 => panic!("type double not supported!"),
166 shared::Elem::BF16 => f.write_str("bfloat"),
167 shared::Elem::BF162 => panic!("type BF162 not supported!"),
168 shared::Elem::TF32 => f.write_str("float"),
169 shared::Elem::I8 => f.write_str("char"),
170 shared::Elem::I16 => f.write_str("short"),
171 shared::Elem::I32 => f.write_str("int"),
172 shared::Elem::I64 => f.write_str("long"),
173 shared::Elem::U8 => f.write_str("uchar"),
174 shared::Elem::U16 => f.write_str("ushort"),
175 shared::Elem::U32 => f.write_str("uint"),
176 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
178 shared::Elem::Atomic(inner) => inner.fmt(f),
179 shared::Elem::_Dialect(_) => Ok(()),
180 }
181 }
182
183 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
184 if 1 == item.vectorization {
185 return write!(f, "{}", item.elem);
186 }
187 if item.native {
188 write!(f, "{}{}", item.elem, item.vectorization)
189 } else {
190 write!(f, "{}_{}", item.elem, item.vectorization)
191 }
192 }
193
194 fn compile_atomic_kind(
195 f: &mut std::fmt::Formatter<'_>,
196 kind: &AtomicKind<Self>,
197 ) -> std::fmt::Result {
198 match kind {
199 AtomicKind::I32 => write!(f, "atomic_int"),
200 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
201 AtomicKind::U32 => write!(f, "atomic_uint"),
202 AtomicKind::U64 => write!(f, "atomic_ulong"),
203 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
204 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
205 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
207 AtomicKind::_Dialect(_) => Ok(()),
208 }
209 }
210
211 fn address_space_for_variable(variable: &Variable<Self>) -> String {
212 format!("{} ", AddressSpace::from(variable))
213 }
214
215 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 write!(f, "thread")
217 }
218
219 fn compile_shared_memory_qualifier(
220 f: &mut std::fmt::Formatter<'_>,
221 _shared: &SharedMemory<Self>,
222 ) -> std::fmt::Result {
223 write!(f, "threadgroup")
224 }
225}
226
227impl DialectBindings<Self> for MslDialect {
230 fn compile_kernel_signature(
231 f: &mut std::fmt::Formatter<'_>,
232 kernel_name: &str,
233 tensor_maps: &[Id],
234 buffers: &[Binding<Self>],
235 scalars: &[(Elem<Self>, usize)],
236 flags: &Flags,
237 ) -> std::fmt::Result {
238 write!(
239 (f),
240 "
241[[kernel]]
242void {}(",
243 kernel_name
244 )?;
245 let mut buffer_idx = 0;
247 debug_assert!(
248 tensor_maps.is_empty(),
249 "Tensor maps aren't supported for metal"
250 );
251 for (i, b) in buffers.iter().enumerate() {
252 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
253 }
254 if flags.static_meta_length > 0 {
255 let binding = Binding {
256 id: 0,
257 item: Item::scalar(Elem::<Self>::U32, true),
258 location: Location::Storage,
259 size: None,
260 vis: Visibility::Read,
261 };
262 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
263 }
264 for (elem, _) in scalars.iter() {
265 let binding = Binding {
266 id: 0,
267 item: Item::scalar(*elem, true),
268 location: Location::Storage,
269 size: None,
270 vis: Visibility::Read,
271 };
272
273 let name = format!("scalars_{elem}");
274 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
275 }
276
277 let builtins = vec![
279 (
280 flags.indexes.absolute_pos_tuple,
281 Variable::<Self>::AbsolutePosBaseName,
282 ),
283 (
284 flags.indexes.cube_dim_tuple,
285 Variable::<Self>::CubeDimBaseName,
286 ),
287 (
288 flags.indexes.cube_count_tuple,
289 Variable::<Self>::CubeCountBaseName,
290 ),
291 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
292 (
293 flags.indexes.unit_pos_tuple,
294 Variable::<Self>::UnitPosBaseName,
295 ),
296 (
297 flags.indexes.cube_pos_tuple,
298 Variable::<Self>::CubePosBaseName,
299 ),
300 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
301 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
302 (flags.indexes.plane_index, Variable::<Self>::PlanePos),
303 ];
304 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
305 builtins
306 .iter()
307 .filter(|(cond, _)| *cond)
308 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
309 f.write_str("\n)")
310 }
311}
312
313impl DialectCubeBuiltins<Self> for MslDialect {
316 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
321 let absolute_pos = flags.absolute_pos;
322 let cube_count = flags.cube_count;
323 let cube_dim = flags.cube_dim;
324 let cube_pos = flags.cube_pos;
325 let plane_dim_checked = flags.plane_dim_checked;
326 let plane_index = flags.plane_index;
327 let unit_pos = flags.unit_pos;
328 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
329 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
330 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
331 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
332 let cluster_pos = flags.cluster_pos;
333 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
334 let unit_pos_plane = flags.unit_pos_plane || plane_index;
335 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
336 CubeIndexFlags {
337 absolute_pos_tuple,
338 absolute_pos,
339 cube_count_tuple,
340 cube_count,
341 cube_dim_tuple,
342 cube_dim,
343 cube_pos_tuple,
344 cube_pos,
345 plane_dim,
346 plane_dim_checked,
347 plane_index,
348 unit_pos_tuple,
349 unit_pos,
350 unit_pos_plane,
351 cluster_pos,
352 }
353 }
354
355 fn compile_absolute_pos_tuple_computation(
356 _f: &mut std::fmt::Formatter<'_>,
357 ) -> std::fmt::Result {
358 Ok(())
360 }
361
362 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.write_str("thread_pos_in_grid")
364 }
365
366 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.write_str("thread_index_in_grid")
368 }
369
370 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 Self::compile_absolute_pos_base_name(f)?;
372 write!(f, ".x")
373 }
374
375 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 Self::compile_absolute_pos_base_name(f)?;
377 write!(f, ".y")
378 }
379
380 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 Self::compile_absolute_pos_base_name(f)?;
382 write!(f, ".z")
383 }
384
385 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.write_str("threadgroups_per_grid")
387 }
388
389 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 f.write_str("total_threadgroups_in_grid")
391 }
392
393 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 Self::compile_cube_count_base_name(f)?;
395 write!(f, ".x")
396 }
397
398 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 Self::compile_cube_count_base_name(f)?;
400 write!(f, ".y")
401 }
402
403 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 Self::compile_cube_count_base_name(f)?;
405 write!(f, ".z")
406 }
407
408 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 f.write_str("threads_per_threadgroup")
410 }
411
412 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 f.write_str("total_thread_in_threadgroup")
414 }
415
416 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 Self::compile_cube_dim_base_name(f)?;
418 write!(f, ".x")
419 }
420
421 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 Self::compile_cube_dim_base_name(f)?;
423 write!(f, ".y")
424 }
425
426 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427 Self::compile_cube_dim_base_name(f)?;
428 write!(f, ".z")
429 }
430
431 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 f.write_str("threadgroup_pos_in_grid")
433 }
434
435 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436 f.write_str("threadgroup_index_in_grid")
437 }
438
439 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440 Self::compile_cube_pos_base_name(f)?;
441 write!(f, ".x")
442 }
443
444 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445 Self::compile_cube_pos_base_name(f)?;
446 write!(f, ".y")
447 }
448
449 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 Self::compile_cube_pos_base_name(f)?;
451 write!(f, ".z")
452 }
453
454 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 Ok(())
457 }
458
459 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.write_str("thread_pos_in_threadgroup")
461 }
462
463 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464 f.write_str("thread_index_in_threadgroup")
465 }
466
467 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 Self::compile_unit_pos_base_name(f)?;
469 write!(f, ".x")
470 }
471
472 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473 Self::compile_unit_pos_base_name(f)?;
474 write!(f, ".y")
475 }
476
477 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478 Self::compile_unit_pos_base_name(f)?;
479 write!(f, ".z")
480 }
481
482 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 f.write_str("simd_size")
484 }
485
486 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487 f.write_str("threads_per_simdgroup_checked")
488 }
489
490 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 f.write_str("simd_group_id")
492 }
493
494 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495 f.write_str("simd_lane_id")
496 }
497}
498
499impl DialectInstructions<Self> for MslDialect {
502 fn compile_atomic_add(
504 f: &mut std::fmt::Formatter<'_>,
505 lhs: &Variable<Self>,
506 rhs: &Variable<Self>,
507 out: &Variable<Self>,
508 ) -> std::fmt::Result {
509 let out = out.fmt_left();
510 writeln!(
511 f,
512 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
513 )
514 }
515
516 fn compile_atomic_and(
517 f: &mut std::fmt::Formatter<'_>,
518 lhs: &Variable<Self>,
519 rhs: &Variable<Self>,
520 out: &Variable<Self>,
521 ) -> std::fmt::Result {
522 let out = out.fmt_left();
523 writeln!(
524 f,
525 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
526 )
527 }
528
529 fn compile_atomic_cas(
530 f: &mut std::fmt::Formatter<'_>,
531 input: &Variable<Self>,
532 cmp: &Variable<Self>,
533 val: &Variable<Self>,
534 out: &Variable<Self>,
535 ) -> std::fmt::Result {
536 let out = out.fmt_left();
537 writeln!(
538 f,
539 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
540 )
541 }
542
543 fn compile_atomic_load(
544 f: &mut std::fmt::Formatter<'_>,
545 input: &Variable<Self>,
546 out: &Variable<Self>,
547 ) -> std::fmt::Result {
548 let out = out.fmt_left();
549 writeln!(
550 f,
551 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
552 )
553 }
554
555 fn compile_atomic_max(
556 f: &mut std::fmt::Formatter<'_>,
557 lhs: &Variable<Self>,
558 rhs: &Variable<Self>,
559 out: &Variable<Self>,
560 ) -> std::fmt::Result {
561 let out = out.fmt_left();
562 writeln!(
563 f,
564 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
565 )
566 }
567
568 fn compile_atomic_min(
569 f: &mut std::fmt::Formatter<'_>,
570 lhs: &Variable<Self>,
571 rhs: &Variable<Self>,
572 out: &Variable<Self>,
573 ) -> std::fmt::Result {
574 let out = out.fmt_left();
575 writeln!(
576 f,
577 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
578 )
579 }
580
581 fn compile_atomic_or(
582 f: &mut std::fmt::Formatter<'_>,
583 lhs: &Variable<Self>,
584 rhs: &Variable<Self>,
585 out: &Variable<Self>,
586 ) -> std::fmt::Result {
587 let out = out.fmt_left();
588 writeln!(
589 f,
590 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
591 )
592 }
593
594 fn compile_atomic_store(
595 f: &mut std::fmt::Formatter<'_>,
596 input: &Variable<Self>,
597 out: &Variable<Self>,
598 ) -> std::fmt::Result {
599 writeln!(
600 f,
601 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
602 )
603 }
604
605 fn compile_atomic_sub(
606 f: &mut std::fmt::Formatter<'_>,
607 lhs: &Variable<Self>,
608 rhs: &Variable<Self>,
609 out: &Variable<Self>,
610 ) -> std::fmt::Result {
611 let out = out.fmt_left();
612 writeln!(
613 f,
614 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
615 )
616 }
617
618 fn compile_atomic_swap(
619 f: &mut std::fmt::Formatter<'_>,
620 lhs: &Variable<Self>,
621 rhs: &Variable<Self>,
622 out: &Variable<Self>,
623 ) -> std::fmt::Result {
624 let out = out.fmt_left();
625 writeln!(
626 f,
627 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
628 )
629 }
630
631 fn compile_atomic_xor(
632 f: &mut std::fmt::Formatter<'_>,
633 lhs: &Variable<Self>,
634 rhs: &Variable<Self>,
635 out: &Variable<Self>,
636 ) -> std::fmt::Result {
637 let out = out.fmt_left();
638 writeln!(
639 f,
640 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
641 )
642 }
643
644 fn compile_instruction_printf(
646 f: &mut std::fmt::Formatter<'_>,
647 format_string: &str,
648 args: &[Variable<Self>],
649 ) -> std::fmt::Result {
650 let format_string = format_string
651 .replace("\t", "\\t")
652 .replace("\n", "\\n")
653 .replace("\r", "\\r");
654 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
655 let args = match args.is_empty() {
656 true => "".to_string(),
657 false => format!(", {}", args.join(",")),
658 };
659 writeln!(f, "os_log_default.log(\"{format_string}\"{args});")
660 }
661
662 fn compile_instruction_log1p_scalar<T: Component<Self>>(
664 f: &mut std::fmt::Formatter<'_>,
665 input: T,
666 ) -> std::fmt::Result {
667 match input.elem() {
668 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
669 write!(f, "log(half(1.0f) + {input})")
670 }
671 _ => write!(f, "log(1.0f + {input})"),
672 }
673 }
674
675 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
678 }
679
680 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
682 }
683
684 fn compile_instruction_tanh_scalar<T: Component<Self>>(
686 f: &mut std::fmt::Formatter<'_>,
687 input: T,
688 ) -> std::fmt::Result {
689 write!(f, "safe_tanh_scalar({input})")
690 }
691
692 fn compile_instruction_find_first_set<T: Component<Self>>(
694 f: &mut std::fmt::Formatter<'_>,
695 input: T,
696 out_elem: Elem<Self>,
697 ) -> std::fmt::Result {
698 write!(f, "{out_elem}(")?;
699 match input.elem() {
700 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
701 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
702 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
703 }?;
704 write!(f, ")")
705 }
706
707 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
708 f: &mut std::fmt::Formatter<'_>,
709 input: T,
710 out_elem: Elem<Self>,
711 ) -> std::fmt::Result {
712 write!(f, "{out_elem}(clz({input}))")
713 }
714
715 fn compile_instruction_popcount_scalar<T: Component<Self>>(
716 f: &mut std::fmt::Formatter<'_>,
717 input: T,
718 out_elem: Elem<Self>,
719 ) -> std::fmt::Result {
720 write!(f, "{out_elem}(")?;
721 match input.elem() {
722 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
723 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
724 }?;
725 write!(f, ")")
726 }
727
728 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
729 f: &mut std::fmt::Formatter<'_>,
730 input: T,
731 out_elem: Elem<Self>,
732 ) -> std::fmt::Result {
733 write!(f, "{out_elem}(")?;
734 match out_elem {
735 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
736 _ => write!(
737 f,
738 "reverse_bits({}) >> {}",
739 shared::unary::zero_extend(input),
740 (size_of::<u32>() - out_elem.size()) * 8
741 ),
742 }?;
743 write!(f, ")")
744 }
745
746 fn compile_instruction_max_function_name(
748 f: &mut std::fmt::Formatter<'_>,
749 _item: Item<Self>,
750 ) -> std::fmt::Result {
751 write!(f, "max")
752 }
753
754 fn compile_instruction_min_function_name(
755 f: &mut std::fmt::Formatter<'_>,
756 _item: Item<Self>,
757 ) -> std::fmt::Result {
758 write!(f, "min")
759 }
760
761 fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
762 write!(f, "pow")
763 }
764
765 fn compile_instruction_half_function_name_prefix() -> &'static str {
766 ""
767 }
768
769 fn compile_instruction_half2_function_name_prefix() -> &'static str {
770 ""
771 }
772
773 fn compile_warp_shuffle(
775 f: &mut std::fmt::Formatter<'_>,
776 var: &str,
777 source: &str,
778 ) -> std::fmt::Result {
779 write!(f, "simd_shuffle({var}, {source})")
780 }
781
782 fn compile_warp_shuffle_xor(
783 f: &mut std::fmt::Formatter<'_>,
784 var: &str,
785 _elem: &Elem<Self>,
786 offset: &str,
787 ) -> std::fmt::Result {
788 write!(f, "simd_shuffle_xor({var}, {offset})")
789 }
790
791 fn compile_warp_shuffle_up(
792 f: &mut std::fmt::Formatter<'_>,
793 var: &str,
794 offset: &str,
795 ) -> std::fmt::Result {
796 write!(f, "simd_shuffle_up({var}, {offset})")
797 }
798
799 fn compile_warp_shuffle_down(
800 f: &mut std::fmt::Formatter<'_>,
801 var: &str,
802 offset: &str,
803 ) -> std::fmt::Result {
804 write!(f, "simd_shuffle_down({var}, {offset})")
805 }
806
807 fn compile_warp_all<T: Component<Self>>(
808 f: &mut std::fmt::Formatter<'_>,
809 input: &T,
810 ) -> std::fmt::Result {
811 write!(f, "simd_all({input})")
812 }
813
814 fn compile_warp_any<T: Component<Self>>(
815 f: &mut std::fmt::Formatter<'_>,
816 input: &T,
817 ) -> std::fmt::Result {
818 write!(f, "simd_any({input})")
819 }
820
821 fn compile_warp_ballot(
822 f: &mut std::fmt::Formatter<'_>,
823 input: &Variable<Self>,
824 out_elem: &Elem<Self>,
825 ) -> std::fmt::Result {
826 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
827 }
828}
829
830impl DialectWmmaCompiler<Self> for MslDialect {
833 type Architecture = MetalArchitecture;
834
835 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 writeln!(f, "#include <metal_simdgroup_matrix>")
837 }
838
839 fn compile_wmma_type_definitions(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840 Ok(())
842 }
843
844 fn compile_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
845 Ok(())
847 }
848
849 fn compile_fragment_ident(
850 _ident: &FragmentIdent<Self>,
851 _f: &mut std::fmt::Formatter<'_>,
852 ) -> std::fmt::Result {
853 Ok(())
855 }
856
857 fn compile_fragment_layout(
858 _layout: &FragmentLayout<Self>,
859 _f: &mut std::fmt::Formatter<'_>,
860 ) -> std::fmt::Result {
861 Ok(())
863 }
864
865 fn compile_fragment(
866 fragment: &Fragment<Self>,
867 f: &mut std::fmt::Formatter<'_>,
868 ) -> std::fmt::Result {
869 let ty = fragment.elem;
870 let m = fragment.m;
872 let n = fragment.n;
873 let k = fragment.k;
874 if m != 8 || n != 8 || k != 8 {
875 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragemts are supported.");
876 }
877 write!(f, "simdgroup_{ty}8x8")
878 }
879
880 fn compile_instruction(
881 instruction: &WmmaInstruction<Self>,
882 f: &mut std::fmt::Formatter<'_>,
883 ) -> std::fmt::Result {
884 match instruction {
885 WmmaInstruction::Fill { frag, value } => {
886 match frag {
887 Variable::WmmaFragment { .. } => {
888 let ty = frag.elem();
889 writeln!(
891 f,
892 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
893 )
894 }
895 _ => panic!("should be a fragment"),
896 }
897 }
898 WmmaInstruction::Load {
899 frag,
900 value,
901 stride,
902 ..
903 } => {
904 let transpose = match frag {
905 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
906 Some(FragmentLayout::RowMajor) => false,
907 Some(FragmentLayout::ColMajor) => true,
908 _ => false,
909 },
910 _ => panic!("should be a fragment"),
911 };
912 let item = value.item();
913 if item.vectorization > 1 {
914 let elem = item.elem;
915 writeln!(
916 f,
917 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value}), {stride}, 0, {transpose});"
918 )
919 } else {
920 writeln!(
921 f,
922 "simdgroup_load({frag}, {value}, {stride}, 0, {transpose});"
923 )
924 }
925 }
926 WmmaInstruction::Execute {
927 frag_a: a,
928 frag_b: b,
929 frag_c: c,
930 frag_d: d,
931 ..
932 } => {
933 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
934 }
935 WmmaInstruction::Store {
936 output,
937 frag,
938 stride,
939 ..
940 } => {
941 let item = output.item();
942 let mut reinterpret_cast = item.vectorization > 1;
943 let elem = match item.elem {
944 Elem::BF16 => {
945 reinterpret_cast = true;
946 Elem::F16
947 }
948 _ => item.elem,
949 };
950 if reinterpret_cast {
951 writeln!(
952 f,
953 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output}), {stride});"
954 )
955 } else {
956 writeln!(f, "simdgroup_store({frag}, {output}, {stride});")
957 }?;
958 writeln!(f, "threadgroup_barrier(mem_flags::mem_none);")
959 }
960 WmmaInstruction::Cast { input, output } => {
961 writeln!(f, "threadgroup_barrier(mem_flags::mem_none);")?;
962 let ty = match output {
963 Variable::WmmaFragment { frag, .. } => frag.elem,
964 _ => panic!("should be a fragment"),
965 };
966 match ty {
967 Elem::BF16 => {
968 let addr_space = Self::address_space_for_variable(output);
969 let elem = Elem::<Self>::F16;
970 writeln!(
973 f,
974 "for(int e=0; e<8; e++) {{
975 {ty} elem = {ty}({input}.thread_elements()[e]);
976 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
977}}"
978 )
979 }
980 _ => {
981 writeln!(
982 f,
983 "for(int e=0; e<8; e++) {{
984 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
985}}"
986 )
987 }
988 }
989 }
990 }
991 }
992
993 fn supported_wmma_combinations(_arch: &Self::Architecture) -> SupportedWmmaCombinations {
994 vec![
995 (
996 gpu::Elem::Float(gpu::FloatKind::F16),
997 gpu::Elem::Float(gpu::FloatKind::F16),
998 gpu::Elem::Float(gpu::FloatKind::F16),
999 vec![(8, 8, 8)],
1000 ),
1001 (
1002 gpu::Elem::Float(gpu::FloatKind::F16),
1003 gpu::Elem::Float(gpu::FloatKind::F16),
1004 gpu::Elem::Float(gpu::FloatKind::F32),
1005 vec![(8, 8, 8)],
1006 ),
1007 (
1008 gpu::Elem::Float(gpu::FloatKind::BF16),
1009 gpu::Elem::Float(gpu::FloatKind::BF16),
1010 gpu::Elem::Float(gpu::FloatKind::BF16),
1011 vec![(8, 8, 8)],
1012 ),
1013 (
1014 gpu::Elem::Float(gpu::FloatKind::F32),
1015 gpu::Elem::Float(gpu::FloatKind::F32),
1016 gpu::Elem::Float(gpu::FloatKind::F32),
1017 vec![(8, 8, 8)],
1018 ),
1019 ]
1020 }
1021}
1022
1023