1use super::{
2 AddressSpace, Extension,
3 arch::MetalArchitecture,
4 extension::{format_ffs, format_mulhi},
5 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8 Dialect,
9 shared::{
10 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11 DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12 DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
14 SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15 },
16};
17use core::panic;
18use cubecl_core::{
19 ir::{self as gpu, features::MmaConfig},
20 prelude::{Location, Visibility},
21};
22use std::fmt::Display;
23
24#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
25pub struct MslDialect {}
26
27impl Dialect for MslDialect {
30 type Architecture = MetalArchitecture;
31}
32
33impl MslDialect {
34 fn warp_op_vectorized(
35 f: &mut core::fmt::Formatter<'_>,
36 input: &Variable<Self>,
37 out: &Variable<Self>,
38 simd_op_prefix: &str,
39 simd_op_suffix: &str,
40 ) -> core::fmt::Result {
41 let out = out.fmt_left();
42 let vectorization = input.item().vectorization;
43
44 f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
45
46 for k in 0..vectorization {
47 let index = if vectorization > 1 {
48 format!(".i_{k}")
49 } else {
50 String::new()
51 };
52 let comma = if k + 1 < vectorization { "," } else { "" };
53
54 writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
55 }
56
57 f.write_fmt(format_args!("}};\n"))
58 }
59}
60
61impl DialectWarpReduceCompiler<Self> for MslDialect {
62 fn warp_reduce_sum(
63 f: &mut core::fmt::Formatter<'_>,
64 input: &Variable<Self>,
65 out: &Variable<Self>,
66 ) -> core::fmt::Result {
67 Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
68 }
69 fn warp_reduce_prod(
70 f: &mut core::fmt::Formatter<'_>,
71 input: &Variable<Self>,
72 out: &Variable<Self>,
73 ) -> core::fmt::Result {
74 Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
75 }
76 fn warp_reduce_max(
77 f: &mut core::fmt::Formatter<'_>,
78 input: &Variable<Self>,
79 out: &Variable<Self>,
80 ) -> core::fmt::Result {
81 Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
82 }
83 fn warp_reduce_min(
84 f: &mut core::fmt::Formatter<'_>,
85 input: &Variable<Self>,
86 out: &Variable<Self>,
87 ) -> core::fmt::Result {
88 Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
89 }
90 fn warp_reduce_all(
91 f: &mut core::fmt::Formatter<'_>,
92 input: &Variable<Self>,
93 out: &Variable<Self>,
94 ) -> core::fmt::Result {
95 Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
96 }
97 fn warp_reduce_any(
98 f: &mut core::fmt::Formatter<'_>,
99 input: &Variable<Self>,
100 out: &Variable<Self>,
101 ) -> core::fmt::Result {
102 Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
103 }
104 fn warp_reduce_sum_inclusive(
105 f: &mut core::fmt::Formatter<'_>,
106 input: &Variable<Self>,
107 out: &Variable<Self>,
108 ) -> core::fmt::Result {
109 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
110 }
111 fn warp_reduce_prod_inclusive(
112 f: &mut core::fmt::Formatter<'_>,
113 input: &Variable<Self>,
114 out: &Variable<Self>,
115 ) -> core::fmt::Result {
116 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
117 }
118 fn warp_reduce_sum_exclusive(
119 f: &mut core::fmt::Formatter<'_>,
120 input: &Variable<Self>,
121 out: &Variable<Self>,
122 ) -> core::fmt::Result {
123 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
124 }
125 fn warp_reduce_prod_exclusive(
126 f: &mut core::fmt::Formatter<'_>,
127 input: &Variable<Self>,
128 out: &Variable<Self>,
129 ) -> core::fmt::Result {
130 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
131 }
132}
133
134impl DialectIncludes<Self> for MslDialect {
137 type Extension = Extension<Self>;
138
139 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
140 write!(
141 f,
142 "
143#include <metal_stdlib>
144using namespace metal;
145"
146 )?;
147 Ok(())
148 }
149
150 fn compile_extensions(
151 f: &mut std::fmt::Formatter<'_>,
152 extensions: &[Self::Extension],
153 ) -> std::fmt::Result {
154 for extension in extensions {
155 match extension {
156 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
157 Extension::Ffs(elem) => format_ffs(f, elem)?,
158 Extension::MulHi(elem) => format_mulhi(f, elem)?,
159 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
160 Extension::NoExtension => {}
161 }
162 }
163 Ok(())
164 }
165
166 fn register_instruction_extension(
167 extensions: &mut Vec<Self::Extension>,
168 instruction: &Instruction<Self>,
169 ) {
170 let mut register_extension = |extension: Self::Extension| {
171 if !extensions.contains(&extension) {
172 extensions.push(extension);
173 }
174 };
175 #[allow(clippy::single_match)]
176 match instruction {
177 shared::Instruction::<Self>::Erf(instruction) => {
178 register_extension(Extension::Erf(
179 instruction.input.elem(),
180 instruction.out.elem(),
181 ));
182 }
183 shared::Instruction::<Self>::FindFirstSet(instruction) => {
184 let input_elem = instruction.input.elem();
185 match input_elem {
186 Elem::U32 | Elem::U64 => {
187 register_extension(Extension::Ffs(instruction.input.elem()));
188 }
189 Elem::I32 => {
190 register_extension(Extension::Ffs(Elem::<Self>::U32));
191 register_extension(Extension::Ffs(instruction.input.elem()));
192 }
193 Elem::I64 => {
194 register_extension(Extension::Ffs(Elem::<Self>::U64));
195 register_extension(Extension::Ffs(instruction.input.elem()));
196 }
197 _ => {
198 register_extension(Extension::Ffs(Elem::<Self>::U32));
199 }
200 }
201 }
202 shared::Instruction::<Self>::HiMul(instruction) => {
203 register_extension(Extension::MulHi(instruction.out.elem()));
204 }
205 shared::Instruction::<Self>::Tanh(instruction) => {
206 register_extension(Extension::SafeTanh(instruction.input.item()));
207 }
208 _ => {}
209 }
210 }
211
212 fn register_warp_instruction_extension(
213 _extensions: &mut Vec<Self::Extension>,
214 _instruction: &WarpInstruction<Self>,
215 ) {
216 }
217}
218
219impl DialectTypes<Self> for MslDialect {
222 fn item_can_be_optimized() -> bool {
223 false
224 }
225
226 fn compile_type_definitions(
227 f: &mut std::fmt::Formatter<'_>,
228 items: &std::collections::HashSet<crate::shared::Item<Self>>,
229 _scalars: &[(Elem<Self>, usize)],
230 _flags: &Flags<Self>,
231 ) -> std::fmt::Result {
232 for item in items.iter() {
233 let elem = item.elem;
234 let size = item.vectorization;
235 let alignment = elem.size() * size;
236 if size > 1 {
237 write!(
238 f,
239 "
240struct alignas({alignment}) {item} {{"
241 )?;
242
243 for i in 0..size {
244 write!(
245 f,
246 "
247 {elem} i_{i};"
248 )?;
249 }
250
251 f.write_str("\n};\n")?;
252 }
253 }
254 Ok(())
255 }
256
257 fn compile_elem(
258 f: &mut std::fmt::Formatter<'_>,
259 elem: &shared::Elem<Self>,
260 _words: bool,
261 ) -> std::fmt::Result {
262 match elem {
264 shared::Elem::FP4(_)
265 | shared::Elem::FP4x2(_)
266 | shared::Elem::FP6(_)
267 | shared::Elem::FP6x2(_)
268 | shared::Elem::FP8(_)
269 | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270 shared::Elem::F16 => f.write_str("half"),
271 shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272 shared::Elem::F32 => f.write_str("float"),
273 shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274 shared::Elem::BF16 => f.write_str("bfloat"),
275 shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276 shared::Elem::TF32 => f.write_str("float"),
277 shared::Elem::I8 => f.write_str("char"),
278 shared::Elem::I16 => f.write_str("short"),
279 shared::Elem::I32 => f.write_str("int"),
280 shared::Elem::I64 => f.write_str("long"),
281 shared::Elem::U8 => f.write_str("uchar"),
282 shared::Elem::U16 => f.write_str("ushort"),
283 shared::Elem::U32 => f.write_str("uint"),
284 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
286 shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287 shared::Elem::Atomic(inner) => inner.fmt(f),
288 shared::Elem::_Dialect(_) => Ok(()),
289 }
290 }
291
292 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293 if 1 == item.vectorization {
294 return write!(f, "{}", item.elem);
295 }
296 if item.native {
297 write!(f, "{}{}", item.elem, item.vectorization)
298 } else {
299 write!(f, "{}_{}", item.elem, item.vectorization)
300 }
301 }
302
303 fn compile_atomic_kind(
304 f: &mut std::fmt::Formatter<'_>,
305 kind: &AtomicKind<Self>,
306 ) -> std::fmt::Result {
307 match kind {
308 AtomicKind::I32 => write!(f, "atomic_int"),
309 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310 AtomicKind::U32 => write!(f, "atomic_uint"),
311 AtomicKind::U64 => write!(f, "atomic_ulong"),
312 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
313 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
314 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316 AtomicKind::_Dialect(_) => Ok(()),
317 }
318 }
319
320 fn address_space_for_variable(variable: &Variable<Self>) -> String {
321 format!("{} ", AddressSpace::from(variable))
322 }
323
324 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 write!(f, "thread")
326 }
327
328 fn compile_shared_memory_declaration(
329 f: &mut std::fmt::Formatter<'_>,
330 shared: &SharedMemory<Self>,
331 ) -> std::fmt::Result {
332 match shared {
333 SharedMemory::Array {
334 index,
335 item,
336 length,
337 offset,
338 ..
339 } => {
340 let size_bytes = length * item.size();
341 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342 writeln!(
343 f,
344 "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345 )
346 }
347 SharedMemory::Value {
348 index,
349 item,
350 offset,
351 ..
352 } => {
353 let size_bytes = item.size();
354 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355 writeln!(
356 f,
357 "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358 )
359 }
360 }
361 }
362}
363
364impl DialectBindings<Self> for MslDialect {
367 fn compile_kernel_signature(
368 f: &mut std::fmt::Formatter<'_>,
369 kernel_name: &str,
370 tensor_maps: &[Binding<Self>],
371 buffers: &[Binding<Self>],
372 scalars: &[(Elem<Self>, usize)],
373 flags: &Flags<Self>,
374 ) -> std::fmt::Result {
375 write!(
376 (f),
377 "
378[[kernel]]
379void {kernel_name}("
380 )?;
381 let mut buffer_idx = 0;
383 debug_assert!(
384 tensor_maps.is_empty(),
385 "Tensor maps aren't supported for metal"
386 );
387 for (i, b) in buffers.iter().enumerate() {
388 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
389 }
390 if flags.static_meta_length > 0 {
391 let binding = Binding {
392 id: 0,
393 item: Item::scalar(Elem::<Self>::U32, true),
394 location: Location::Storage,
395 size: None,
396 vis: Visibility::Read,
397 };
398 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
399 }
400 for (elem, _) in scalars.iter() {
401 let binding = Binding {
402 id: 0,
403 item: Item::scalar(*elem, true),
404 location: Location::Storage,
405 size: None,
406 vis: Visibility::Read,
407 };
408
409 let name = format!("scalars_{elem}");
410 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
411 }
412
413 let builtins = vec![
415 (
416 flags.indexes.absolute_pos_tuple,
417 Variable::<Self>::AbsolutePosBaseName,
418 ),
419 (
420 flags.indexes.cube_dim_tuple,
421 Variable::<Self>::CubeDimBaseName,
422 ),
423 (
424 flags.indexes.cube_count_tuple,
425 Variable::<Self>::CubeCountBaseName,
426 ),
427 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
428 (
429 flags.indexes.unit_pos_tuple,
430 Variable::<Self>::UnitPosBaseName,
431 ),
432 (
433 flags.indexes.cube_pos_tuple,
434 Variable::<Self>::CubePosBaseName,
435 ),
436 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
437 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
438 (flags.indexes.plane_pos, Variable::<Self>::PlanePos),
439 ];
440 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
441 builtins
442 .iter()
443 .filter(|(cond, _)| *cond)
444 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
445 f.write_str("\n)")
446 }
447
448 fn compile_bindings_body(
449 f: &mut std::fmt::Formatter<'_>,
450 body: &shared::Body<Self>,
451 ) -> std::fmt::Result {
452 if !body.shared_memories.is_empty() {
453 let size = body
454 .shared_memories
455 .iter()
456 .map(|it| it.offset() + it.size())
457 .max()
458 .unwrap();
459
460 writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
461 }
462 Ok(())
463 }
464}
465
466impl DialectCubeBuiltins<Self> for MslDialect {
469 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
474 let absolute_pos = flags.absolute_pos;
475 let cube_count = flags.cube_count;
476 let cube_dim = flags.cube_dim;
477 let cube_pos = flags.cube_pos;
478 let plane_dim_checked = flags.plane_dim_checked;
479 let plane_index = flags.plane_pos;
480 let unit_pos = flags.unit_pos;
481 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
482 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
483 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
484 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
485 let cluster_pos = flags.cluster_pos;
486 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
487 let unit_pos_plane = flags.unit_pos_plane || plane_index;
488 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
489 CubeIndexFlags {
490 absolute_pos_tuple,
491 absolute_pos,
492 cube_count_tuple,
493 cube_count,
494 cube_dim_tuple,
495 cube_dim,
496 cube_pos_tuple,
497 cube_pos,
498 plane_dim,
499 plane_dim_checked,
500 plane_pos: plane_index,
501 unit_pos_tuple,
502 unit_pos,
503 unit_pos_plane,
504 cluster_pos,
505 }
506 }
507
508 fn compile_absolute_pos_tuple_computation(
509 _f: &mut std::fmt::Formatter<'_>,
510 ) -> std::fmt::Result {
511 Ok(())
513 }
514
515 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 f.write_str("thread_pos_in_grid")
517 }
518
519 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520 f.write_str("thread_index_in_grid")
521 }
522
523 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524 Self::compile_absolute_pos_base_name(f)?;
525 write!(f, ".x")
526 }
527
528 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 Self::compile_absolute_pos_base_name(f)?;
530 write!(f, ".y")
531 }
532
533 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 Self::compile_absolute_pos_base_name(f)?;
535 write!(f, ".z")
536 }
537
538 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539 f.write_str("threadgroups_per_grid")
540 }
541
542 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 f.write_str("total_threadgroups_in_grid")
544 }
545
546 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547 Self::compile_cube_count_base_name(f)?;
548 write!(f, ".x")
549 }
550
551 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552 Self::compile_cube_count_base_name(f)?;
553 write!(f, ".y")
554 }
555
556 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 Self::compile_cube_count_base_name(f)?;
558 write!(f, ".z")
559 }
560
561 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 f.write_str("threads_per_threadgroup")
563 }
564
565 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 f.write_str("total_thread_in_threadgroup")
567 }
568
569 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
570 Self::compile_cube_dim_base_name(f)?;
571 write!(f, ".x")
572 }
573
574 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575 Self::compile_cube_dim_base_name(f)?;
576 write!(f, ".y")
577 }
578
579 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580 Self::compile_cube_dim_base_name(f)?;
581 write!(f, ".z")
582 }
583
584 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 f.write_str("threadgroup_pos_in_grid")
586 }
587
588 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 f.write_str("threadgroup_index_in_grid")
590 }
591
592 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 Self::compile_cube_pos_base_name(f)?;
594 write!(f, ".x")
595 }
596
597 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 Self::compile_cube_pos_base_name(f)?;
599 write!(f, ".y")
600 }
601
602 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603 Self::compile_cube_pos_base_name(f)?;
604 write!(f, ".z")
605 }
606
607 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608 Ok(())
610 }
611
612 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
613 f.write_str("thread_pos_in_threadgroup")
614 }
615
616 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 f.write_str("thread_index_in_threadgroup")
618 }
619
620 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 Self::compile_unit_pos_base_name(f)?;
622 write!(f, ".x")
623 }
624
625 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 Self::compile_unit_pos_base_name(f)?;
627 write!(f, ".y")
628 }
629
630 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 Self::compile_unit_pos_base_name(f)?;
632 write!(f, ".z")
633 }
634
635 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 f.write_str("simd_size")
637 }
638
639 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640 f.write_str("threads_per_simdgroup_checked")
641 }
642
643 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 f.write_str("simd_group_id")
645 }
646
647 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648 f.write_str("simd_lane_id")
649 }
650}
651
652impl DialectInstructions<Self> for MslDialect {
655 fn compile_atomic_add(
657 f: &mut std::fmt::Formatter<'_>,
658 lhs: &Variable<Self>,
659 rhs: &Variable<Self>,
660 out: &Variable<Self>,
661 ) -> std::fmt::Result {
662 let out = out.fmt_left();
663 writeln!(
664 f,
665 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
666 )
667 }
668
669 fn compile_atomic_and(
670 f: &mut std::fmt::Formatter<'_>,
671 lhs: &Variable<Self>,
672 rhs: &Variable<Self>,
673 out: &Variable<Self>,
674 ) -> std::fmt::Result {
675 let out = out.fmt_left();
676 writeln!(
677 f,
678 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
679 )
680 }
681
682 fn compile_atomic_cas(
683 f: &mut std::fmt::Formatter<'_>,
684 input: &Variable<Self>,
685 cmp: &Variable<Self>,
686 val: &Variable<Self>,
687 out: &Variable<Self>,
688 ) -> std::fmt::Result {
689 let out = out.fmt_left();
690 writeln!(
691 f,
692 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
693 )
694 }
695
696 fn compile_atomic_load(
697 f: &mut std::fmt::Formatter<'_>,
698 input: &Variable<Self>,
699 out: &Variable<Self>,
700 ) -> std::fmt::Result {
701 let out = out.fmt_left();
702 writeln!(
703 f,
704 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
705 )
706 }
707
708 fn compile_atomic_max(
709 f: &mut std::fmt::Formatter<'_>,
710 lhs: &Variable<Self>,
711 rhs: &Variable<Self>,
712 out: &Variable<Self>,
713 ) -> std::fmt::Result {
714 let out = out.fmt_left();
715 writeln!(
716 f,
717 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
718 )
719 }
720
721 fn compile_atomic_min(
722 f: &mut std::fmt::Formatter<'_>,
723 lhs: &Variable<Self>,
724 rhs: &Variable<Self>,
725 out: &Variable<Self>,
726 ) -> std::fmt::Result {
727 let out = out.fmt_left();
728 writeln!(
729 f,
730 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
731 )
732 }
733
734 fn compile_atomic_or(
735 f: &mut std::fmt::Formatter<'_>,
736 lhs: &Variable<Self>,
737 rhs: &Variable<Self>,
738 out: &Variable<Self>,
739 ) -> std::fmt::Result {
740 let out = out.fmt_left();
741 writeln!(
742 f,
743 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
744 )
745 }
746
747 fn compile_atomic_store(
748 f: &mut std::fmt::Formatter<'_>,
749 input: &Variable<Self>,
750 out: &Variable<Self>,
751 ) -> std::fmt::Result {
752 writeln!(
753 f,
754 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
755 )
756 }
757
758 fn compile_atomic_sub(
759 f: &mut std::fmt::Formatter<'_>,
760 lhs: &Variable<Self>,
761 rhs: &Variable<Self>,
762 out: &Variable<Self>,
763 ) -> std::fmt::Result {
764 let out = out.fmt_left();
765 writeln!(
766 f,
767 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
768 )
769 }
770
771 fn compile_atomic_swap(
772 f: &mut std::fmt::Formatter<'_>,
773 lhs: &Variable<Self>,
774 rhs: &Variable<Self>,
775 out: &Variable<Self>,
776 ) -> std::fmt::Result {
777 let out = out.fmt_left();
778 writeln!(
779 f,
780 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
781 )
782 }
783
784 fn compile_atomic_xor(
785 f: &mut std::fmt::Formatter<'_>,
786 lhs: &Variable<Self>,
787 rhs: &Variable<Self>,
788 out: &Variable<Self>,
789 ) -> std::fmt::Result {
790 let out = out.fmt_left();
791 writeln!(
792 f,
793 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
794 )
795 }
796
797 fn compile_saturating_add(
798 f: &mut std::fmt::Formatter<'_>,
799 lhs: impl Display,
800 rhs: impl Display,
801 _item: Item<Self>,
802 ) -> std::fmt::Result {
803 write!(f, "addsat({lhs}, {rhs})")
804 }
805
806 fn compile_saturating_sub(
807 f: &mut std::fmt::Formatter<'_>,
808 lhs: impl Display,
809 rhs: impl Display,
810 _item: Item<Self>,
811 ) -> std::fmt::Result {
812 write!(f, "subsat({lhs}, {rhs})")
813 }
814
815 fn compile_instruction_printf(
817 f: &mut std::fmt::Formatter<'_>,
818 format_string: &str,
819 args: &[Variable<Self>],
820 ) -> std::fmt::Result {
821 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
822 let args = match args.is_empty() {
823 true => "".to_string(),
824 false => format!(", {}", args.join(",")),
825 };
826 writeln!(f, "os_log_default.log({format_string:?}{args});")
827 }
828
829 fn compile_instruction_log1p_scalar<T: Component<Self>>(
831 f: &mut std::fmt::Formatter<'_>,
832 input: T,
833 ) -> std::fmt::Result {
834 match input.elem() {
835 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
836 write!(f, "log(half(1.0f) + {input})")
837 }
838 _ => write!(f, "log(1.0f + {input})"),
839 }
840 }
841
842 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
845 }
846
847 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
849 }
850
851 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
852 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
853 }
854
855 fn compile_instruction_tanh_scalar<T: Component<Self>>(
857 f: &mut std::fmt::Formatter<'_>,
858 input: T,
859 ) -> std::fmt::Result {
860 write!(f, "safe_tanh_scalar({input})")
861 }
862
863 fn compile_instruction_find_first_set<T: Component<Self>>(
865 f: &mut std::fmt::Formatter<'_>,
866 input: T,
867 out_elem: Elem<Self>,
868 ) -> std::fmt::Result {
869 write!(f, "{out_elem}(")?;
870 match input.elem() {
871 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
872 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
873 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
874 }?;
875 write!(f, ")")
876 }
877
878 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
879 f: &mut std::fmt::Formatter<'_>,
880 input: T,
881 out_elem: Elem<Self>,
882 ) -> std::fmt::Result {
883 write!(f, "{out_elem}(clz({input}))")
884 }
885
886 fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
887 f: &mut std::fmt::Formatter<'_>,
888 input: T,
889 out_elem: Elem<Self>,
890 ) -> std::fmt::Result {
891 write!(f, "{out_elem}(ctz({input}))")
892 }
893
894 fn compile_instruction_popcount_scalar<T: Component<Self>>(
895 f: &mut std::fmt::Formatter<'_>,
896 input: T,
897 out_elem: Elem<Self>,
898 ) -> std::fmt::Result {
899 write!(f, "{out_elem}(")?;
900 match input.elem() {
901 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
902 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
903 }?;
904 write!(f, ")")
905 }
906
907 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
908 f: &mut std::fmt::Formatter<'_>,
909 input: T,
910 out_elem: Elem<Self>,
911 ) -> std::fmt::Result {
912 write!(f, "{out_elem}(")?;
913 match out_elem {
914 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
915 _ => write!(
916 f,
917 "reverse_bits({}) >> {}",
918 shared::unary::zero_extend(input),
919 (size_of::<u32>() - out_elem.size()) * 8
920 ),
921 }?;
922 write!(f, ")")
923 }
924
925 fn compile_instruction_max_function_name(
927 f: &mut std::fmt::Formatter<'_>,
928 _item: Item<Self>,
929 ) -> std::fmt::Result {
930 write!(f, "max")
931 }
932
933 fn compile_instruction_min_function_name(
934 f: &mut std::fmt::Formatter<'_>,
935 _item: Item<Self>,
936 ) -> std::fmt::Result {
937 write!(f, "min")
938 }
939
940 fn compile_instruction_powf(
941 f: &mut std::fmt::Formatter<'_>,
942 lhs: &str,
943 rhs: &str,
944 elem: Elem<Self>,
945 ) -> std::fmt::Result {
946 write!(f, "pow({lhs}, {elem}({rhs}))")
947 }
948
949 fn compile_instruction_half_function_name_prefix() -> &'static str {
950 ""
951 }
952
953 fn compile_instruction_half2_function_name_prefix() -> &'static str {
954 ""
955 }
956
957 fn compile_warp_shuffle(
959 f: &mut std::fmt::Formatter<'_>,
960 var: &str,
961 source: &str,
962 ) -> std::fmt::Result {
963 write!(f, "simd_shuffle({var}, {source})")
964 }
965
966 fn compile_warp_shuffle_xor(
967 f: &mut std::fmt::Formatter<'_>,
968 var: &str,
969 _elem: &Elem<Self>,
970 offset: &str,
971 ) -> std::fmt::Result {
972 write!(f, "simd_shuffle_xor({var}, {offset})")
973 }
974
975 fn compile_warp_shuffle_up(
976 f: &mut std::fmt::Formatter<'_>,
977 var: &str,
978 offset: &str,
979 ) -> std::fmt::Result {
980 write!(f, "simd_shuffle_up({var}, {offset})")
981 }
982
983 fn compile_warp_shuffle_down(
984 f: &mut std::fmt::Formatter<'_>,
985 var: &str,
986 offset: &str,
987 ) -> std::fmt::Result {
988 write!(f, "simd_shuffle_down({var}, {offset})")
989 }
990
991 fn compile_warp_all<T: Component<Self>>(
992 f: &mut std::fmt::Formatter<'_>,
993 input: &T,
994 ) -> std::fmt::Result {
995 write!(f, "simd_all({input})")
996 }
997
998 fn compile_warp_any<T: Component<Self>>(
999 f: &mut std::fmt::Formatter<'_>,
1000 input: &T,
1001 ) -> std::fmt::Result {
1002 write!(f, "simd_any({input})")
1003 }
1004
1005 fn compile_warp_ballot(
1006 f: &mut std::fmt::Formatter<'_>,
1007 input: &Variable<Self>,
1008 out_elem: &Elem<Self>,
1009 ) -> std::fmt::Result {
1010 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1011 }
1012}
1013
1014impl DialectWmmaCompiler<Self> for MslDialect {
1017 fn compile_wmma_includes(
1018 f: &mut std::fmt::Formatter<'_>,
1019 _flags: &Flags<Self>,
1020 ) -> std::fmt::Result {
1021 writeln!(f, "#include <metal_simdgroup_matrix>")
1022 }
1023
1024 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1025 Ok(())
1027 }
1028
1029 fn compile_wmma_fragment_declaration(
1030 f: &mut std::fmt::Formatter<'_>,
1031 var: &crate::shared::Variable<MslDialect>,
1032 ) -> std::fmt::Result {
1033 wmma_api_base::compile_fragment_declaration(f, var)
1034 }
1035
1036 fn compile_wwma_fragment_ident(
1037 _f: &mut std::fmt::Formatter<'_>,
1038 _ident: &FragmentIdent<Self>,
1039 ) -> std::fmt::Result {
1040 Ok(())
1042 }
1043
1044 fn compile_wmma_fragment_layout(
1045 _f: &mut std::fmt::Formatter<'_>,
1046 _layout: &FragmentLayout<Self>,
1047 ) -> std::fmt::Result {
1048 Ok(())
1050 }
1051
1052 fn compile_wmma_fragment(
1053 f: &mut std::fmt::Formatter<'_>,
1054 fragment: &Fragment<Self>,
1055 ) -> std::fmt::Result {
1056 let ty = fragment.elem;
1057 let m = fragment.m;
1059 let n = fragment.n;
1060 let k = fragment.k;
1061 if m != 8 || n != 8 || k != 8 {
1062 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1063 }
1064 write!(f, "simdgroup_{ty}8x8")
1065 }
1066
1067 fn compile_wmma_instruction(
1068 f: &mut std::fmt::Formatter<'_>,
1069 instruction: &WmmaInstruction<Self>,
1070 ) -> std::fmt::Result {
1071 match instruction {
1072 WmmaInstruction::Fill { frag, value } => {
1073 match frag {
1074 Variable::WmmaFragment { .. } => {
1075 let ty = frag.elem();
1076 writeln!(
1078 f,
1079 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1080 )
1081 }
1082 _ => panic!("should be a fragment"),
1083 }
1084 }
1085 WmmaInstruction::Load {
1086 frag,
1087 value,
1088 stride,
1089 offset,
1090 layout: _layout,
1091 } => {
1092 let transpose = match frag {
1093 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1094 Some(FragmentLayout::RowMajor) => false,
1095 Some(FragmentLayout::ColMajor) => true,
1096 _ => false,
1097 },
1098 _ => panic!("should be a fragment"),
1099 };
1100 let item = value.item();
1101 if item.vectorization > 1 {
1102 let elem = item.elem;
1103 match value {
1104 Variable::GlobalInputArray(..) => writeln!(
1105 f,
1106 "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1107 ),
1108 Variable::SharedArray(..) => writeln!(
1109 f,
1110 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1111 ),
1112 _ => panic!(
1113 "Vectorized wmma load is only supported from global or shared memory."
1114 ),
1115 }
1116 } else {
1117 writeln!(
1118 f,
1119 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1120 )
1121 }
1122 }
1123 WmmaInstruction::Execute {
1124 frag_a: a,
1125 frag_b: b,
1126 frag_c: c,
1127 frag_d: d,
1128 ..
1129 } => {
1130 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1131 }
1132 WmmaInstruction::Store {
1133 output,
1134 frag,
1135 stride,
1136 offset,
1137 layout: _layout,
1138 } => {
1139 let item = output.item();
1140 let mut reinterpret_cast = item.vectorization > 1;
1141 let elem = match item.elem {
1142 Elem::BF16 => {
1143 reinterpret_cast = true;
1144 Elem::F16
1145 }
1146 _ => item.elem,
1147 };
1148 if reinterpret_cast {
1149 writeln!(
1150 f,
1151 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1152 )
1153 } else {
1154 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1155 }?;
1156 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1157 }
1158 WmmaInstruction::Cast { input, output } => {
1159 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1160 let ty = match output {
1161 Variable::WmmaFragment { frag, .. } => frag.elem,
1162 _ => panic!("should be a fragment"),
1163 };
1164 match ty {
1165 Elem::BF16 => {
1166 let addr_space = Self::address_space_for_variable(output);
1167 let elem = Elem::<Self>::F16;
1168 writeln!(
1171 f,
1172 "for(int e=0; e<8; e++) {{
1173 {ty} elem = {ty}({input}.thread_elements()[e]);
1174 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1175}}"
1176 )
1177 }
1178 _ => {
1179 writeln!(
1180 f,
1181 "for(int e=0; e<8; e++) {{
1182 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1183}}"
1184 )
1185 }
1186 }
1187 }
1188 WmmaInstruction::ExecuteManual {
1189 shape,
1190 frag_a,
1191 frag_b,
1192 frag_c,
1193 frag_d,
1194 } => {
1195 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1196 }
1197 WmmaInstruction::ExecuteScaled {
1198 shape,
1199 frag_a,
1200 frag_b,
1201 frag_c,
1202 frag_d,
1203 scales_a,
1204 scales_b,
1205 scales_factor,
1206 } => Self::compile_scaled_mma(
1207 f,
1208 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1209 *scales_a,
1210 *scales_b,
1211 *scales_factor,
1212 ),
1213 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1214 f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1215 }
1216 }
1217 }
1218
1219 fn compile_manual_mma(
1220 f: &mut std::fmt::Formatter<'_>,
1221 _mma: shared::ManualMma<Self>,
1222 ) -> std::fmt::Result {
1223 f.write_str("#error manual mma not supported on Metal\n")
1224 }
1225
1226 fn compile_scaled_mma(
1227 f: &mut std::fmt::Formatter<'_>,
1228 _mma: shared::ManualMma<Self>,
1229 _scales_a: Variable<Self>,
1230 _scales_b: Variable<Self>,
1231 _scales_factor: u32,
1232 ) -> std::fmt::Result {
1233 f.write_str("#error scaled mma not supported on Metal\n")
1234 }
1235
1236 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1237 let types = vec![
1238 (
1239 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1240 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1241 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1242 ),
1243 (
1244 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1245 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1246 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247 ),
1248 (
1249 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1250 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1251 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1252 ),
1253 (
1254 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1255 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1256 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1257 ),
1258 ];
1259 types
1260 .into_iter()
1261 .map(|(a_type, b_type, cd_type)| MmaConfig {
1262 a_type,
1263 b_type,
1264 cd_type,
1265 m: 8,
1266 n: 8,
1267 k: 8,
1268 })
1269 .collect()
1270 }
1271
1272 fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1273 Vec::new()
1274 }
1275}
1276
1277impl DialectProcessors<Self> for MslDialect {
1280 fn processors() -> Vec<Box<dyn gpu::Processor>> {
1281 Vec::new()
1282 }
1283}