cubecl_cpp/metal/
dialect.rs

1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5    Dialect,
6    shared::{
7        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8        DialectIncludes, DialectInstructions, DialectTypes, DialectWmmaCompiler, Elem, Flags,
9        FmtLeft, Fragment, FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory,
10        SupportedWmmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
11    },
12};
13use cubecl_core::{
14    compute::{Location, Visibility},
15    ir::{self as gpu, Id},
16};
17
18use super::{
19    AddressSpace, Extension,
20    arch::MetalArchitecture,
21    extension::{format_ffs, format_mulhi},
22    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
23};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct MslDialect {}
27
28// Base dialect
29
30impl Dialect for MslDialect {
31    type Architecture = MetalArchitecture;
32}
33
34// Includes
35
36impl DialectIncludes<Self> for MslDialect {
37    type Extension = Extension<Self>;
38
39    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
40        write!(
41            f,
42            "
43#include <metal_stdlib>
44using namespace metal;
45"
46        )?;
47        Ok(())
48    }
49
50    fn compile_extensions(
51        f: &mut std::fmt::Formatter<'_>,
52        extensions: &[Self::Extension],
53    ) -> std::fmt::Result {
54        for extension in extensions {
55            match extension {
56                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
57                Extension::Ffs(elem) => format_ffs(f, elem)?,
58                Extension::MulHi(elem) => format_mulhi(f, elem)?,
59                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
60                Extension::NoExtension => {}
61            }
62        }
63        Ok(())
64    }
65
66    fn register_instruction_extension(
67        extensions: &mut Vec<Self::Extension>,
68        instruction: &Instruction<Self>,
69    ) {
70        let mut register_extension = |extension: Self::Extension| {
71            if !extensions.contains(&extension) {
72                extensions.push(extension);
73            }
74        };
75        #[allow(clippy::single_match)]
76        match instruction {
77            shared::Instruction::<Self>::Erf(instruction) => {
78                register_extension(Extension::Erf(
79                    instruction.input.elem(),
80                    instruction.out.elem(),
81                ));
82            }
83            shared::Instruction::<Self>::FindFirstSet(instruction) => {
84                let input_elem = instruction.input.elem();
85                match input_elem {
86                    Elem::U32 | Elem::U64 => {
87                        register_extension(Extension::Ffs(instruction.input.elem()));
88                    }
89                    Elem::I32 => {
90                        register_extension(Extension::Ffs(Elem::<Self>::U32));
91                        register_extension(Extension::Ffs(instruction.input.elem()));
92                    }
93                    Elem::I64 => {
94                        register_extension(Extension::Ffs(Elem::<Self>::U64));
95                        register_extension(Extension::Ffs(instruction.input.elem()));
96                    }
97                    _ => {
98                        register_extension(Extension::Ffs(Elem::<Self>::U32));
99                    }
100                }
101            }
102            shared::Instruction::<Self>::HiMul(instruction) => {
103                register_extension(Extension::MulHi(instruction.out.elem()));
104            }
105            shared::Instruction::<Self>::Tanh(instruction) => {
106                register_extension(Extension::SafeTanh(instruction.input.item()));
107            }
108            _ => {}
109        }
110    }
111
112    fn register_warp_instruction_extension(
113        _extensions: &mut Vec<Self::Extension>,
114        _instruction: &WarpInstruction<Self>,
115    ) {
116    }
117}
118
119// Types
120
121impl DialectTypes<Self> for MslDialect {
122    fn item_can_be_optimized() -> bool {
123        false
124    }
125
126    fn compile_type_definitions(
127        f: &mut std::fmt::Formatter<'_>,
128        items: &std::collections::HashSet<crate::shared::Item<Self>>,
129        _scalars: &[(Elem<Self>, usize)],
130        _flags: &Flags,
131    ) -> std::fmt::Result {
132        for item in items.iter() {
133            let elem = item.elem;
134            let size = item.vectorization;
135            let alignment = elem.size() * size;
136            if size > 1 {
137                write!(
138                    f,
139                    "
140struct alignas({alignment}) {item} {{"
141                )?;
142
143                for i in 0..size {
144                    write!(
145                        f,
146                        "
147    {elem} i_{i};"
148                    )?;
149                }
150
151                f.write_str("\n};\n")?;
152            }
153        }
154        Ok(())
155    }
156
157    fn compile_elem(
158        f: &mut std::fmt::Formatter<'_>,
159        elem: &shared::Elem<Self>,
160        _words: bool,
161    ) -> std::fmt::Result {
162        // we always use the word form of types
163        match elem {
164            shared::Elem::FP4(_)
165            | shared::Elem::FP4x2(_)
166            | shared::Elem::FP6(_)
167            | shared::Elem::FP6x2(_)
168            | shared::Elem::FP8(_)
169            | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in Metal"),
170            shared::Elem::F16 => f.write_str("half"),
171            shared::Elem::F16x2 => panic!("type F162 not supported!"),
172            shared::Elem::F32 => f.write_str("float"),
173            shared::Elem::F64 => panic!("type double not supported!"),
174            shared::Elem::BF16 => f.write_str("bfloat"),
175            shared::Elem::BF16x2 => panic!("type BF162 not supported!"),
176            shared::Elem::TF32 => f.write_str("float"),
177            shared::Elem::I8 => f.write_str("char"),
178            shared::Elem::I16 => f.write_str("short"),
179            shared::Elem::I32 => f.write_str("int"),
180            shared::Elem::I64 => f.write_str("long"),
181            shared::Elem::U8 => f.write_str("uchar"),
182            shared::Elem::U16 => f.write_str("ushort"),
183            shared::Elem::U32 => f.write_str("uint"),
184            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
185            shared::Elem::Bool => f.write_str("bool"),
186            shared::Elem::Atomic(inner) => inner.fmt(f),
187            shared::Elem::_Dialect(_) => Ok(()),
188        }
189    }
190
191    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
192        if 1 == item.vectorization {
193            return write!(f, "{}", item.elem);
194        }
195        if item.native {
196            write!(f, "{}{}", item.elem, item.vectorization)
197        } else {
198            write!(f, "{}_{}", item.elem, item.vectorization)
199        }
200    }
201
202    fn compile_atomic_kind(
203        f: &mut std::fmt::Formatter<'_>,
204        kind: &AtomicKind<Self>,
205    ) -> std::fmt::Result {
206        match kind {
207            AtomicKind::I32 => write!(f, "atomic_int"),
208            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
209            AtomicKind::U32 => write!(f, "atomic_uint"),
210            AtomicKind::U64 => write!(f, "atomic_ulong"),
211            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
212            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
213            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
214            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
215            AtomicKind::_Dialect(_) => Ok(()),
216        }
217    }
218
219    fn address_space_for_variable(variable: &Variable<Self>) -> String {
220        format!("{} ", AddressSpace::from(variable))
221    }
222
223    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        write!(f, "thread")
225    }
226
227    fn compile_shared_memory_declaration(
228        f: &mut std::fmt::Formatter<'_>,
229        shared: &SharedMemory<Self>,
230    ) -> std::fmt::Result {
231        let item = shared.item;
232        let index = shared.index;
233        let size = shared.size;
234        let alignment = shared
235            .align
236            .map(|align| format!("alignas({align})"))
237            .unwrap_or_default();
238        writeln!(
239            f,
240            "threadgroup {alignment} {item} shared_memory_{index}[{size}];",
241        )
242    }
243}
244
245// Kernel argument bindings
246
247impl DialectBindings<Self> for MslDialect {
248    fn compile_kernel_signature(
249        f: &mut std::fmt::Formatter<'_>,
250        kernel_name: &str,
251        tensor_maps: &[Id],
252        buffers: &[Binding<Self>],
253        scalars: &[(Elem<Self>, usize)],
254        flags: &Flags,
255    ) -> std::fmt::Result {
256        write!(
257            (f),
258            "
259[[kernel]]
260void {kernel_name}("
261        )?;
262        // Global bindings args
263        let mut buffer_idx = 0;
264        debug_assert!(
265            tensor_maps.is_empty(),
266            "Tensor maps aren't supported for metal"
267        );
268        for (i, b) in buffers.iter().enumerate() {
269            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
270        }
271        if flags.static_meta_length > 0 {
272            let binding = Binding {
273                id: 0,
274                item: Item::scalar(Elem::<Self>::U32, true),
275                location: Location::Storage,
276                size: None,
277                vis: Visibility::Read,
278            };
279            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
280        }
281        for (elem, _) in scalars.iter() {
282            let binding = Binding {
283                id: 0,
284                item: Item::scalar(*elem, true),
285                location: Location::Storage,
286                size: None,
287                vis: Visibility::Read,
288            };
289
290            let name = format!("scalars_{elem}");
291            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
292        }
293
294        // Global metal builtins args
295        let builtins = vec![
296            (
297                flags.indexes.absolute_pos_tuple,
298                Variable::<Self>::AbsolutePosBaseName,
299            ),
300            (
301                flags.indexes.cube_dim_tuple,
302                Variable::<Self>::CubeDimBaseName,
303            ),
304            (
305                flags.indexes.cube_count_tuple,
306                Variable::<Self>::CubeCountBaseName,
307            ),
308            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
309            (
310                flags.indexes.unit_pos_tuple,
311                Variable::<Self>::UnitPosBaseName,
312            ),
313            (
314                flags.indexes.cube_pos_tuple,
315                Variable::<Self>::CubePosBaseName,
316            ),
317            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
318            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
319            (flags.indexes.plane_index, Variable::<Self>::PlanePos),
320        ];
321        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
322        builtins
323            .iter()
324            .filter(|(cond, _)| *cond)
325            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
326        f.write_str("\n)")
327    }
328}
329
330// Cube builtins dialect
331
332impl DialectCubeBuiltins<Self> for MslDialect {
333    /// Depending on the dialect available built-in variables the
334    /// inclusion rules might change.
335    /// For instance in metal we have a built-in for the Unit plane position
336    /// so we don't rely on other builtins.
337    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
338        let absolute_pos = flags.absolute_pos;
339        let cube_count = flags.cube_count;
340        let cube_dim = flags.cube_dim;
341        let cube_pos = flags.cube_pos;
342        let plane_dim_checked = flags.plane_dim_checked;
343        let plane_index = flags.plane_index;
344        let unit_pos = flags.unit_pos;
345        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
346        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
347        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
348        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
349        let cluster_pos = flags.cluster_pos;
350        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
351        let unit_pos_plane = flags.unit_pos_plane || plane_index;
352        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
353        CubeIndexFlags {
354            absolute_pos_tuple,
355            absolute_pos,
356            cube_count_tuple,
357            cube_count,
358            cube_dim_tuple,
359            cube_dim,
360            cube_pos_tuple,
361            cube_pos,
362            plane_dim,
363            plane_dim_checked,
364            plane_index,
365            unit_pos_tuple,
366            unit_pos,
367            unit_pos_plane,
368            cluster_pos,
369        }
370    }
371
372    fn compile_absolute_pos_tuple_computation(
373        _f: &mut std::fmt::Formatter<'_>,
374    ) -> std::fmt::Result {
375        // no need to compute it on metal as there is y a built-in for it
376        Ok(())
377    }
378
379    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        f.write_str("thread_pos_in_grid")
381    }
382
383    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        f.write_str("thread_index_in_grid")
385    }
386
387    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        Self::compile_absolute_pos_base_name(f)?;
389        write!(f, ".x")
390    }
391
392    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        Self::compile_absolute_pos_base_name(f)?;
394        write!(f, ".y")
395    }
396
397    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        Self::compile_absolute_pos_base_name(f)?;
399        write!(f, ".z")
400    }
401
402    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.write_str("threadgroups_per_grid")
404    }
405
406    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        f.write_str("total_threadgroups_in_grid")
408    }
409
410    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        Self::compile_cube_count_base_name(f)?;
412        write!(f, ".x")
413    }
414
415    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        Self::compile_cube_count_base_name(f)?;
417        write!(f, ".y")
418    }
419
420    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        Self::compile_cube_count_base_name(f)?;
422        write!(f, ".z")
423    }
424
425    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        f.write_str("threads_per_threadgroup")
427    }
428
429    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430        f.write_str("total_thread_in_threadgroup")
431    }
432
433    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434        Self::compile_cube_dim_base_name(f)?;
435        write!(f, ".x")
436    }
437
438    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        Self::compile_cube_dim_base_name(f)?;
440        write!(f, ".y")
441    }
442
443    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        Self::compile_cube_dim_base_name(f)?;
445        write!(f, ".z")
446    }
447
448    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        f.write_str("threadgroup_pos_in_grid")
450    }
451
452    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453        f.write_str("threadgroup_index_in_grid")
454    }
455
456    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        Self::compile_cube_pos_base_name(f)?;
458        write!(f, ".x")
459    }
460
461    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        Self::compile_cube_pos_base_name(f)?;
463        write!(f, ".y")
464    }
465
466    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        Self::compile_cube_pos_base_name(f)?;
468        write!(f, ".z")
469    }
470
471    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        // no need to compute it on metal as there is y a built-in for it
473        Ok(())
474    }
475
476    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        f.write_str("thread_pos_in_threadgroup")
478    }
479
480    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481        f.write_str("thread_index_in_threadgroup")
482    }
483
484    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485        Self::compile_unit_pos_base_name(f)?;
486        write!(f, ".x")
487    }
488
489    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490        Self::compile_unit_pos_base_name(f)?;
491        write!(f, ".y")
492    }
493
494    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        Self::compile_unit_pos_base_name(f)?;
496        write!(f, ".z")
497    }
498
499    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500        f.write_str("simd_size")
501    }
502
503    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        f.write_str("threads_per_simdgroup_checked")
505    }
506
507    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508        f.write_str("simd_group_id")
509    }
510
511    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512        f.write_str("simd_lane_id")
513    }
514}
515
516// Instructions
517
518impl DialectInstructions<Self> for MslDialect {
519    // atomics
520    fn compile_atomic_add(
521        f: &mut std::fmt::Formatter<'_>,
522        lhs: &Variable<Self>,
523        rhs: &Variable<Self>,
524        out: &Variable<Self>,
525    ) -> std::fmt::Result {
526        let out = out.fmt_left();
527        writeln!(
528            f,
529            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
530        )
531    }
532
533    fn compile_atomic_and(
534        f: &mut std::fmt::Formatter<'_>,
535        lhs: &Variable<Self>,
536        rhs: &Variable<Self>,
537        out: &Variable<Self>,
538    ) -> std::fmt::Result {
539        let out = out.fmt_left();
540        writeln!(
541            f,
542            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
543        )
544    }
545
546    fn compile_atomic_cas(
547        f: &mut std::fmt::Formatter<'_>,
548        input: &Variable<Self>,
549        cmp: &Variable<Self>,
550        val: &Variable<Self>,
551        out: &Variable<Self>,
552    ) -> std::fmt::Result {
553        let out = out.fmt_left();
554        writeln!(
555            f,
556            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
557        )
558    }
559
560    fn compile_atomic_load(
561        f: &mut std::fmt::Formatter<'_>,
562        input: &Variable<Self>,
563        out: &Variable<Self>,
564    ) -> std::fmt::Result {
565        let out = out.fmt_left();
566        writeln!(
567            f,
568            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
569        )
570    }
571
572    fn compile_atomic_max(
573        f: &mut std::fmt::Formatter<'_>,
574        lhs: &Variable<Self>,
575        rhs: &Variable<Self>,
576        out: &Variable<Self>,
577    ) -> std::fmt::Result {
578        let out = out.fmt_left();
579        writeln!(
580            f,
581            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
582        )
583    }
584
585    fn compile_atomic_min(
586        f: &mut std::fmt::Formatter<'_>,
587        lhs: &Variable<Self>,
588        rhs: &Variable<Self>,
589        out: &Variable<Self>,
590    ) -> std::fmt::Result {
591        let out = out.fmt_left();
592        writeln!(
593            f,
594            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
595        )
596    }
597
598    fn compile_atomic_or(
599        f: &mut std::fmt::Formatter<'_>,
600        lhs: &Variable<Self>,
601        rhs: &Variable<Self>,
602        out: &Variable<Self>,
603    ) -> std::fmt::Result {
604        let out = out.fmt_left();
605        writeln!(
606            f,
607            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
608        )
609    }
610
611    fn compile_atomic_store(
612        f: &mut std::fmt::Formatter<'_>,
613        input: &Variable<Self>,
614        out: &Variable<Self>,
615    ) -> std::fmt::Result {
616        writeln!(
617            f,
618            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
619        )
620    }
621
622    fn compile_atomic_sub(
623        f: &mut std::fmt::Formatter<'_>,
624        lhs: &Variable<Self>,
625        rhs: &Variable<Self>,
626        out: &Variable<Self>,
627    ) -> std::fmt::Result {
628        let out = out.fmt_left();
629        writeln!(
630            f,
631            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
632        )
633    }
634
635    fn compile_atomic_swap(
636        f: &mut std::fmt::Formatter<'_>,
637        lhs: &Variable<Self>,
638        rhs: &Variable<Self>,
639        out: &Variable<Self>,
640    ) -> std::fmt::Result {
641        let out = out.fmt_left();
642        writeln!(
643            f,
644            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
645        )
646    }
647
648    fn compile_atomic_xor(
649        f: &mut std::fmt::Formatter<'_>,
650        lhs: &Variable<Self>,
651        rhs: &Variable<Self>,
652        out: &Variable<Self>,
653    ) -> std::fmt::Result {
654        let out = out.fmt_left();
655        writeln!(
656            f,
657            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
658        )
659    }
660
661    // debug
662    fn compile_instruction_printf(
663        f: &mut std::fmt::Formatter<'_>,
664        format_string: &str,
665        args: &[Variable<Self>],
666    ) -> std::fmt::Result {
667        let format_string = format_string
668            .replace("\t", "\\t")
669            .replace("\n", "\\n")
670            .replace("\r", "\\r");
671        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
672        let args = match args.is_empty() {
673            true => "".to_string(),
674            false => format!(", {}", args.join(",")),
675        };
676        writeln!(f, "os_log_default.log(\"{format_string}\"{args});")
677    }
678
679    // logs
680    fn compile_instruction_log1p_scalar<T: Component<Self>>(
681        f: &mut std::fmt::Formatter<'_>,
682        input: T,
683    ) -> std::fmt::Result {
684        match input.elem() {
685            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
686                write!(f, "log(half(1.0f) + {input})")
687            }
688            _ => write!(f, "log(1.0f + {input})"),
689        }
690    }
691
692    // sync
693    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
694        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
695    }
696
697    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
698        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
699    }
700
701    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
702        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
703    }
704
705    // trigo
706    fn compile_instruction_tanh_scalar<T: Component<Self>>(
707        f: &mut std::fmt::Formatter<'_>,
708        input: T,
709    ) -> std::fmt::Result {
710        write!(f, "safe_tanh_scalar({input})")
711    }
712
713    // unary
714    fn compile_instruction_find_first_set<T: Component<Self>>(
715        f: &mut std::fmt::Formatter<'_>,
716        input: T,
717        out_elem: Elem<Self>,
718    ) -> std::fmt::Result {
719        write!(f, "{out_elem}(")?;
720        match input.elem() {
721            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
722            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
723            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
724        }?;
725        write!(f, ")")
726    }
727
728    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
729        f: &mut std::fmt::Formatter<'_>,
730        input: T,
731        out_elem: Elem<Self>,
732    ) -> std::fmt::Result {
733        write!(f, "{out_elem}(clz({input}))")
734    }
735
736    fn compile_instruction_popcount_scalar<T: Component<Self>>(
737        f: &mut std::fmt::Formatter<'_>,
738        input: T,
739        out_elem: Elem<Self>,
740    ) -> std::fmt::Result {
741        write!(f, "{out_elem}(")?;
742        match input.elem() {
743            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
744            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
745        }?;
746        write!(f, ")")
747    }
748
749    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
750        f: &mut std::fmt::Formatter<'_>,
751        input: T,
752        out_elem: Elem<Self>,
753    ) -> std::fmt::Result {
754        write!(f, "{out_elem}(")?;
755        match out_elem {
756            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
757            _ => write!(
758                f,
759                "reverse_bits({}) >> {}",
760                shared::unary::zero_extend(input),
761                (size_of::<u32>() - out_elem.size()) * 8
762            ),
763        }?;
764        write!(f, ")")
765    }
766
767    // others
768    fn compile_instruction_max_function_name(
769        f: &mut std::fmt::Formatter<'_>,
770        _item: Item<Self>,
771    ) -> std::fmt::Result {
772        write!(f, "max")
773    }
774
775    fn compile_instruction_min_function_name(
776        f: &mut std::fmt::Formatter<'_>,
777        _item: Item<Self>,
778    ) -> std::fmt::Result {
779        write!(f, "min")
780    }
781
782    fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783        write!(f, "pow")
784    }
785
786    fn compile_instruction_half_function_name_prefix() -> &'static str {
787        ""
788    }
789
790    fn compile_instruction_half2_function_name_prefix() -> &'static str {
791        ""
792    }
793
794    // Warp
795    fn compile_warp_shuffle(
796        f: &mut std::fmt::Formatter<'_>,
797        var: &str,
798        source: &str,
799    ) -> std::fmt::Result {
800        write!(f, "simd_shuffle({var}, {source})")
801    }
802
803    fn compile_warp_shuffle_xor(
804        f: &mut std::fmt::Formatter<'_>,
805        var: &str,
806        _elem: &Elem<Self>,
807        offset: &str,
808    ) -> std::fmt::Result {
809        write!(f, "simd_shuffle_xor({var}, {offset})")
810    }
811
812    fn compile_warp_shuffle_up(
813        f: &mut std::fmt::Formatter<'_>,
814        var: &str,
815        offset: &str,
816    ) -> std::fmt::Result {
817        write!(f, "simd_shuffle_up({var}, {offset})")
818    }
819
820    fn compile_warp_shuffle_down(
821        f: &mut std::fmt::Formatter<'_>,
822        var: &str,
823        offset: &str,
824    ) -> std::fmt::Result {
825        write!(f, "simd_shuffle_down({var}, {offset})")
826    }
827
828    fn compile_warp_all<T: Component<Self>>(
829        f: &mut std::fmt::Formatter<'_>,
830        input: &T,
831    ) -> std::fmt::Result {
832        write!(f, "simd_all({input})")
833    }
834
835    fn compile_warp_any<T: Component<Self>>(
836        f: &mut std::fmt::Formatter<'_>,
837        input: &T,
838    ) -> std::fmt::Result {
839        write!(f, "simd_any({input})")
840    }
841
842    fn compile_warp_ballot(
843        f: &mut std::fmt::Formatter<'_>,
844        input: &Variable<Self>,
845        out_elem: &Elem<Self>,
846    ) -> std::fmt::Result {
847        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
848    }
849}
850
851// Coop Matrices dialect
852
853impl DialectWmmaCompiler<Self> for MslDialect {
854    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855        writeln!(f, "#include <metal_simdgroup_matrix>")
856    }
857
858    fn compile_wmma_type_definitions(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
859        // not used
860        Ok(())
861    }
862
863    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
864        // not used
865        Ok(())
866    }
867
868    fn compile_wmma_fragment_declaration(
869        f: &mut std::fmt::Formatter<'_>,
870        var: &crate::shared::Variable<MslDialect>,
871    ) -> std::fmt::Result {
872        wmma_api_base::compile_fragment_declaration(f, var)
873    }
874
875    fn compile_wwma_fragment_ident(
876        _f: &mut std::fmt::Formatter<'_>,
877        _ident: &FragmentIdent<Self>,
878    ) -> std::fmt::Result {
879        // not used
880        Ok(())
881    }
882
883    fn compile_wmma_fragment_layout(
884        _f: &mut std::fmt::Formatter<'_>,
885        _layout: &FragmentLayout<Self>,
886    ) -> std::fmt::Result {
887        // not used
888        Ok(())
889    }
890
891    fn compile_wmma_fragment(
892        f: &mut std::fmt::Formatter<'_>,
893        fragment: &Fragment<Self>,
894    ) -> std::fmt::Result {
895        let ty = fragment.elem;
896        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
897        let m = fragment.m;
898        let n = fragment.n;
899        let k = fragment.k;
900        if m != 8 || n != 8 || k != 8 {
901            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
902        }
903        write!(f, "simdgroup_{ty}8x8")
904    }
905
906    fn compile_wmma_instruction(
907        f: &mut std::fmt::Formatter<'_>,
908        instruction: &WmmaInstruction<Self>,
909    ) -> std::fmt::Result {
910        match instruction {
911            WmmaInstruction::Fill { frag, value } => {
912                match frag {
913                    Variable::WmmaFragment { .. } => {
914                        let ty = frag.elem();
915                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
916                        writeln!(
917                            f,
918                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
919                        )
920                    }
921                    _ => panic!("should be a fragment"),
922                }
923            }
924            WmmaInstruction::Load {
925                frag,
926                value,
927                stride,
928                offset,
929                layout: _layout,
930            } => {
931                let transpose = match frag {
932                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
933                        Some(FragmentLayout::RowMajor) => false,
934                        Some(FragmentLayout::ColMajor) => true,
935                        _ => false,
936                    },
937                    _ => panic!("should be a fragment"),
938                };
939                let item = value.item();
940                if item.vectorization > 1 {
941                    let elem = item.elem;
942                    writeln!(
943                        f,
944                        "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
945                    )
946                } else {
947                    writeln!(
948                        f,
949                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
950                    )
951                }
952            }
953            WmmaInstruction::Execute {
954                frag_a: a,
955                frag_b: b,
956                frag_c: c,
957                frag_d: d,
958                ..
959            } => {
960                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
961            }
962            WmmaInstruction::Store {
963                output,
964                frag,
965                stride,
966                offset,
967                layout: _layout,
968            } => {
969                let item = output.item();
970                let mut reinterpret_cast = item.vectorization > 1;
971                let elem = match item.elem {
972                    Elem::BF16 => {
973                        reinterpret_cast = true;
974                        Elem::F16
975                    }
976                    _ => item.elem,
977                };
978                if reinterpret_cast {
979                    writeln!(
980                        f,
981                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
982                    )
983                } else {
984                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
985                }?;
986                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
987            }
988            WmmaInstruction::Cast { input, output } => {
989                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
990                let ty = match output {
991                    Variable::WmmaFragment { frag, .. } => frag.elem,
992                    _ => panic!("should be a fragment"),
993                };
994                match ty {
995                    Elem::BF16 => {
996                        let addr_space = Self::address_space_for_variable(output);
997                        let elem = Elem::<Self>::F16;
998                        // TODO: to test with benchmarks
999
1000                        writeln!(
1001                            f,
1002                            "for(int e=0; e<8; e++) {{
1003    {ty} elem = {ty}({input}.thread_elements()[e]);
1004    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1005}}"
1006                        )
1007                    }
1008                    _ => {
1009                        writeln!(
1010                            f,
1011                            "for(int e=0; e<8; e++) {{
1012    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1013}}"
1014                        )
1015                    }
1016                }
1017            }
1018        }
1019    }
1020
1021    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedWmmaCombinations {
1022        vec![
1023            (
1024                gpu::Elem::Float(gpu::FloatKind::F16),
1025                gpu::Elem::Float(gpu::FloatKind::F16),
1026                gpu::Elem::Float(gpu::FloatKind::F16),
1027                vec![(8, 8, 8)],
1028            ),
1029            (
1030                gpu::Elem::Float(gpu::FloatKind::F16),
1031                gpu::Elem::Float(gpu::FloatKind::F16),
1032                gpu::Elem::Float(gpu::FloatKind::F32),
1033                vec![(8, 8, 8)],
1034            ),
1035            (
1036                gpu::Elem::Float(gpu::FloatKind::BF16),
1037                gpu::Elem::Float(gpu::FloatKind::BF16),
1038                gpu::Elem::Float(gpu::FloatKind::BF16),
1039                vec![(8, 8, 8)],
1040            ),
1041            (
1042                gpu::Elem::Float(gpu::FloatKind::F32),
1043                gpu::Elem::Float(gpu::FloatKind::F32),
1044                gpu::Elem::Float(gpu::FloatKind::F32),
1045                vec![(8, 8, 8)],
1046            ),
1047        ]
1048    }
1049}
1050
1051// Coop Matrices dialect