cubecl_cpp/metal/
dialect.rs

1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5    Dialect,
6    shared::{
7        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8        DialectIncludes, DialectInstructions, DialectTypes, DialectWmmaCompiler, Elem, Flags,
9        FmtLeft, Fragment, FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory,
10        SupportedWmmaCombinations, Variable, WarpInstruction, WmmaInstruction,
11    },
12};
13use cubecl_core::{
14    compute::{Location, Visibility},
15    ir::{self as gpu, Id},
16};
17
18use super::{
19    AddressSpace, Extension,
20    arch::MetalArchitecture,
21    extension::{format_ffs, format_mulhi},
22    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
23};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct MslDialect {}
27
28// Base dialect
29
30impl Dialect for MslDialect {}
31
32// Includes
33
34impl DialectIncludes<Self> for MslDialect {
35    type Extension = Extension<Self>;
36
37    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
38        write!(
39            f,
40            "
41#include <metal_stdlib>
42using namespace metal;
43"
44        )?;
45        Ok(())
46    }
47
48    fn compile_extensions(
49        f: &mut std::fmt::Formatter<'_>,
50        extensions: &[Self::Extension],
51    ) -> std::fmt::Result {
52        for extension in extensions {
53            match extension {
54                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
55                Extension::Ffs(elem) => format_ffs(f, elem)?,
56                Extension::MulHi(elem) => format_mulhi(f, elem)?,
57                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
58                Extension::NoExtension => {}
59            }
60        }
61        Ok(())
62    }
63
64    fn register_instruction_extension(
65        extensions: &mut Vec<Self::Extension>,
66        instruction: &Instruction<Self>,
67    ) {
68        let mut register_extension = |extension: Self::Extension| {
69            if !extensions.contains(&extension) {
70                extensions.push(extension);
71            }
72        };
73        #[allow(clippy::single_match)]
74        match instruction {
75            shared::Instruction::<Self>::Erf(instruction) => {
76                register_extension(Extension::Erf(
77                    instruction.input.elem(),
78                    instruction.out.elem(),
79                ));
80            }
81            shared::Instruction::<Self>::FindFirstSet(instruction) => {
82                let input_elem = instruction.input.elem();
83                match input_elem {
84                    Elem::U32 | Elem::U64 => {
85                        register_extension(Extension::Ffs(instruction.input.elem()));
86                    }
87                    Elem::I32 => {
88                        register_extension(Extension::Ffs(Elem::<Self>::U32));
89                        register_extension(Extension::Ffs(instruction.input.elem()));
90                    }
91                    Elem::I64 => {
92                        register_extension(Extension::Ffs(Elem::<Self>::U64));
93                        register_extension(Extension::Ffs(instruction.input.elem()));
94                    }
95                    _ => {
96                        register_extension(Extension::Ffs(Elem::<Self>::U32));
97                    }
98                }
99            }
100            shared::Instruction::<Self>::HiMul(instruction) => {
101                register_extension(Extension::MulHi(instruction.out.elem()));
102            }
103            shared::Instruction::<Self>::Tanh(instruction) => {
104                register_extension(Extension::SafeTanh(instruction.input.item()));
105            }
106            _ => {}
107        }
108    }
109
110    fn register_warp_instruction_extension(
111        _extensions: &mut Vec<Self::Extension>,
112        _instruction: &WarpInstruction<Self>,
113    ) {
114    }
115}
116
117// Types
118
119impl DialectTypes<Self> for MslDialect {
120    fn item_can_be_optimized() -> bool {
121        false
122    }
123
124    fn compile_type_definitions(
125        f: &mut std::fmt::Formatter<'_>,
126        items: &std::collections::HashSet<crate::shared::Item<Self>>,
127        _scalars: &[(Elem<Self>, usize)],
128        _flags: &Flags,
129    ) -> std::fmt::Result {
130        for item in items.iter() {
131            let elem = item.elem;
132            let size = item.vectorization;
133            let alignment = elem.size() * size;
134            if size > 1 {
135                write!(
136                    f,
137                    "
138struct alignas({alignment}) {item} {{"
139                )?;
140
141                for i in 0..size {
142                    write!(
143                        f,
144                        "
145    {elem} i_{i};"
146                    )?;
147                }
148
149                f.write_str("\n};\n")?;
150            }
151        }
152        Ok(())
153    }
154
155    fn compile_elem(
156        f: &mut std::fmt::Formatter<'_>,
157        elem: &shared::Elem<Self>,
158        _words: bool,
159    ) -> std::fmt::Result {
160        // we always use the word form of types
161        match elem {
162            shared::Elem::F16 => f.write_str("half"),
163            shared::Elem::F162 => panic!("type F162 not supported!"),
164            shared::Elem::F32 => f.write_str("float"),
165            shared::Elem::F64 => panic!("type double not supported!"),
166            shared::Elem::BF16 => f.write_str("bfloat"),
167            shared::Elem::BF162 => panic!("type BF162 not supported!"),
168            shared::Elem::TF32 => f.write_str("float"),
169            shared::Elem::I8 => f.write_str("char"),
170            shared::Elem::I16 => f.write_str("short"),
171            shared::Elem::I32 => f.write_str("int"),
172            shared::Elem::I64 => f.write_str("long"),
173            shared::Elem::U8 => f.write_str("uchar"),
174            shared::Elem::U16 => f.write_str("ushort"),
175            shared::Elem::U32 => f.write_str("uint"),
176            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
177            shared::Elem::Bool => f.write_str("bool"),
178            shared::Elem::Atomic(inner) => inner.fmt(f),
179            shared::Elem::_Dialect(_) => Ok(()),
180        }
181    }
182
183    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
184        if 1 == item.vectorization {
185            return write!(f, "{}", item.elem);
186        }
187        if item.native {
188            write!(f, "{}{}", item.elem, item.vectorization)
189        } else {
190            write!(f, "{}_{}", item.elem, item.vectorization)
191        }
192    }
193
194    fn compile_atomic_kind(
195        f: &mut std::fmt::Formatter<'_>,
196        kind: &AtomicKind<Self>,
197    ) -> std::fmt::Result {
198        match kind {
199            AtomicKind::I32 => write!(f, "atomic_int"),
200            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
201            AtomicKind::U32 => write!(f, "atomic_uint"),
202            AtomicKind::U64 => write!(f, "atomic_ulong"),
203            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
204            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
205            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
206            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
207            AtomicKind::_Dialect(_) => Ok(()),
208        }
209    }
210
211    fn address_space_for_variable(variable: &Variable<Self>) -> String {
212        format!("{} ", AddressSpace::from(variable))
213    }
214
215    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        write!(f, "thread")
217    }
218
219    fn compile_shared_memory_qualifier(
220        f: &mut std::fmt::Formatter<'_>,
221        _shared: &SharedMemory<Self>,
222    ) -> std::fmt::Result {
223        write!(f, "threadgroup")
224    }
225}
226
227// Kernel argument bindings
228
229impl DialectBindings<Self> for MslDialect {
230    fn compile_kernel_signature(
231        f: &mut std::fmt::Formatter<'_>,
232        kernel_name: &str,
233        tensor_maps: &[Id],
234        buffers: &[Binding<Self>],
235        scalars: &[(Elem<Self>, usize)],
236        flags: &Flags,
237    ) -> std::fmt::Result {
238        write!(
239            (f),
240            "
241[[kernel]]
242void {}(",
243            kernel_name
244        )?;
245        // Global bindings args
246        let mut buffer_idx = 0;
247        debug_assert!(
248            tensor_maps.is_empty(),
249            "Tensor maps aren't supported for metal"
250        );
251        for (i, b) in buffers.iter().enumerate() {
252            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
253        }
254        if flags.static_meta_length > 0 {
255            let binding = Binding {
256                id: 0,
257                item: Item::scalar(Elem::<Self>::U32, true),
258                location: Location::Storage,
259                size: None,
260                vis: Visibility::Read,
261            };
262            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
263        }
264        for (elem, _) in scalars.iter() {
265            let binding = Binding {
266                id: 0,
267                item: Item::scalar(*elem, true),
268                location: Location::Storage,
269                size: None,
270                vis: Visibility::Read,
271            };
272
273            let name = format!("scalars_{elem}");
274            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
275        }
276
277        // Global metal builtins args
278        let builtins = vec![
279            (
280                flags.indexes.absolute_pos_tuple,
281                Variable::<Self>::AbsolutePosBaseName,
282            ),
283            (
284                flags.indexes.cube_dim_tuple,
285                Variable::<Self>::CubeDimBaseName,
286            ),
287            (
288                flags.indexes.cube_count_tuple,
289                Variable::<Self>::CubeCountBaseName,
290            ),
291            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
292            (
293                flags.indexes.unit_pos_tuple,
294                Variable::<Self>::UnitPosBaseName,
295            ),
296            (
297                flags.indexes.cube_pos_tuple,
298                Variable::<Self>::CubePosBaseName,
299            ),
300            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
301            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
302            (flags.indexes.plane_index, Variable::<Self>::PlanePos),
303        ];
304        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
305        builtins
306            .iter()
307            .filter(|(cond, _)| *cond)
308            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
309        f.write_str("\n)")
310    }
311}
312
313// Cube builtins dialect
314
315impl DialectCubeBuiltins<Self> for MslDialect {
316    /// Depending on the dialect available built-in variables the
317    /// inclusion rules might change.
318    /// For instance in metal we have a built-in for the Unit plane position
319    /// so we don't rely on other builtins.
320    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
321        let absolute_pos = flags.absolute_pos;
322        let cube_count = flags.cube_count;
323        let cube_dim = flags.cube_dim;
324        let cube_pos = flags.cube_pos;
325        let plane_dim_checked = flags.plane_dim_checked;
326        let plane_index = flags.plane_index;
327        let unit_pos = flags.unit_pos;
328        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
329        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
330        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
331        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
332        let cluster_pos = flags.cluster_pos;
333        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
334        let unit_pos_plane = flags.unit_pos_plane || plane_index;
335        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
336        CubeIndexFlags {
337            absolute_pos_tuple,
338            absolute_pos,
339            cube_count_tuple,
340            cube_count,
341            cube_dim_tuple,
342            cube_dim,
343            cube_pos_tuple,
344            cube_pos,
345            plane_dim,
346            plane_dim_checked,
347            plane_index,
348            unit_pos_tuple,
349            unit_pos,
350            unit_pos_plane,
351            cluster_pos,
352        }
353    }
354
355    fn compile_absolute_pos_tuple_computation(
356        _f: &mut std::fmt::Formatter<'_>,
357    ) -> std::fmt::Result {
358        // no need to compute it on metal as there is y a built-in for it
359        Ok(())
360    }
361
362    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        f.write_str("thread_pos_in_grid")
364    }
365
366    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        f.write_str("thread_index_in_grid")
368    }
369
370    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        Self::compile_absolute_pos_base_name(f)?;
372        write!(f, ".x")
373    }
374
375    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        Self::compile_absolute_pos_base_name(f)?;
377        write!(f, ".y")
378    }
379
380    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        Self::compile_absolute_pos_base_name(f)?;
382        write!(f, ".z")
383    }
384
385    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.write_str("threadgroups_per_grid")
387    }
388
389    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        f.write_str("total_threadgroups_in_grid")
391    }
392
393    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        Self::compile_cube_count_base_name(f)?;
395        write!(f, ".x")
396    }
397
398    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        Self::compile_cube_count_base_name(f)?;
400        write!(f, ".y")
401    }
402
403    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        Self::compile_cube_count_base_name(f)?;
405        write!(f, ".z")
406    }
407
408    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409        f.write_str("threads_per_threadgroup")
410    }
411
412    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        f.write_str("total_thread_in_threadgroup")
414    }
415
416    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        Self::compile_cube_dim_base_name(f)?;
418        write!(f, ".x")
419    }
420
421    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422        Self::compile_cube_dim_base_name(f)?;
423        write!(f, ".y")
424    }
425
426    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427        Self::compile_cube_dim_base_name(f)?;
428        write!(f, ".z")
429    }
430
431    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        f.write_str("threadgroup_pos_in_grid")
433    }
434
435    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436        f.write_str("threadgroup_index_in_grid")
437    }
438
439    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        Self::compile_cube_pos_base_name(f)?;
441        write!(f, ".x")
442    }
443
444    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        Self::compile_cube_pos_base_name(f)?;
446        write!(f, ".y")
447    }
448
449    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450        Self::compile_cube_pos_base_name(f)?;
451        write!(f, ".z")
452    }
453
454    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        // no need to compute it on metal as there is y a built-in for it
456        Ok(())
457    }
458
459    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.write_str("thread_pos_in_threadgroup")
461    }
462
463    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464        f.write_str("thread_index_in_threadgroup")
465    }
466
467    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        Self::compile_unit_pos_base_name(f)?;
469        write!(f, ".x")
470    }
471
472    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473        Self::compile_unit_pos_base_name(f)?;
474        write!(f, ".y")
475    }
476
477    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478        Self::compile_unit_pos_base_name(f)?;
479        write!(f, ".z")
480    }
481
482    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        f.write_str("simd_size")
484    }
485
486    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
487        f.write_str("threads_per_simdgroup_checked")
488    }
489
490    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        f.write_str("simd_group_id")
492    }
493
494    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        f.write_str("simd_lane_id")
496    }
497}
498
499// Instructions
500
501impl DialectInstructions<Self> for MslDialect {
502    // atomics
503    fn compile_atomic_add(
504        f: &mut std::fmt::Formatter<'_>,
505        lhs: &Variable<Self>,
506        rhs: &Variable<Self>,
507        out: &Variable<Self>,
508    ) -> std::fmt::Result {
509        let out = out.fmt_left();
510        writeln!(
511            f,
512            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
513        )
514    }
515
516    fn compile_atomic_and(
517        f: &mut std::fmt::Formatter<'_>,
518        lhs: &Variable<Self>,
519        rhs: &Variable<Self>,
520        out: &Variable<Self>,
521    ) -> std::fmt::Result {
522        let out = out.fmt_left();
523        writeln!(
524            f,
525            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
526        )
527    }
528
529    fn compile_atomic_cas(
530        f: &mut std::fmt::Formatter<'_>,
531        input: &Variable<Self>,
532        cmp: &Variable<Self>,
533        val: &Variable<Self>,
534        out: &Variable<Self>,
535    ) -> std::fmt::Result {
536        let out = out.fmt_left();
537        writeln!(
538            f,
539            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
540        )
541    }
542
543    fn compile_atomic_load(
544        f: &mut std::fmt::Formatter<'_>,
545        input: &Variable<Self>,
546        out: &Variable<Self>,
547    ) -> std::fmt::Result {
548        let out = out.fmt_left();
549        writeln!(
550            f,
551            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
552        )
553    }
554
555    fn compile_atomic_max(
556        f: &mut std::fmt::Formatter<'_>,
557        lhs: &Variable<Self>,
558        rhs: &Variable<Self>,
559        out: &Variable<Self>,
560    ) -> std::fmt::Result {
561        let out = out.fmt_left();
562        writeln!(
563            f,
564            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
565        )
566    }
567
568    fn compile_atomic_min(
569        f: &mut std::fmt::Formatter<'_>,
570        lhs: &Variable<Self>,
571        rhs: &Variable<Self>,
572        out: &Variable<Self>,
573    ) -> std::fmt::Result {
574        let out = out.fmt_left();
575        writeln!(
576            f,
577            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
578        )
579    }
580
581    fn compile_atomic_or(
582        f: &mut std::fmt::Formatter<'_>,
583        lhs: &Variable<Self>,
584        rhs: &Variable<Self>,
585        out: &Variable<Self>,
586    ) -> std::fmt::Result {
587        let out = out.fmt_left();
588        writeln!(
589            f,
590            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
591        )
592    }
593
594    fn compile_atomic_store(
595        f: &mut std::fmt::Formatter<'_>,
596        input: &Variable<Self>,
597        out: &Variable<Self>,
598    ) -> std::fmt::Result {
599        writeln!(
600            f,
601            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
602        )
603    }
604
605    fn compile_atomic_sub(
606        f: &mut std::fmt::Formatter<'_>,
607        lhs: &Variable<Self>,
608        rhs: &Variable<Self>,
609        out: &Variable<Self>,
610    ) -> std::fmt::Result {
611        let out = out.fmt_left();
612        writeln!(
613            f,
614            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
615        )
616    }
617
618    fn compile_atomic_swap(
619        f: &mut std::fmt::Formatter<'_>,
620        lhs: &Variable<Self>,
621        rhs: &Variable<Self>,
622        out: &Variable<Self>,
623    ) -> std::fmt::Result {
624        let out = out.fmt_left();
625        writeln!(
626            f,
627            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
628        )
629    }
630
631    fn compile_atomic_xor(
632        f: &mut std::fmt::Formatter<'_>,
633        lhs: &Variable<Self>,
634        rhs: &Variable<Self>,
635        out: &Variable<Self>,
636    ) -> std::fmt::Result {
637        let out = out.fmt_left();
638        writeln!(
639            f,
640            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
641        )
642    }
643
644    // debug
645    fn compile_instruction_printf(
646        f: &mut std::fmt::Formatter<'_>,
647        format_string: &str,
648        args: &[Variable<Self>],
649    ) -> std::fmt::Result {
650        let format_string = format_string
651            .replace("\t", "\\t")
652            .replace("\n", "\\n")
653            .replace("\r", "\\r");
654        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
655        let args = match args.is_empty() {
656            true => "".to_string(),
657            false => format!(", {}", args.join(",")),
658        };
659        writeln!(f, "os_log_default.log(\"{format_string}\"{args});")
660    }
661
662    // logs
663    fn compile_instruction_log1p_scalar<T: Component<Self>>(
664        f: &mut std::fmt::Formatter<'_>,
665        input: T,
666    ) -> std::fmt::Result {
667        match input.elem() {
668            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
669                write!(f, "log(half(1.0f) + {input})")
670            }
671            _ => write!(f, "log(1.0f + {input})"),
672        }
673    }
674
675    // sync
676    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
677        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
678    }
679
680    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
682    }
683
684    // trigo
685    fn compile_instruction_tanh_scalar<T: Component<Self>>(
686        f: &mut std::fmt::Formatter<'_>,
687        input: T,
688    ) -> std::fmt::Result {
689        write!(f, "safe_tanh_scalar({input})")
690    }
691
692    // unary
693    fn compile_instruction_find_first_set<T: Component<Self>>(
694        f: &mut std::fmt::Formatter<'_>,
695        input: T,
696        out_elem: Elem<Self>,
697    ) -> std::fmt::Result {
698        write!(f, "{out_elem}(")?;
699        match input.elem() {
700            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
701            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
702            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
703        }?;
704        write!(f, ")")
705    }
706
707    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
708        f: &mut std::fmt::Formatter<'_>,
709        input: T,
710        out_elem: Elem<Self>,
711    ) -> std::fmt::Result {
712        write!(f, "{out_elem}(clz({input}))")
713    }
714
715    fn compile_instruction_popcount_scalar<T: Component<Self>>(
716        f: &mut std::fmt::Formatter<'_>,
717        input: T,
718        out_elem: Elem<Self>,
719    ) -> std::fmt::Result {
720        write!(f, "{out_elem}(")?;
721        match input.elem() {
722            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
723            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
724        }?;
725        write!(f, ")")
726    }
727
728    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
729        f: &mut std::fmt::Formatter<'_>,
730        input: T,
731        out_elem: Elem<Self>,
732    ) -> std::fmt::Result {
733        write!(f, "{out_elem}(")?;
734        match out_elem {
735            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
736            _ => write!(
737                f,
738                "reverse_bits({}) >> {}",
739                shared::unary::zero_extend(input),
740                (size_of::<u32>() - out_elem.size()) * 8
741            ),
742        }?;
743        write!(f, ")")
744    }
745
746    // others
747    fn compile_instruction_max_function_name(
748        f: &mut std::fmt::Formatter<'_>,
749        _item: Item<Self>,
750    ) -> std::fmt::Result {
751        write!(f, "max")
752    }
753
754    fn compile_instruction_min_function_name(
755        f: &mut std::fmt::Formatter<'_>,
756        _item: Item<Self>,
757    ) -> std::fmt::Result {
758        write!(f, "min")
759    }
760
761    fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
762        write!(f, "pow")
763    }
764
765    fn compile_instruction_half_function_name_prefix() -> &'static str {
766        ""
767    }
768
769    fn compile_instruction_half2_function_name_prefix() -> &'static str {
770        ""
771    }
772
773    // Warp
774    fn compile_warp_shuffle(
775        f: &mut std::fmt::Formatter<'_>,
776        var: &str,
777        source: &str,
778    ) -> std::fmt::Result {
779        write!(f, "simd_shuffle({var}, {source})")
780    }
781
782    fn compile_warp_shuffle_xor(
783        f: &mut std::fmt::Formatter<'_>,
784        var: &str,
785        _elem: &Elem<Self>,
786        offset: &str,
787    ) -> std::fmt::Result {
788        write!(f, "simd_shuffle_xor({var}, {offset})")
789    }
790
791    fn compile_warp_shuffle_up(
792        f: &mut std::fmt::Formatter<'_>,
793        var: &str,
794        offset: &str,
795    ) -> std::fmt::Result {
796        write!(f, "simd_shuffle_up({var}, {offset})")
797    }
798
799    fn compile_warp_shuffle_down(
800        f: &mut std::fmt::Formatter<'_>,
801        var: &str,
802        offset: &str,
803    ) -> std::fmt::Result {
804        write!(f, "simd_shuffle_down({var}, {offset})")
805    }
806
807    fn compile_warp_all<T: Component<Self>>(
808        f: &mut std::fmt::Formatter<'_>,
809        input: &T,
810    ) -> std::fmt::Result {
811        write!(f, "simd_all({input})")
812    }
813
814    fn compile_warp_any<T: Component<Self>>(
815        f: &mut std::fmt::Formatter<'_>,
816        input: &T,
817    ) -> std::fmt::Result {
818        write!(f, "simd_any({input})")
819    }
820
821    fn compile_warp_ballot(
822        f: &mut std::fmt::Formatter<'_>,
823        input: &Variable<Self>,
824        out_elem: &Elem<Self>,
825    ) -> std::fmt::Result {
826        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
827    }
828}
829
830// Coop Matrices dialect
831
832impl DialectWmmaCompiler<Self> for MslDialect {
833    type Architecture = MetalArchitecture;
834
835    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836        writeln!(f, "#include <metal_simdgroup_matrix>")
837    }
838
839    fn compile_wmma_type_definitions(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840        // not used
841        Ok(())
842    }
843
844    fn compile_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
845        // not used
846        Ok(())
847    }
848
849    fn compile_fragment_ident(
850        _ident: &FragmentIdent<Self>,
851        _f: &mut std::fmt::Formatter<'_>,
852    ) -> std::fmt::Result {
853        // not used
854        Ok(())
855    }
856
857    fn compile_fragment_layout(
858        _layout: &FragmentLayout<Self>,
859        _f: &mut std::fmt::Formatter<'_>,
860    ) -> std::fmt::Result {
861        // not used
862        Ok(())
863    }
864
865    fn compile_fragment(
866        fragment: &Fragment<Self>,
867        f: &mut std::fmt::Formatter<'_>,
868    ) -> std::fmt::Result {
869        let ty = fragment.elem;
870        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
871        let m = fragment.m;
872        let n = fragment.n;
873        let k = fragment.k;
874        if m != 8 || n != 8 || k != 8 {
875            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragemts are supported.");
876        }
877        write!(f, "simdgroup_{ty}8x8")
878    }
879
880    fn compile_instruction(
881        instruction: &WmmaInstruction<Self>,
882        f: &mut std::fmt::Formatter<'_>,
883    ) -> std::fmt::Result {
884        match instruction {
885            WmmaInstruction::Fill { frag, value } => {
886                match frag {
887                    Variable::WmmaFragment { .. } => {
888                        let ty = frag.elem();
889                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
890                        writeln!(
891                            f,
892                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
893                        )
894                    }
895                    _ => panic!("should be a fragment"),
896                }
897            }
898            WmmaInstruction::Load {
899                frag,
900                value,
901                stride,
902                ..
903            } => {
904                let transpose = match frag {
905                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
906                        Some(FragmentLayout::RowMajor) => false,
907                        Some(FragmentLayout::ColMajor) => true,
908                        _ => false,
909                    },
910                    _ => panic!("should be a fragment"),
911                };
912                let item = value.item();
913                if item.vectorization > 1 {
914                    let elem = item.elem;
915                    writeln!(
916                        f,
917                        "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value}), {stride}, 0, {transpose});"
918                    )
919                } else {
920                    writeln!(
921                        f,
922                        "simdgroup_load({frag}, {value}, {stride}, 0, {transpose});"
923                    )
924                }
925            }
926            WmmaInstruction::Execute {
927                frag_a: a,
928                frag_b: b,
929                frag_c: c,
930                frag_d: d,
931                ..
932            } => {
933                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
934            }
935            WmmaInstruction::Store {
936                output,
937                frag,
938                stride,
939                ..
940            } => {
941                let item = output.item();
942                let mut reinterpret_cast = item.vectorization > 1;
943                let elem = match item.elem {
944                    Elem::BF16 => {
945                        reinterpret_cast = true;
946                        Elem::F16
947                    }
948                    _ => item.elem,
949                };
950                if reinterpret_cast {
951                    writeln!(
952                        f,
953                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output}), {stride});"
954                    )
955                } else {
956                    writeln!(f, "simdgroup_store({frag}, {output}, {stride});")
957                }?;
958                writeln!(f, "threadgroup_barrier(mem_flags::mem_none);")
959            }
960            WmmaInstruction::Cast { input, output } => {
961                writeln!(f, "threadgroup_barrier(mem_flags::mem_none);")?;
962                let ty = match output {
963                    Variable::WmmaFragment { frag, .. } => frag.elem,
964                    _ => panic!("should be a fragment"),
965                };
966                match ty {
967                    Elem::BF16 => {
968                        let addr_space = Self::address_space_for_variable(output);
969                        let elem = Elem::<Self>::F16;
970                        // TODO: to test with benchmarks
971
972                        writeln!(
973                            f,
974                            "for(int e=0; e<8; e++) {{
975    {ty} elem = {ty}({input}.thread_elements()[e]);
976    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
977}}"
978                        )
979                    }
980                    _ => {
981                        writeln!(
982                            f,
983                            "for(int e=0; e<8; e++) {{
984    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
985}}"
986                        )
987                    }
988                }
989            }
990        }
991    }
992
993    fn supported_wmma_combinations(_arch: &Self::Architecture) -> SupportedWmmaCombinations {
994        vec![
995            (
996                gpu::Elem::Float(gpu::FloatKind::F16),
997                gpu::Elem::Float(gpu::FloatKind::F16),
998                gpu::Elem::Float(gpu::FloatKind::F16),
999                vec![(8, 8, 8)],
1000            ),
1001            (
1002                gpu::Elem::Float(gpu::FloatKind::F16),
1003                gpu::Elem::Float(gpu::FloatKind::F16),
1004                gpu::Elem::Float(gpu::FloatKind::F32),
1005                vec![(8, 8, 8)],
1006            ),
1007            (
1008                gpu::Elem::Float(gpu::FloatKind::BF16),
1009                gpu::Elem::Float(gpu::FloatKind::BF16),
1010                gpu::Elem::Float(gpu::FloatKind::BF16),
1011                vec![(8, 8, 8)],
1012            ),
1013            (
1014                gpu::Elem::Float(gpu::FloatKind::F32),
1015                gpu::Elem::Float(gpu::FloatKind::F32),
1016                gpu::Elem::Float(gpu::FloatKind::F32),
1017                vec![(8, 8, 8)],
1018            ),
1019        ]
1020    }
1021}
1022
1023// Coop Matrices dialect