1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5 Dialect,
6 shared::{
7 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8 DialectIncludes, DialectInstructions, DialectTypes, DialectWmmaCompiler, Elem, Flags,
9 FmtLeft, Fragment, FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory,
10 SupportedWmmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
11 },
12};
13use cubecl_core::{
14 compute::{Location, Visibility},
15 ir::{self as gpu, Id},
16};
17
18use super::{
19 AddressSpace, Extension,
20 arch::MetalArchitecture,
21 extension::{format_ffs, format_mulhi},
22 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
23};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct MslDialect {}
27
28impl Dialect for MslDialect {
31 type Architecture = MetalArchitecture;
32}
33
34impl DialectIncludes<Self> for MslDialect {
37 type Extension = Extension<Self>;
38
39 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
40 write!(
41 f,
42 "
43#include <metal_stdlib>
44using namespace metal;
45"
46 )?;
47 Ok(())
48 }
49
50 fn compile_extensions(
51 f: &mut std::fmt::Formatter<'_>,
52 extensions: &[Self::Extension],
53 ) -> std::fmt::Result {
54 for extension in extensions {
55 match extension {
56 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
57 Extension::Ffs(elem) => format_ffs(f, elem)?,
58 Extension::MulHi(elem) => format_mulhi(f, elem)?,
59 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
60 Extension::NoExtension => {}
61 }
62 }
63 Ok(())
64 }
65
66 fn register_instruction_extension(
67 extensions: &mut Vec<Self::Extension>,
68 instruction: &Instruction<Self>,
69 ) {
70 let mut register_extension = |extension: Self::Extension| {
71 if !extensions.contains(&extension) {
72 extensions.push(extension);
73 }
74 };
75 #[allow(clippy::single_match)]
76 match instruction {
77 shared::Instruction::<Self>::Erf(instruction) => {
78 register_extension(Extension::Erf(
79 instruction.input.elem(),
80 instruction.out.elem(),
81 ));
82 }
83 shared::Instruction::<Self>::FindFirstSet(instruction) => {
84 let input_elem = instruction.input.elem();
85 match input_elem {
86 Elem::U32 | Elem::U64 => {
87 register_extension(Extension::Ffs(instruction.input.elem()));
88 }
89 Elem::I32 => {
90 register_extension(Extension::Ffs(Elem::<Self>::U32));
91 register_extension(Extension::Ffs(instruction.input.elem()));
92 }
93 Elem::I64 => {
94 register_extension(Extension::Ffs(Elem::<Self>::U64));
95 register_extension(Extension::Ffs(instruction.input.elem()));
96 }
97 _ => {
98 register_extension(Extension::Ffs(Elem::<Self>::U32));
99 }
100 }
101 }
102 shared::Instruction::<Self>::HiMul(instruction) => {
103 register_extension(Extension::MulHi(instruction.out.elem()));
104 }
105 shared::Instruction::<Self>::Tanh(instruction) => {
106 register_extension(Extension::SafeTanh(instruction.input.item()));
107 }
108 _ => {}
109 }
110 }
111
112 fn register_warp_instruction_extension(
113 _extensions: &mut Vec<Self::Extension>,
114 _instruction: &WarpInstruction<Self>,
115 ) {
116 }
117}
118
119impl DialectTypes<Self> for MslDialect {
122 fn item_can_be_optimized() -> bool {
123 false
124 }
125
126 fn compile_type_definitions(
127 f: &mut std::fmt::Formatter<'_>,
128 items: &std::collections::HashSet<crate::shared::Item<Self>>,
129 _scalars: &[(Elem<Self>, usize)],
130 _flags: &Flags,
131 ) -> std::fmt::Result {
132 for item in items.iter() {
133 let elem = item.elem;
134 let size = item.vectorization;
135 let alignment = elem.size() * size;
136 if size > 1 {
137 write!(
138 f,
139 "
140struct alignas({alignment}) {item} {{"
141 )?;
142
143 for i in 0..size {
144 write!(
145 f,
146 "
147 {elem} i_{i};"
148 )?;
149 }
150
151 f.write_str("\n};\n")?;
152 }
153 }
154 Ok(())
155 }
156
157 fn compile_elem(
158 f: &mut std::fmt::Formatter<'_>,
159 elem: &shared::Elem<Self>,
160 _words: bool,
161 ) -> std::fmt::Result {
162 match elem {
164 shared::Elem::FP4(_)
165 | shared::Elem::FP4x2(_)
166 | shared::Elem::FP6(_)
167 | shared::Elem::FP6x2(_)
168 | shared::Elem::FP8(_)
169 | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in Metal"),
170 shared::Elem::F16 => f.write_str("half"),
171 shared::Elem::F16x2 => panic!("type F162 not supported!"),
172 shared::Elem::F32 => f.write_str("float"),
173 shared::Elem::F64 => panic!("type double not supported!"),
174 shared::Elem::BF16 => f.write_str("bfloat"),
175 shared::Elem::BF16x2 => panic!("type BF162 not supported!"),
176 shared::Elem::TF32 => f.write_str("float"),
177 shared::Elem::I8 => f.write_str("char"),
178 shared::Elem::I16 => f.write_str("short"),
179 shared::Elem::I32 => f.write_str("int"),
180 shared::Elem::I64 => f.write_str("long"),
181 shared::Elem::U8 => f.write_str("uchar"),
182 shared::Elem::U16 => f.write_str("ushort"),
183 shared::Elem::U32 => f.write_str("uint"),
184 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
186 shared::Elem::Atomic(inner) => inner.fmt(f),
187 shared::Elem::_Dialect(_) => Ok(()),
188 }
189 }
190
191 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
192 if 1 == item.vectorization {
193 return write!(f, "{}", item.elem);
194 }
195 if item.native {
196 write!(f, "{}{}", item.elem, item.vectorization)
197 } else {
198 write!(f, "{}_{}", item.elem, item.vectorization)
199 }
200 }
201
202 fn compile_atomic_kind(
203 f: &mut std::fmt::Formatter<'_>,
204 kind: &AtomicKind<Self>,
205 ) -> std::fmt::Result {
206 match kind {
207 AtomicKind::I32 => write!(f, "atomic_int"),
208 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
209 AtomicKind::U32 => write!(f, "atomic_uint"),
210 AtomicKind::U64 => write!(f, "atomic_ulong"),
211 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
212 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
213 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
215 AtomicKind::_Dialect(_) => Ok(()),
216 }
217 }
218
219 fn address_space_for_variable(variable: &Variable<Self>) -> String {
220 format!("{} ", AddressSpace::from(variable))
221 }
222
223 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 write!(f, "thread")
225 }
226
227 fn compile_shared_memory_declaration(
228 f: &mut std::fmt::Formatter<'_>,
229 shared: &SharedMemory<Self>,
230 ) -> std::fmt::Result {
231 let item = shared.item;
232 let index = shared.index;
233 let size = shared.size;
234 let alignment = shared
235 .align
236 .map(|align| format!("alignas({align})"))
237 .unwrap_or_default();
238 writeln!(
239 f,
240 "threadgroup {alignment} {item} shared_memory_{index}[{size}];",
241 )
242 }
243}
244
245impl DialectBindings<Self> for MslDialect {
248 fn compile_kernel_signature(
249 f: &mut std::fmt::Formatter<'_>,
250 kernel_name: &str,
251 tensor_maps: &[Id],
252 buffers: &[Binding<Self>],
253 scalars: &[(Elem<Self>, usize)],
254 flags: &Flags,
255 ) -> std::fmt::Result {
256 write!(
257 (f),
258 "
259[[kernel]]
260void {kernel_name}("
261 )?;
262 let mut buffer_idx = 0;
264 debug_assert!(
265 tensor_maps.is_empty(),
266 "Tensor maps aren't supported for metal"
267 );
268 for (i, b) in buffers.iter().enumerate() {
269 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
270 }
271 if flags.static_meta_length > 0 {
272 let binding = Binding {
273 id: 0,
274 item: Item::scalar(Elem::<Self>::U32, true),
275 location: Location::Storage,
276 size: None,
277 vis: Visibility::Read,
278 };
279 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
280 }
281 for (elem, _) in scalars.iter() {
282 let binding = Binding {
283 id: 0,
284 item: Item::scalar(*elem, true),
285 location: Location::Storage,
286 size: None,
287 vis: Visibility::Read,
288 };
289
290 let name = format!("scalars_{elem}");
291 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
292 }
293
294 let builtins = vec![
296 (
297 flags.indexes.absolute_pos_tuple,
298 Variable::<Self>::AbsolutePosBaseName,
299 ),
300 (
301 flags.indexes.cube_dim_tuple,
302 Variable::<Self>::CubeDimBaseName,
303 ),
304 (
305 flags.indexes.cube_count_tuple,
306 Variable::<Self>::CubeCountBaseName,
307 ),
308 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
309 (
310 flags.indexes.unit_pos_tuple,
311 Variable::<Self>::UnitPosBaseName,
312 ),
313 (
314 flags.indexes.cube_pos_tuple,
315 Variable::<Self>::CubePosBaseName,
316 ),
317 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
318 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
319 (flags.indexes.plane_index, Variable::<Self>::PlanePos),
320 ];
321 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
322 builtins
323 .iter()
324 .filter(|(cond, _)| *cond)
325 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
326 f.write_str("\n)")
327 }
328}
329
330impl DialectCubeBuiltins<Self> for MslDialect {
333 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
338 let absolute_pos = flags.absolute_pos;
339 let cube_count = flags.cube_count;
340 let cube_dim = flags.cube_dim;
341 let cube_pos = flags.cube_pos;
342 let plane_dim_checked = flags.plane_dim_checked;
343 let plane_index = flags.plane_index;
344 let unit_pos = flags.unit_pos;
345 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
346 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
347 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
348 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
349 let cluster_pos = flags.cluster_pos;
350 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
351 let unit_pos_plane = flags.unit_pos_plane || plane_index;
352 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
353 CubeIndexFlags {
354 absolute_pos_tuple,
355 absolute_pos,
356 cube_count_tuple,
357 cube_count,
358 cube_dim_tuple,
359 cube_dim,
360 cube_pos_tuple,
361 cube_pos,
362 plane_dim,
363 plane_dim_checked,
364 plane_index,
365 unit_pos_tuple,
366 unit_pos,
367 unit_pos_plane,
368 cluster_pos,
369 }
370 }
371
372 fn compile_absolute_pos_tuple_computation(
373 _f: &mut std::fmt::Formatter<'_>,
374 ) -> std::fmt::Result {
375 Ok(())
377 }
378
379 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.write_str("thread_pos_in_grid")
381 }
382
383 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 f.write_str("thread_index_in_grid")
385 }
386
387 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 Self::compile_absolute_pos_base_name(f)?;
389 write!(f, ".x")
390 }
391
392 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 Self::compile_absolute_pos_base_name(f)?;
394 write!(f, ".y")
395 }
396
397 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 Self::compile_absolute_pos_base_name(f)?;
399 write!(f, ".z")
400 }
401
402 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.write_str("threadgroups_per_grid")
404 }
405
406 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 f.write_str("total_threadgroups_in_grid")
408 }
409
410 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 Self::compile_cube_count_base_name(f)?;
412 write!(f, ".x")
413 }
414
415 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 Self::compile_cube_count_base_name(f)?;
417 write!(f, ".y")
418 }
419
420 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 Self::compile_cube_count_base_name(f)?;
422 write!(f, ".z")
423 }
424
425 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 f.write_str("threads_per_threadgroup")
427 }
428
429 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 f.write_str("total_thread_in_threadgroup")
431 }
432
433 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434 Self::compile_cube_dim_base_name(f)?;
435 write!(f, ".x")
436 }
437
438 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 Self::compile_cube_dim_base_name(f)?;
440 write!(f, ".y")
441 }
442
443 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 Self::compile_cube_dim_base_name(f)?;
445 write!(f, ".z")
446 }
447
448 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449 f.write_str("threadgroup_pos_in_grid")
450 }
451
452 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 f.write_str("threadgroup_index_in_grid")
454 }
455
456 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 Self::compile_cube_pos_base_name(f)?;
458 write!(f, ".x")
459 }
460
461 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 Self::compile_cube_pos_base_name(f)?;
463 write!(f, ".y")
464 }
465
466 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 Self::compile_cube_pos_base_name(f)?;
468 write!(f, ".z")
469 }
470
471 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 Ok(())
474 }
475
476 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 f.write_str("thread_pos_in_threadgroup")
478 }
479
480 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481 f.write_str("thread_index_in_threadgroup")
482 }
483
484 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485 Self::compile_unit_pos_base_name(f)?;
486 write!(f, ".x")
487 }
488
489 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 Self::compile_unit_pos_base_name(f)?;
491 write!(f, ".y")
492 }
493
494 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495 Self::compile_unit_pos_base_name(f)?;
496 write!(f, ".z")
497 }
498
499 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 f.write_str("simd_size")
501 }
502
503 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.write_str("threads_per_simdgroup_checked")
505 }
506
507 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 f.write_str("simd_group_id")
509 }
510
511 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512 f.write_str("simd_lane_id")
513 }
514}
515
516impl DialectInstructions<Self> for MslDialect {
519 fn compile_atomic_add(
521 f: &mut std::fmt::Formatter<'_>,
522 lhs: &Variable<Self>,
523 rhs: &Variable<Self>,
524 out: &Variable<Self>,
525 ) -> std::fmt::Result {
526 let out = out.fmt_left();
527 writeln!(
528 f,
529 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
530 )
531 }
532
533 fn compile_atomic_and(
534 f: &mut std::fmt::Formatter<'_>,
535 lhs: &Variable<Self>,
536 rhs: &Variable<Self>,
537 out: &Variable<Self>,
538 ) -> std::fmt::Result {
539 let out = out.fmt_left();
540 writeln!(
541 f,
542 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
543 )
544 }
545
546 fn compile_atomic_cas(
547 f: &mut std::fmt::Formatter<'_>,
548 input: &Variable<Self>,
549 cmp: &Variable<Self>,
550 val: &Variable<Self>,
551 out: &Variable<Self>,
552 ) -> std::fmt::Result {
553 let out = out.fmt_left();
554 writeln!(
555 f,
556 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
557 )
558 }
559
560 fn compile_atomic_load(
561 f: &mut std::fmt::Formatter<'_>,
562 input: &Variable<Self>,
563 out: &Variable<Self>,
564 ) -> std::fmt::Result {
565 let out = out.fmt_left();
566 writeln!(
567 f,
568 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
569 )
570 }
571
572 fn compile_atomic_max(
573 f: &mut std::fmt::Formatter<'_>,
574 lhs: &Variable<Self>,
575 rhs: &Variable<Self>,
576 out: &Variable<Self>,
577 ) -> std::fmt::Result {
578 let out = out.fmt_left();
579 writeln!(
580 f,
581 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
582 )
583 }
584
585 fn compile_atomic_min(
586 f: &mut std::fmt::Formatter<'_>,
587 lhs: &Variable<Self>,
588 rhs: &Variable<Self>,
589 out: &Variable<Self>,
590 ) -> std::fmt::Result {
591 let out = out.fmt_left();
592 writeln!(
593 f,
594 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
595 )
596 }
597
598 fn compile_atomic_or(
599 f: &mut std::fmt::Formatter<'_>,
600 lhs: &Variable<Self>,
601 rhs: &Variable<Self>,
602 out: &Variable<Self>,
603 ) -> std::fmt::Result {
604 let out = out.fmt_left();
605 writeln!(
606 f,
607 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
608 )
609 }
610
611 fn compile_atomic_store(
612 f: &mut std::fmt::Formatter<'_>,
613 input: &Variable<Self>,
614 out: &Variable<Self>,
615 ) -> std::fmt::Result {
616 writeln!(
617 f,
618 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
619 )
620 }
621
622 fn compile_atomic_sub(
623 f: &mut std::fmt::Formatter<'_>,
624 lhs: &Variable<Self>,
625 rhs: &Variable<Self>,
626 out: &Variable<Self>,
627 ) -> std::fmt::Result {
628 let out = out.fmt_left();
629 writeln!(
630 f,
631 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
632 )
633 }
634
635 fn compile_atomic_swap(
636 f: &mut std::fmt::Formatter<'_>,
637 lhs: &Variable<Self>,
638 rhs: &Variable<Self>,
639 out: &Variable<Self>,
640 ) -> std::fmt::Result {
641 let out = out.fmt_left();
642 writeln!(
643 f,
644 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
645 )
646 }
647
648 fn compile_atomic_xor(
649 f: &mut std::fmt::Formatter<'_>,
650 lhs: &Variable<Self>,
651 rhs: &Variable<Self>,
652 out: &Variable<Self>,
653 ) -> std::fmt::Result {
654 let out = out.fmt_left();
655 writeln!(
656 f,
657 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
658 )
659 }
660
661 fn compile_instruction_printf(
663 f: &mut std::fmt::Formatter<'_>,
664 format_string: &str,
665 args: &[Variable<Self>],
666 ) -> std::fmt::Result {
667 let format_string = format_string
668 .replace("\t", "\\t")
669 .replace("\n", "\\n")
670 .replace("\r", "\\r");
671 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
672 let args = match args.is_empty() {
673 true => "".to_string(),
674 false => format!(", {}", args.join(",")),
675 };
676 writeln!(f, "os_log_default.log(\"{format_string}\"{args});")
677 }
678
679 fn compile_instruction_log1p_scalar<T: Component<Self>>(
681 f: &mut std::fmt::Formatter<'_>,
682 input: T,
683 ) -> std::fmt::Result {
684 match input.elem() {
685 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
686 write!(f, "log(half(1.0f) + {input})")
687 }
688 _ => write!(f, "log(1.0f + {input})"),
689 }
690 }
691
692 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
694 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
695 }
696
697 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
698 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
699 }
700
701 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
702 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
703 }
704
705 fn compile_instruction_tanh_scalar<T: Component<Self>>(
707 f: &mut std::fmt::Formatter<'_>,
708 input: T,
709 ) -> std::fmt::Result {
710 write!(f, "safe_tanh_scalar({input})")
711 }
712
713 fn compile_instruction_find_first_set<T: Component<Self>>(
715 f: &mut std::fmt::Formatter<'_>,
716 input: T,
717 out_elem: Elem<Self>,
718 ) -> std::fmt::Result {
719 write!(f, "{out_elem}(")?;
720 match input.elem() {
721 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
722 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
723 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
724 }?;
725 write!(f, ")")
726 }
727
728 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
729 f: &mut std::fmt::Formatter<'_>,
730 input: T,
731 out_elem: Elem<Self>,
732 ) -> std::fmt::Result {
733 write!(f, "{out_elem}(clz({input}))")
734 }
735
736 fn compile_instruction_popcount_scalar<T: Component<Self>>(
737 f: &mut std::fmt::Formatter<'_>,
738 input: T,
739 out_elem: Elem<Self>,
740 ) -> std::fmt::Result {
741 write!(f, "{out_elem}(")?;
742 match input.elem() {
743 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
744 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
745 }?;
746 write!(f, ")")
747 }
748
749 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
750 f: &mut std::fmt::Formatter<'_>,
751 input: T,
752 out_elem: Elem<Self>,
753 ) -> std::fmt::Result {
754 write!(f, "{out_elem}(")?;
755 match out_elem {
756 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
757 _ => write!(
758 f,
759 "reverse_bits({}) >> {}",
760 shared::unary::zero_extend(input),
761 (size_of::<u32>() - out_elem.size()) * 8
762 ),
763 }?;
764 write!(f, ")")
765 }
766
767 fn compile_instruction_max_function_name(
769 f: &mut std::fmt::Formatter<'_>,
770 _item: Item<Self>,
771 ) -> std::fmt::Result {
772 write!(f, "max")
773 }
774
775 fn compile_instruction_min_function_name(
776 f: &mut std::fmt::Formatter<'_>,
777 _item: Item<Self>,
778 ) -> std::fmt::Result {
779 write!(f, "min")
780 }
781
782 fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783 write!(f, "pow")
784 }
785
786 fn compile_instruction_half_function_name_prefix() -> &'static str {
787 ""
788 }
789
790 fn compile_instruction_half2_function_name_prefix() -> &'static str {
791 ""
792 }
793
794 fn compile_warp_shuffle(
796 f: &mut std::fmt::Formatter<'_>,
797 var: &str,
798 source: &str,
799 ) -> std::fmt::Result {
800 write!(f, "simd_shuffle({var}, {source})")
801 }
802
803 fn compile_warp_shuffle_xor(
804 f: &mut std::fmt::Formatter<'_>,
805 var: &str,
806 _elem: &Elem<Self>,
807 offset: &str,
808 ) -> std::fmt::Result {
809 write!(f, "simd_shuffle_xor({var}, {offset})")
810 }
811
812 fn compile_warp_shuffle_up(
813 f: &mut std::fmt::Formatter<'_>,
814 var: &str,
815 offset: &str,
816 ) -> std::fmt::Result {
817 write!(f, "simd_shuffle_up({var}, {offset})")
818 }
819
820 fn compile_warp_shuffle_down(
821 f: &mut std::fmt::Formatter<'_>,
822 var: &str,
823 offset: &str,
824 ) -> std::fmt::Result {
825 write!(f, "simd_shuffle_down({var}, {offset})")
826 }
827
828 fn compile_warp_all<T: Component<Self>>(
829 f: &mut std::fmt::Formatter<'_>,
830 input: &T,
831 ) -> std::fmt::Result {
832 write!(f, "simd_all({input})")
833 }
834
835 fn compile_warp_any<T: Component<Self>>(
836 f: &mut std::fmt::Formatter<'_>,
837 input: &T,
838 ) -> std::fmt::Result {
839 write!(f, "simd_any({input})")
840 }
841
842 fn compile_warp_ballot(
843 f: &mut std::fmt::Formatter<'_>,
844 input: &Variable<Self>,
845 out_elem: &Elem<Self>,
846 ) -> std::fmt::Result {
847 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
848 }
849}
850
851impl DialectWmmaCompiler<Self> for MslDialect {
854 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855 writeln!(f, "#include <metal_simdgroup_matrix>")
856 }
857
858 fn compile_wmma_type_definitions(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
859 Ok(())
861 }
862
863 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
864 Ok(())
866 }
867
868 fn compile_wmma_fragment_declaration(
869 f: &mut std::fmt::Formatter<'_>,
870 var: &crate::shared::Variable<MslDialect>,
871 ) -> std::fmt::Result {
872 wmma_api_base::compile_fragment_declaration(f, var)
873 }
874
875 fn compile_wwma_fragment_ident(
876 _f: &mut std::fmt::Formatter<'_>,
877 _ident: &FragmentIdent<Self>,
878 ) -> std::fmt::Result {
879 Ok(())
881 }
882
883 fn compile_wmma_fragment_layout(
884 _f: &mut std::fmt::Formatter<'_>,
885 _layout: &FragmentLayout<Self>,
886 ) -> std::fmt::Result {
887 Ok(())
889 }
890
891 fn compile_wmma_fragment(
892 f: &mut std::fmt::Formatter<'_>,
893 fragment: &Fragment<Self>,
894 ) -> std::fmt::Result {
895 let ty = fragment.elem;
896 let m = fragment.m;
898 let n = fragment.n;
899 let k = fragment.k;
900 if m != 8 || n != 8 || k != 8 {
901 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
902 }
903 write!(f, "simdgroup_{ty}8x8")
904 }
905
906 fn compile_wmma_instruction(
907 f: &mut std::fmt::Formatter<'_>,
908 instruction: &WmmaInstruction<Self>,
909 ) -> std::fmt::Result {
910 match instruction {
911 WmmaInstruction::Fill { frag, value } => {
912 match frag {
913 Variable::WmmaFragment { .. } => {
914 let ty = frag.elem();
915 writeln!(
917 f,
918 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
919 )
920 }
921 _ => panic!("should be a fragment"),
922 }
923 }
924 WmmaInstruction::Load {
925 frag,
926 value,
927 stride,
928 offset,
929 layout: _layout,
930 } => {
931 let transpose = match frag {
932 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
933 Some(FragmentLayout::RowMajor) => false,
934 Some(FragmentLayout::ColMajor) => true,
935 _ => false,
936 },
937 _ => panic!("should be a fragment"),
938 };
939 let item = value.item();
940 if item.vectorization > 1 {
941 let elem = item.elem;
942 writeln!(
943 f,
944 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
945 )
946 } else {
947 writeln!(
948 f,
949 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
950 )
951 }
952 }
953 WmmaInstruction::Execute {
954 frag_a: a,
955 frag_b: b,
956 frag_c: c,
957 frag_d: d,
958 ..
959 } => {
960 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
961 }
962 WmmaInstruction::Store {
963 output,
964 frag,
965 stride,
966 offset,
967 layout: _layout,
968 } => {
969 let item = output.item();
970 let mut reinterpret_cast = item.vectorization > 1;
971 let elem = match item.elem {
972 Elem::BF16 => {
973 reinterpret_cast = true;
974 Elem::F16
975 }
976 _ => item.elem,
977 };
978 if reinterpret_cast {
979 writeln!(
980 f,
981 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
982 )
983 } else {
984 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
985 }?;
986 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
987 }
988 WmmaInstruction::Cast { input, output } => {
989 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
990 let ty = match output {
991 Variable::WmmaFragment { frag, .. } => frag.elem,
992 _ => panic!("should be a fragment"),
993 };
994 match ty {
995 Elem::BF16 => {
996 let addr_space = Self::address_space_for_variable(output);
997 let elem = Elem::<Self>::F16;
998 writeln!(
1001 f,
1002 "for(int e=0; e<8; e++) {{
1003 {ty} elem = {ty}({input}.thread_elements()[e]);
1004 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1005}}"
1006 )
1007 }
1008 _ => {
1009 writeln!(
1010 f,
1011 "for(int e=0; e<8; e++) {{
1012 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1013}}"
1014 )
1015 }
1016 }
1017 }
1018 }
1019 }
1020
1021 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedWmmaCombinations {
1022 vec![
1023 (
1024 gpu::Elem::Float(gpu::FloatKind::F16),
1025 gpu::Elem::Float(gpu::FloatKind::F16),
1026 gpu::Elem::Float(gpu::FloatKind::F16),
1027 vec![(8, 8, 8)],
1028 ),
1029 (
1030 gpu::Elem::Float(gpu::FloatKind::F16),
1031 gpu::Elem::Float(gpu::FloatKind::F16),
1032 gpu::Elem::Float(gpu::FloatKind::F32),
1033 vec![(8, 8, 8)],
1034 ),
1035 (
1036 gpu::Elem::Float(gpu::FloatKind::BF16),
1037 gpu::Elem::Float(gpu::FloatKind::BF16),
1038 gpu::Elem::Float(gpu::FloatKind::BF16),
1039 vec![(8, 8, 8)],
1040 ),
1041 (
1042 gpu::Elem::Float(gpu::FloatKind::F32),
1043 gpu::Elem::Float(gpu::FloatKind::F32),
1044 gpu::Elem::Float(gpu::FloatKind::F32),
1045 vec![(8, 8, 8)],
1046 ),
1047 ]
1048 }
1049}
1050
1051