cubecl_cpp/metal/
dialect.rs

1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5    Dialect,
6    shared::{
7        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8        DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
9        DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
10        FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
11        SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
12    },
13};
14use cubecl_core::{
15    ir::{self as gpu},
16    prelude::{Location, Visibility},
17};
18use cubecl_runtime::MmaConfig;
19
20use super::{
21    AddressSpace, Extension,
22    arch::MetalArchitecture,
23    extension::{format_ffs, format_mulhi},
24    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
25};
26
27#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
28pub struct MslDialect {}
29
30// Base dialect
31
32impl Dialect for MslDialect {
33    type Architecture = MetalArchitecture;
34}
35
36impl MslDialect {
37    fn warp_op_vectorized(
38        f: &mut core::fmt::Formatter<'_>,
39        input: &Variable<Self>,
40        out: &Variable<Self>,
41        simd_op_prefix: &str,
42        simd_op_suffix: &str,
43    ) -> core::fmt::Result {
44        let out = out.fmt_left();
45        let vectorization = input.item().vectorization;
46
47        f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
48
49        for k in 0..vectorization {
50            let index = if vectorization > 1 {
51                format!(".i_{k}")
52            } else {
53                String::new()
54            };
55            let comma = if k + 1 < vectorization { "," } else { "" };
56
57            writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
58        }
59
60        f.write_fmt(format_args!("}};\n"))
61    }
62}
63
64impl DialectWarpReduceCompiler<Self> for MslDialect {
65    fn warp_reduce_sum(
66        f: &mut core::fmt::Formatter<'_>,
67        input: &Variable<Self>,
68        out: &Variable<Self>,
69    ) -> core::fmt::Result {
70        Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
71    }
72    fn warp_reduce_prod(
73        f: &mut core::fmt::Formatter<'_>,
74        input: &Variable<Self>,
75        out: &Variable<Self>,
76    ) -> core::fmt::Result {
77        Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
78    }
79    fn warp_reduce_max(
80        f: &mut core::fmt::Formatter<'_>,
81        input: &Variable<Self>,
82        out: &Variable<Self>,
83    ) -> core::fmt::Result {
84        Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
85    }
86    fn warp_reduce_min(
87        f: &mut core::fmt::Formatter<'_>,
88        input: &Variable<Self>,
89        out: &Variable<Self>,
90    ) -> core::fmt::Result {
91        Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
92    }
93    fn warp_reduce_all(
94        f: &mut core::fmt::Formatter<'_>,
95        input: &Variable<Self>,
96        out: &Variable<Self>,
97    ) -> core::fmt::Result {
98        Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
99    }
100    fn warp_reduce_any(
101        f: &mut core::fmt::Formatter<'_>,
102        input: &Variable<Self>,
103        out: &Variable<Self>,
104    ) -> core::fmt::Result {
105        Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
106    }
107    fn warp_reduce_sum_inclusive(
108        f: &mut core::fmt::Formatter<'_>,
109        input: &Variable<Self>,
110        out: &Variable<Self>,
111    ) -> core::fmt::Result {
112        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
113    }
114    fn warp_reduce_prod_inclusive(
115        f: &mut core::fmt::Formatter<'_>,
116        input: &Variable<Self>,
117        out: &Variable<Self>,
118    ) -> core::fmt::Result {
119        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
120    }
121    fn warp_reduce_sum_exclusive(
122        f: &mut core::fmt::Formatter<'_>,
123        input: &Variable<Self>,
124        out: &Variable<Self>,
125    ) -> core::fmt::Result {
126        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
127    }
128    fn warp_reduce_prod_exclusive(
129        f: &mut core::fmt::Formatter<'_>,
130        input: &Variable<Self>,
131        out: &Variable<Self>,
132    ) -> core::fmt::Result {
133        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
134    }
135}
136
137// Includes
138
139impl DialectIncludes<Self> for MslDialect {
140    type Extension = Extension<Self>;
141
142    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
143        write!(
144            f,
145            "
146#include <metal_stdlib>
147using namespace metal;
148"
149        )?;
150        Ok(())
151    }
152
153    fn compile_extensions(
154        f: &mut std::fmt::Formatter<'_>,
155        extensions: &[Self::Extension],
156    ) -> std::fmt::Result {
157        for extension in extensions {
158            match extension {
159                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
160                Extension::Ffs(elem) => format_ffs(f, elem)?,
161                Extension::MulHi(elem) => format_mulhi(f, elem)?,
162                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
163                Extension::NoExtension => {}
164            }
165        }
166        Ok(())
167    }
168
169    fn register_instruction_extension(
170        extensions: &mut Vec<Self::Extension>,
171        instruction: &Instruction<Self>,
172    ) {
173        let mut register_extension = |extension: Self::Extension| {
174            if !extensions.contains(&extension) {
175                extensions.push(extension);
176            }
177        };
178        #[allow(clippy::single_match)]
179        match instruction {
180            shared::Instruction::<Self>::Erf(instruction) => {
181                register_extension(Extension::Erf(
182                    instruction.input.elem(),
183                    instruction.out.elem(),
184                ));
185            }
186            shared::Instruction::<Self>::FindFirstSet(instruction) => {
187                let input_elem = instruction.input.elem();
188                match input_elem {
189                    Elem::U32 | Elem::U64 => {
190                        register_extension(Extension::Ffs(instruction.input.elem()));
191                    }
192                    Elem::I32 => {
193                        register_extension(Extension::Ffs(Elem::<Self>::U32));
194                        register_extension(Extension::Ffs(instruction.input.elem()));
195                    }
196                    Elem::I64 => {
197                        register_extension(Extension::Ffs(Elem::<Self>::U64));
198                        register_extension(Extension::Ffs(instruction.input.elem()));
199                    }
200                    _ => {
201                        register_extension(Extension::Ffs(Elem::<Self>::U32));
202                    }
203                }
204            }
205            shared::Instruction::<Self>::HiMul(instruction) => {
206                register_extension(Extension::MulHi(instruction.out.elem()));
207            }
208            shared::Instruction::<Self>::Tanh(instruction) => {
209                register_extension(Extension::SafeTanh(instruction.input.item()));
210            }
211            _ => {}
212        }
213    }
214
215    fn register_warp_instruction_extension(
216        _extensions: &mut Vec<Self::Extension>,
217        _instruction: &WarpInstruction<Self>,
218    ) {
219    }
220}
221
222// Types
223
224impl DialectTypes<Self> for MslDialect {
225    fn item_can_be_optimized() -> bool {
226        false
227    }
228
229    fn compile_type_definitions(
230        f: &mut std::fmt::Formatter<'_>,
231        items: &std::collections::HashSet<crate::shared::Item<Self>>,
232        _scalars: &[(Elem<Self>, usize)],
233        _flags: &Flags,
234    ) -> std::fmt::Result {
235        for item in items.iter() {
236            let elem = item.elem;
237            let size = item.vectorization;
238            let alignment = elem.size() * size;
239            if size > 1 {
240                write!(
241                    f,
242                    "
243struct alignas({alignment}) {item} {{"
244                )?;
245
246                for i in 0..size {
247                    write!(
248                        f,
249                        "
250    {elem} i_{i};"
251                    )?;
252                }
253
254                f.write_str("\n};\n")?;
255            }
256        }
257        Ok(())
258    }
259
260    fn compile_elem(
261        f: &mut std::fmt::Formatter<'_>,
262        elem: &shared::Elem<Self>,
263        _words: bool,
264    ) -> std::fmt::Result {
265        // we always use the word form of types
266        match elem {
267            shared::Elem::FP4(_)
268            | shared::Elem::FP4x2(_)
269            | shared::Elem::FP6(_)
270            | shared::Elem::FP6x2(_)
271            | shared::Elem::FP8(_)
272            | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
273            shared::Elem::F16 => f.write_str("half"),
274            shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
275            shared::Elem::F32 => f.write_str("float"),
276            shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
277            shared::Elem::BF16 => f.write_str("bfloat"),
278            shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
279            shared::Elem::TF32 => f.write_str("float"),
280            shared::Elem::I8 => f.write_str("char"),
281            shared::Elem::I16 => f.write_str("short"),
282            shared::Elem::I32 => f.write_str("int"),
283            shared::Elem::I64 => f.write_str("long"),
284            shared::Elem::U8 => f.write_str("uchar"),
285            shared::Elem::U16 => f.write_str("ushort"),
286            shared::Elem::U32 => f.write_str("uint"),
287            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
288            shared::Elem::Bool => f.write_str("bool"),
289            shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
290            shared::Elem::Atomic(inner) => inner.fmt(f),
291            shared::Elem::_Dialect(_) => Ok(()),
292        }
293    }
294
295    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
296        if 1 == item.vectorization {
297            return write!(f, "{}", item.elem);
298        }
299        if item.native {
300            write!(f, "{}{}", item.elem, item.vectorization)
301        } else {
302            write!(f, "{}_{}", item.elem, item.vectorization)
303        }
304    }
305
306    fn compile_atomic_kind(
307        f: &mut std::fmt::Formatter<'_>,
308        kind: &AtomicKind<Self>,
309    ) -> std::fmt::Result {
310        match kind {
311            AtomicKind::I32 => write!(f, "atomic_int"),
312            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
313            AtomicKind::U32 => write!(f, "atomic_uint"),
314            AtomicKind::U64 => write!(f, "atomic_ulong"),
315            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
316            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
317            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
318            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
319            AtomicKind::_Dialect(_) => Ok(()),
320        }
321    }
322
323    fn address_space_for_variable(variable: &Variable<Self>) -> String {
324        format!("{} ", AddressSpace::from(variable))
325    }
326
327    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        write!(f, "thread")
329    }
330
331    fn compile_shared_memory_declaration(
332        f: &mut std::fmt::Formatter<'_>,
333        shared: &SharedMemory<Self>,
334    ) -> std::fmt::Result {
335        match shared {
336            SharedMemory::Array {
337                index,
338                item,
339                length,
340                offset,
341                ..
342            } => {
343                let size_bytes = length * item.size() as u32;
344                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
345                writeln!(
346                    f,
347                    "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
348                )
349            }
350            SharedMemory::Value {
351                index,
352                item,
353                offset,
354                ..
355            } => {
356                let size_bytes = item.size() as u32;
357                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
358                writeln!(
359                    f,
360                    "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
361                )
362            }
363        }
364    }
365}
366
367// Kernel argument bindings
368
369impl DialectBindings<Self> for MslDialect {
370    fn compile_kernel_signature(
371        f: &mut std::fmt::Formatter<'_>,
372        kernel_name: &str,
373        tensor_maps: &[Binding<Self>],
374        buffers: &[Binding<Self>],
375        scalars: &[(Elem<Self>, usize)],
376        flags: &Flags,
377    ) -> std::fmt::Result {
378        write!(
379            (f),
380            "
381[[kernel]]
382void {kernel_name}("
383        )?;
384        // Global bindings args
385        let mut buffer_idx = 0;
386        debug_assert!(
387            tensor_maps.is_empty(),
388            "Tensor maps aren't supported for metal"
389        );
390        for (i, b) in buffers.iter().enumerate() {
391            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
392        }
393        if flags.static_meta_length > 0 {
394            let binding = Binding {
395                id: 0,
396                item: Item::scalar(Elem::<Self>::U32, true),
397                location: Location::Storage,
398                size: None,
399                vis: Visibility::Read,
400            };
401            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
402        }
403        for (elem, _) in scalars.iter() {
404            let binding = Binding {
405                id: 0,
406                item: Item::scalar(*elem, true),
407                location: Location::Storage,
408                size: None,
409                vis: Visibility::Read,
410            };
411
412            let name = format!("scalars_{elem}");
413            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
414        }
415
416        // Global metal builtins args
417        let builtins = vec![
418            (
419                flags.indexes.absolute_pos_tuple,
420                Variable::<Self>::AbsolutePosBaseName,
421            ),
422            (
423                flags.indexes.cube_dim_tuple,
424                Variable::<Self>::CubeDimBaseName,
425            ),
426            (
427                flags.indexes.cube_count_tuple,
428                Variable::<Self>::CubeCountBaseName,
429            ),
430            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
431            (
432                flags.indexes.unit_pos_tuple,
433                Variable::<Self>::UnitPosBaseName,
434            ),
435            (
436                flags.indexes.cube_pos_tuple,
437                Variable::<Self>::CubePosBaseName,
438            ),
439            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
440            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
441            (flags.indexes.plane_index, Variable::<Self>::PlanePos),
442        ];
443        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
444        builtins
445            .iter()
446            .filter(|(cond, _)| *cond)
447            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
448        f.write_str("\n)")
449    }
450
451    fn compile_bindings_body(
452        f: &mut std::fmt::Formatter<'_>,
453        body: &shared::Body<Self>,
454    ) -> std::fmt::Result {
455        if !body.shared_memories.is_empty() {
456            let size = body
457                .shared_memories
458                .iter()
459                .map(|it| it.offset() + it.size())
460                .max()
461                .unwrap();
462
463            writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
464        }
465        Ok(())
466    }
467}
468
469// Cube builtins dialect
470
471impl DialectCubeBuiltins<Self> for MslDialect {
472    /// Depending on the dialect available built-in variables the
473    /// inclusion rules might change.
474    /// For instance in metal we have a built-in for the Unit plane position
475    /// so we don't rely on other builtins.
476    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
477        let absolute_pos = flags.absolute_pos;
478        let cube_count = flags.cube_count;
479        let cube_dim = flags.cube_dim;
480        let cube_pos = flags.cube_pos;
481        let plane_dim_checked = flags.plane_dim_checked;
482        let plane_index = flags.plane_index;
483        let unit_pos = flags.unit_pos;
484        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
485        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
486        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
487        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
488        let cluster_pos = flags.cluster_pos;
489        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
490        let unit_pos_plane = flags.unit_pos_plane || plane_index;
491        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
492        CubeIndexFlags {
493            absolute_pos_tuple,
494            absolute_pos,
495            cube_count_tuple,
496            cube_count,
497            cube_dim_tuple,
498            cube_dim,
499            cube_pos_tuple,
500            cube_pos,
501            plane_dim,
502            plane_dim_checked,
503            plane_index,
504            unit_pos_tuple,
505            unit_pos,
506            unit_pos_plane,
507            cluster_pos,
508        }
509    }
510
511    fn compile_absolute_pos_tuple_computation(
512        _f: &mut std::fmt::Formatter<'_>,
513    ) -> std::fmt::Result {
514        // no need to compute it on metal as there is y a built-in for it
515        Ok(())
516    }
517
518    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519        f.write_str("thread_pos_in_grid")
520    }
521
522    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        f.write_str("thread_index_in_grid")
524    }
525
526    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        Self::compile_absolute_pos_base_name(f)?;
528        write!(f, ".x")
529    }
530
531    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532        Self::compile_absolute_pos_base_name(f)?;
533        write!(f, ".y")
534    }
535
536    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537        Self::compile_absolute_pos_base_name(f)?;
538        write!(f, ".z")
539    }
540
541    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542        f.write_str("threadgroups_per_grid")
543    }
544
545    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        f.write_str("total_threadgroups_in_grid")
547    }
548
549    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550        Self::compile_cube_count_base_name(f)?;
551        write!(f, ".x")
552    }
553
554    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        Self::compile_cube_count_base_name(f)?;
556        write!(f, ".y")
557    }
558
559    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        Self::compile_cube_count_base_name(f)?;
561        write!(f, ".z")
562    }
563
564    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565        f.write_str("threads_per_threadgroup")
566    }
567
568    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569        f.write_str("total_thread_in_threadgroup")
570    }
571
572    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573        Self::compile_cube_dim_base_name(f)?;
574        write!(f, ".x")
575    }
576
577    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578        Self::compile_cube_dim_base_name(f)?;
579        write!(f, ".y")
580    }
581
582    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583        Self::compile_cube_dim_base_name(f)?;
584        write!(f, ".z")
585    }
586
587    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        f.write_str("threadgroup_pos_in_grid")
589    }
590
591    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        f.write_str("threadgroup_index_in_grid")
593    }
594
595    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        Self::compile_cube_pos_base_name(f)?;
597        write!(f, ".x")
598    }
599
600    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601        Self::compile_cube_pos_base_name(f)?;
602        write!(f, ".y")
603    }
604
605    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606        Self::compile_cube_pos_base_name(f)?;
607        write!(f, ".z")
608    }
609
610    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611        // no need to compute it on metal as there is y a built-in for it
612        Ok(())
613    }
614
615    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        f.write_str("thread_pos_in_threadgroup")
617    }
618
619    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.write_str("thread_index_in_threadgroup")
621    }
622
623    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624        Self::compile_unit_pos_base_name(f)?;
625        write!(f, ".x")
626    }
627
628    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629        Self::compile_unit_pos_base_name(f)?;
630        write!(f, ".y")
631    }
632
633    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634        Self::compile_unit_pos_base_name(f)?;
635        write!(f, ".z")
636    }
637
638    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.write_str("simd_size")
640    }
641
642    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
643        f.write_str("threads_per_simdgroup_checked")
644    }
645
646    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
647        f.write_str("simd_group_id")
648    }
649
650    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
651        f.write_str("simd_lane_id")
652    }
653}
654
655// Instructions
656
657impl DialectInstructions<Self> for MslDialect {
658    // atomics
659    fn compile_atomic_add(
660        f: &mut std::fmt::Formatter<'_>,
661        lhs: &Variable<Self>,
662        rhs: &Variable<Self>,
663        out: &Variable<Self>,
664    ) -> std::fmt::Result {
665        let out = out.fmt_left();
666        writeln!(
667            f,
668            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
669        )
670    }
671
672    fn compile_atomic_and(
673        f: &mut std::fmt::Formatter<'_>,
674        lhs: &Variable<Self>,
675        rhs: &Variable<Self>,
676        out: &Variable<Self>,
677    ) -> std::fmt::Result {
678        let out = out.fmt_left();
679        writeln!(
680            f,
681            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
682        )
683    }
684
685    fn compile_atomic_cas(
686        f: &mut std::fmt::Formatter<'_>,
687        input: &Variable<Self>,
688        cmp: &Variable<Self>,
689        val: &Variable<Self>,
690        out: &Variable<Self>,
691    ) -> std::fmt::Result {
692        let out = out.fmt_left();
693        writeln!(
694            f,
695            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
696        )
697    }
698
699    fn compile_atomic_load(
700        f: &mut std::fmt::Formatter<'_>,
701        input: &Variable<Self>,
702        out: &Variable<Self>,
703    ) -> std::fmt::Result {
704        let out = out.fmt_left();
705        writeln!(
706            f,
707            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
708        )
709    }
710
711    fn compile_atomic_max(
712        f: &mut std::fmt::Formatter<'_>,
713        lhs: &Variable<Self>,
714        rhs: &Variable<Self>,
715        out: &Variable<Self>,
716    ) -> std::fmt::Result {
717        let out = out.fmt_left();
718        writeln!(
719            f,
720            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
721        )
722    }
723
724    fn compile_atomic_min(
725        f: &mut std::fmt::Formatter<'_>,
726        lhs: &Variable<Self>,
727        rhs: &Variable<Self>,
728        out: &Variable<Self>,
729    ) -> std::fmt::Result {
730        let out = out.fmt_left();
731        writeln!(
732            f,
733            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
734        )
735    }
736
737    fn compile_atomic_or(
738        f: &mut std::fmt::Formatter<'_>,
739        lhs: &Variable<Self>,
740        rhs: &Variable<Self>,
741        out: &Variable<Self>,
742    ) -> std::fmt::Result {
743        let out = out.fmt_left();
744        writeln!(
745            f,
746            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
747        )
748    }
749
750    fn compile_atomic_store(
751        f: &mut std::fmt::Formatter<'_>,
752        input: &Variable<Self>,
753        out: &Variable<Self>,
754    ) -> std::fmt::Result {
755        writeln!(
756            f,
757            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
758        )
759    }
760
761    fn compile_atomic_sub(
762        f: &mut std::fmt::Formatter<'_>,
763        lhs: &Variable<Self>,
764        rhs: &Variable<Self>,
765        out: &Variable<Self>,
766    ) -> std::fmt::Result {
767        let out = out.fmt_left();
768        writeln!(
769            f,
770            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
771        )
772    }
773
774    fn compile_atomic_swap(
775        f: &mut std::fmt::Formatter<'_>,
776        lhs: &Variable<Self>,
777        rhs: &Variable<Self>,
778        out: &Variable<Self>,
779    ) -> std::fmt::Result {
780        let out = out.fmt_left();
781        writeln!(
782            f,
783            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
784        )
785    }
786
787    fn compile_atomic_xor(
788        f: &mut std::fmt::Formatter<'_>,
789        lhs: &Variable<Self>,
790        rhs: &Variable<Self>,
791        out: &Variable<Self>,
792    ) -> std::fmt::Result {
793        let out = out.fmt_left();
794        writeln!(
795            f,
796            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
797        )
798    }
799
800    fn compile_saturating_add(
801        f: &mut std::fmt::Formatter<'_>,
802        lhs: impl Display,
803        rhs: impl Display,
804        _item: Item<Self>,
805    ) -> std::fmt::Result {
806        write!(f, "addsat({lhs}, {rhs})")
807    }
808
809    fn compile_saturating_sub(
810        f: &mut std::fmt::Formatter<'_>,
811        lhs: impl Display,
812        rhs: impl Display,
813        _item: Item<Self>,
814    ) -> std::fmt::Result {
815        write!(f, "subsat({lhs}, {rhs})")
816    }
817
818    // debug
819    fn compile_instruction_printf(
820        f: &mut std::fmt::Formatter<'_>,
821        format_string: &str,
822        args: &[Variable<Self>],
823    ) -> std::fmt::Result {
824        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
825        let args = match args.is_empty() {
826            true => "".to_string(),
827            false => format!(", {}", args.join(",")),
828        };
829        writeln!(f, "os_log_default.log({format_string:?}{args});")
830    }
831
832    // logs
833    fn compile_instruction_log1p_scalar<T: Component<Self>>(
834        f: &mut std::fmt::Formatter<'_>,
835        input: T,
836    ) -> std::fmt::Result {
837        match input.elem() {
838            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
839                write!(f, "log(half(1.0f) + {input})")
840            }
841            _ => write!(f, "log(1.0f + {input})"),
842        }
843    }
844
845    // sync
846    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
848    }
849
850    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
851        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
852    }
853
854    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
856    }
857
858    // trigo
859    fn compile_instruction_tanh_scalar<T: Component<Self>>(
860        f: &mut std::fmt::Formatter<'_>,
861        input: T,
862    ) -> std::fmt::Result {
863        write!(f, "safe_tanh_scalar({input})")
864    }
865
866    // unary
867    fn compile_instruction_find_first_set<T: Component<Self>>(
868        f: &mut std::fmt::Formatter<'_>,
869        input: T,
870        out_elem: Elem<Self>,
871    ) -> std::fmt::Result {
872        write!(f, "{out_elem}(")?;
873        match input.elem() {
874            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
875            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
876            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
877        }?;
878        write!(f, ")")
879    }
880
881    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
882        f: &mut std::fmt::Formatter<'_>,
883        input: T,
884        out_elem: Elem<Self>,
885    ) -> std::fmt::Result {
886        write!(f, "{out_elem}(clz({input}))")
887    }
888
889    fn compile_instruction_popcount_scalar<T: Component<Self>>(
890        f: &mut std::fmt::Formatter<'_>,
891        input: T,
892        out_elem: Elem<Self>,
893    ) -> std::fmt::Result {
894        write!(f, "{out_elem}(")?;
895        match input.elem() {
896            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
897            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
898        }?;
899        write!(f, ")")
900    }
901
902    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
903        f: &mut std::fmt::Formatter<'_>,
904        input: T,
905        out_elem: Elem<Self>,
906    ) -> std::fmt::Result {
907        write!(f, "{out_elem}(")?;
908        match out_elem {
909            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
910            _ => write!(
911                f,
912                "reverse_bits({}) >> {}",
913                shared::unary::zero_extend(input),
914                (size_of::<u32>() - out_elem.size()) * 8
915            ),
916        }?;
917        write!(f, ")")
918    }
919
920    // others
921    fn compile_instruction_max_function_name(
922        f: &mut std::fmt::Formatter<'_>,
923        _item: Item<Self>,
924    ) -> std::fmt::Result {
925        write!(f, "max")
926    }
927
928    fn compile_instruction_min_function_name(
929        f: &mut std::fmt::Formatter<'_>,
930        _item: Item<Self>,
931    ) -> std::fmt::Result {
932        write!(f, "min")
933    }
934
935    fn compile_instruction_powf(
936        f: &mut std::fmt::Formatter<'_>,
937        lhs: &str,
938        rhs: &str,
939        elem: Elem<Self>,
940    ) -> std::fmt::Result {
941        write!(f, "pow({lhs}, {elem}({rhs}))")
942    }
943
944    fn compile_instruction_half_function_name_prefix() -> &'static str {
945        ""
946    }
947
948    fn compile_instruction_half2_function_name_prefix() -> &'static str {
949        ""
950    }
951
952    // Warp
953    fn compile_warp_shuffle(
954        f: &mut std::fmt::Formatter<'_>,
955        var: &str,
956        source: &str,
957    ) -> std::fmt::Result {
958        write!(f, "simd_shuffle({var}, {source})")
959    }
960
961    fn compile_warp_shuffle_xor(
962        f: &mut std::fmt::Formatter<'_>,
963        var: &str,
964        _elem: &Elem<Self>,
965        offset: &str,
966    ) -> std::fmt::Result {
967        write!(f, "simd_shuffle_xor({var}, {offset})")
968    }
969
970    fn compile_warp_shuffle_up(
971        f: &mut std::fmt::Formatter<'_>,
972        var: &str,
973        offset: &str,
974    ) -> std::fmt::Result {
975        write!(f, "simd_shuffle_up({var}, {offset})")
976    }
977
978    fn compile_warp_shuffle_down(
979        f: &mut std::fmt::Formatter<'_>,
980        var: &str,
981        offset: &str,
982    ) -> std::fmt::Result {
983        write!(f, "simd_shuffle_down({var}, {offset})")
984    }
985
986    fn compile_warp_all<T: Component<Self>>(
987        f: &mut std::fmt::Formatter<'_>,
988        input: &T,
989    ) -> std::fmt::Result {
990        write!(f, "simd_all({input})")
991    }
992
993    fn compile_warp_any<T: Component<Self>>(
994        f: &mut std::fmt::Formatter<'_>,
995        input: &T,
996    ) -> std::fmt::Result {
997        write!(f, "simd_any({input})")
998    }
999
1000    fn compile_warp_ballot(
1001        f: &mut std::fmt::Formatter<'_>,
1002        input: &Variable<Self>,
1003        out_elem: &Elem<Self>,
1004    ) -> std::fmt::Result {
1005        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1006    }
1007}
1008
1009// Coop Matrices dialect
1010
1011impl DialectWmmaCompiler<Self> for MslDialect {
1012    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
1013        writeln!(f, "#include <metal_simdgroup_matrix>")
1014    }
1015
1016    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        // not used
1018        Ok(())
1019    }
1020
1021    fn compile_wmma_fragment_declaration(
1022        f: &mut std::fmt::Formatter<'_>,
1023        var: &crate::shared::Variable<MslDialect>,
1024    ) -> std::fmt::Result {
1025        wmma_api_base::compile_fragment_declaration(f, var)
1026    }
1027
1028    fn compile_wwma_fragment_ident(
1029        _f: &mut std::fmt::Formatter<'_>,
1030        _ident: &FragmentIdent<Self>,
1031    ) -> std::fmt::Result {
1032        // not used
1033        Ok(())
1034    }
1035
1036    fn compile_wmma_fragment_layout(
1037        _f: &mut std::fmt::Formatter<'_>,
1038        _layout: &FragmentLayout<Self>,
1039    ) -> std::fmt::Result {
1040        // not used
1041        Ok(())
1042    }
1043
1044    fn compile_wmma_fragment(
1045        f: &mut std::fmt::Formatter<'_>,
1046        fragment: &Fragment<Self>,
1047    ) -> std::fmt::Result {
1048        let ty = fragment.elem;
1049        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
1050        let m = fragment.m;
1051        let n = fragment.n;
1052        let k = fragment.k;
1053        if m != 8 || n != 8 || k != 8 {
1054            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1055        }
1056        write!(f, "simdgroup_{ty}8x8")
1057    }
1058
1059    fn compile_wmma_instruction(
1060        f: &mut std::fmt::Formatter<'_>,
1061        instruction: &WmmaInstruction<Self>,
1062    ) -> std::fmt::Result {
1063        match instruction {
1064            WmmaInstruction::Fill { frag, value } => {
1065                match frag {
1066                    Variable::WmmaFragment { .. } => {
1067                        let ty = frag.elem();
1068                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
1069                        writeln!(
1070                            f,
1071                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1072                        )
1073                    }
1074                    _ => panic!("should be a fragment"),
1075                }
1076            }
1077            WmmaInstruction::Load {
1078                frag,
1079                value,
1080                stride,
1081                offset,
1082                layout: _layout,
1083            } => {
1084                let transpose = match frag {
1085                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1086                        Some(FragmentLayout::RowMajor) => false,
1087                        Some(FragmentLayout::ColMajor) => true,
1088                        _ => false,
1089                    },
1090                    _ => panic!("should be a fragment"),
1091                };
1092                let item = value.item();
1093                if item.vectorization > 1 {
1094                    let elem = item.elem;
1095                    match value {
1096                        Variable::GlobalInputArray(..) => writeln!(
1097                            f,
1098                            "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1099                        ),
1100                        Variable::SharedArray(..) => writeln!(
1101                            f,
1102                            "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1103                        ),
1104                        _ => panic!(
1105                            "Vectorized wmma load is only supported from global or shared memory."
1106                        ),
1107                    }
1108                } else {
1109                    writeln!(
1110                        f,
1111                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1112                    )
1113                }
1114            }
1115            WmmaInstruction::Execute {
1116                frag_a: a,
1117                frag_b: b,
1118                frag_c: c,
1119                frag_d: d,
1120                ..
1121            } => {
1122                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1123            }
1124            WmmaInstruction::Store {
1125                output,
1126                frag,
1127                stride,
1128                offset,
1129                layout: _layout,
1130            } => {
1131                let item = output.item();
1132                let mut reinterpret_cast = item.vectorization > 1;
1133                let elem = match item.elem {
1134                    Elem::BF16 => {
1135                        reinterpret_cast = true;
1136                        Elem::F16
1137                    }
1138                    _ => item.elem,
1139                };
1140                if reinterpret_cast {
1141                    writeln!(
1142                        f,
1143                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1144                    )
1145                } else {
1146                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1147                }?;
1148                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1149            }
1150            WmmaInstruction::Cast { input, output } => {
1151                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1152                let ty = match output {
1153                    Variable::WmmaFragment { frag, .. } => frag.elem,
1154                    _ => panic!("should be a fragment"),
1155                };
1156                match ty {
1157                    Elem::BF16 => {
1158                        let addr_space = Self::address_space_for_variable(output);
1159                        let elem = Elem::<Self>::F16;
1160                        // TODO: to test with benchmarks
1161
1162                        writeln!(
1163                            f,
1164                            "for(int e=0; e<8; e++) {{
1165    {ty} elem = {ty}({input}.thread_elements()[e]);
1166    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1167}}"
1168                        )
1169                    }
1170                    _ => {
1171                        writeln!(
1172                            f,
1173                            "for(int e=0; e<8; e++) {{
1174    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1175}}"
1176                        )
1177                    }
1178                }
1179            }
1180            WmmaInstruction::ExecuteManual {
1181                shape,
1182                frag_a,
1183                frag_b,
1184                frag_c,
1185                frag_d,
1186            } => {
1187                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1188            }
1189            WmmaInstruction::ExecuteScaled {
1190                shape,
1191                frag_a,
1192                frag_b,
1193                frag_c,
1194                frag_d,
1195                scales_a,
1196                scales_b,
1197                scales_factor,
1198            } => Self::compile_scaled_mma(
1199                f,
1200                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1201                *scales_a,
1202                *scales_b,
1203                *scales_factor,
1204            ),
1205            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1206                f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1207            }
1208        }
1209    }
1210
1211    fn compile_manual_mma(
1212        f: &mut std::fmt::Formatter<'_>,
1213        _mma: shared::ManualMma<Self>,
1214    ) -> std::fmt::Result {
1215        f.write_str("#error manual mma not supported on Metal\n")
1216    }
1217
1218    fn compile_scaled_mma(
1219        f: &mut std::fmt::Formatter<'_>,
1220        _mma: shared::ManualMma<Self>,
1221        _scales_a: Variable<Self>,
1222        _scales_b: Variable<Self>,
1223        _scales_factor: u32,
1224    ) -> std::fmt::Result {
1225        f.write_str("#error scaled mma not supported on Metal\n")
1226    }
1227
1228    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1229        let types = vec![
1230            (
1231                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1232                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1233                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1234            ),
1235            (
1236                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1237                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1238                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1239            ),
1240            (
1241                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1242                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1243                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1244            ),
1245            (
1246                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1248                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1249            ),
1250        ];
1251        types
1252            .into_iter()
1253            .map(|(a_type, b_type, cd_type)| MmaConfig {
1254                a_type,
1255                b_type,
1256                cd_type,
1257                m: 8,
1258                n: 8,
1259                k: 8,
1260            })
1261            .collect()
1262    }
1263
1264    fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1265        Vec::new()
1266    }
1267}
1268
1269// Coop Matrices dialect
1270
1271impl DialectProcessors<Self> for MslDialect {
1272    fn processors() -> Vec<Box<dyn gpu::Processor>> {
1273        Vec::new()
1274    }
1275}