Skip to main content

cubecl_cpp/metal/
dialect.rs

1use super::{
2    AddressSpace, Extension,
3    arch::MetalArchitecture,
4    extension::{format_ffs, format_mulhi},
5    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8    Dialect,
9    shared::{
10        self, AtomicKind, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11        DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12        DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13        FragmentIdent, FragmentLayout, Instruction, Item, KernelArg, ManualMma, SharedMemory,
14        SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15    },
16};
17use core::panic;
18use cubecl_core::ir::{self as gpu, features::MmaConfig};
19use std::fmt::Display;
20
21#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
22pub struct MslDialect {}
23
24// Base dialect
25
26impl Dialect for MslDialect {
27    type Architecture = MetalArchitecture;
28}
29
30impl MslDialect {
31    fn warp_op_vectorized(
32        f: &mut core::fmt::Formatter<'_>,
33        input: &Variable<Self>,
34        out: &Variable<Self>,
35        simd_op_prefix: &str,
36        simd_op_suffix: &str,
37    ) -> core::fmt::Result {
38        let out = out.fmt_left();
39        let vectorization = input.item().vectorization;
40
41        f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
42
43        for k in 0..vectorization {
44            let index = if vectorization > 1 {
45                format!(".i_{k}")
46            } else {
47                String::new()
48            };
49            let comma = if k + 1 < vectorization { "," } else { "" };
50
51            writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
52        }
53
54        f.write_fmt(format_args!("}};\n"))
55    }
56}
57
58impl DialectWarpReduceCompiler<Self> for MslDialect {
59    fn warp_reduce_sum(
60        f: &mut core::fmt::Formatter<'_>,
61        input: &Variable<Self>,
62        out: &Variable<Self>,
63    ) -> core::fmt::Result {
64        Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
65    }
66    fn warp_reduce_prod(
67        f: &mut core::fmt::Formatter<'_>,
68        input: &Variable<Self>,
69        out: &Variable<Self>,
70    ) -> core::fmt::Result {
71        Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
72    }
73    fn warp_reduce_max(
74        f: &mut core::fmt::Formatter<'_>,
75        input: &Variable<Self>,
76        out: &Variable<Self>,
77    ) -> core::fmt::Result {
78        Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
79    }
80    fn warp_reduce_min(
81        f: &mut core::fmt::Formatter<'_>,
82        input: &Variable<Self>,
83        out: &Variable<Self>,
84    ) -> core::fmt::Result {
85        Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
86    }
87    fn warp_reduce_all(
88        f: &mut core::fmt::Formatter<'_>,
89        input: &Variable<Self>,
90        out: &Variable<Self>,
91    ) -> core::fmt::Result {
92        Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
93    }
94    fn warp_reduce_any(
95        f: &mut core::fmt::Formatter<'_>,
96        input: &Variable<Self>,
97        out: &Variable<Self>,
98    ) -> core::fmt::Result {
99        Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
100    }
101    fn warp_reduce_sum_inclusive(
102        f: &mut core::fmt::Formatter<'_>,
103        input: &Variable<Self>,
104        out: &Variable<Self>,
105    ) -> core::fmt::Result {
106        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
107    }
108    fn warp_reduce_prod_inclusive(
109        f: &mut core::fmt::Formatter<'_>,
110        input: &Variable<Self>,
111        out: &Variable<Self>,
112    ) -> core::fmt::Result {
113        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
114    }
115    fn warp_reduce_sum_exclusive(
116        f: &mut core::fmt::Formatter<'_>,
117        input: &Variable<Self>,
118        out: &Variable<Self>,
119    ) -> core::fmt::Result {
120        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
121    }
122    fn warp_reduce_prod_exclusive(
123        f: &mut core::fmt::Formatter<'_>,
124        input: &Variable<Self>,
125        out: &Variable<Self>,
126    ) -> core::fmt::Result {
127        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
128    }
129}
130
131// Includes
132
133impl DialectIncludes<Self> for MslDialect {
134    type Extension = Extension<Self>;
135
136    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
137        write!(
138            f,
139            "
140#include <metal_stdlib>
141using namespace metal;
142"
143        )?;
144        Ok(())
145    }
146
147    fn compile_extensions(
148        f: &mut std::fmt::Formatter<'_>,
149        extensions: &[Self::Extension],
150    ) -> std::fmt::Result {
151        for extension in extensions {
152            match extension {
153                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
154                Extension::Ffs(elem) => format_ffs(f, elem)?,
155                Extension::MulHi(elem) => format_mulhi(f, elem)?,
156                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
157                Extension::NoExtension => {}
158            }
159        }
160        Ok(())
161    }
162
163    fn register_instruction_extension(
164        extensions: &mut Vec<Self::Extension>,
165        instruction: &Instruction<Self>,
166    ) {
167        let mut register_extension = |extension: Self::Extension| {
168            if !extensions.contains(&extension) {
169                extensions.push(extension);
170            }
171        };
172        #[allow(clippy::single_match)]
173        match instruction {
174            shared::Instruction::<Self>::Erf(instruction) => {
175                register_extension(Extension::Erf(
176                    instruction.input.elem(),
177                    instruction.out.elem(),
178                ));
179            }
180            shared::Instruction::<Self>::FindFirstSet(instruction) => {
181                let input_elem = instruction.input.elem();
182                match input_elem {
183                    Elem::U32 | Elem::U64 => {
184                        register_extension(Extension::Ffs(instruction.input.elem()));
185                    }
186                    Elem::I32 => {
187                        register_extension(Extension::Ffs(Elem::<Self>::U32));
188                        register_extension(Extension::Ffs(instruction.input.elem()));
189                    }
190                    Elem::I64 => {
191                        register_extension(Extension::Ffs(Elem::<Self>::U64));
192                        register_extension(Extension::Ffs(instruction.input.elem()));
193                    }
194                    _ => {
195                        register_extension(Extension::Ffs(Elem::<Self>::U32));
196                    }
197                }
198            }
199            shared::Instruction::<Self>::HiMul(instruction) => {
200                register_extension(Extension::MulHi(instruction.out.elem()));
201            }
202            shared::Instruction::<Self>::Tanh(instruction) => {
203                register_extension(Extension::SafeTanh(instruction.input.item()));
204            }
205            _ => {}
206        }
207    }
208
209    fn register_warp_instruction_extension(
210        _extensions: &mut Vec<Self::Extension>,
211        _instruction: &WarpInstruction<Self>,
212    ) {
213    }
214}
215
216// Types
217
218impl DialectTypes<Self> for MslDialect {
219    fn item_can_be_optimized() -> bool {
220        false
221    }
222
223    fn compile_type_definitions(
224        f: &mut std::fmt::Formatter<'_>,
225        items: &std::collections::HashSet<crate::shared::Item<Self>>,
226        scalars: &[(Elem<Self>, usize)],
227        info: &cubecl_core::Info,
228        flags: &Flags<Self>,
229    ) -> std::fmt::Result {
230        for item in items.iter() {
231            let elem = item.elem;
232            let size = item.vectorization;
233            let alignment = elem.size() * size;
234            if size > 1 {
235                write!(
236                    f,
237                    "
238struct alignas({alignment}) {item} {{"
239                )?;
240
241                for i in 0..size {
242                    write!(
243                        f,
244                        "
245    {elem} i_{i};"
246                    )?;
247                }
248
249                f.write_str("\n};\n")?;
250            }
251        }
252
253        shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
254        Ok(())
255    }
256
257    fn compile_elem(
258        f: &mut std::fmt::Formatter<'_>,
259        elem: &shared::Elem<Self>,
260        _words: bool,
261    ) -> std::fmt::Result {
262        // we always use the word form of types
263        match elem {
264            shared::Elem::FP4(_)
265            | shared::Elem::FP4x2(_)
266            | shared::Elem::FP6(_)
267            | shared::Elem::FP6x2(_)
268            | shared::Elem::FP8(_)
269            | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270            shared::Elem::F16 => f.write_str("half"),
271            shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272            shared::Elem::F32 => f.write_str("float"),
273            shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274            shared::Elem::BF16 => f.write_str("bfloat"),
275            shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276            shared::Elem::TF32 => f.write_str("float"),
277            shared::Elem::I8 => f.write_str("char"),
278            shared::Elem::I16 => f.write_str("short"),
279            shared::Elem::I32 => f.write_str("int"),
280            shared::Elem::I64 => f.write_str("long"),
281            shared::Elem::U8 => f.write_str("uchar"),
282            shared::Elem::U16 => f.write_str("ushort"),
283            shared::Elem::U32 => f.write_str("uint"),
284            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
285            shared::Elem::Bool => f.write_str("bool"),
286            shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287            shared::Elem::Atomic(inner) => inner.fmt(f),
288            shared::Elem::_Dialect(_) => Ok(()),
289        }
290    }
291
292    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293        if 1 == item.vectorization {
294            return write!(f, "{}", item.elem);
295        }
296        if item.native {
297            write!(f, "{}{}", item.elem, item.vectorization)
298        } else {
299            write!(f, "{}_{}", item.elem, item.vectorization)
300        }
301    }
302
303    fn compile_atomic_kind(
304        f: &mut std::fmt::Formatter<'_>,
305        kind: &AtomicKind<Self>,
306    ) -> std::fmt::Result {
307        match kind {
308            AtomicKind::I32 => write!(f, "atomic_int"),
309            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310            AtomicKind::U32 => write!(f, "atomic_uint"),
311            AtomicKind::U64 => write!(f, "atomic_ulong"),
312            AtomicKind::F16 | AtomicKind::F16x2 => panic!("F16 atomic kind no supported."),
313            AtomicKind::BF16 | AtomicKind::BF16x2 => panic!("BF16 atomic kind no supported."),
314            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
315            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316            AtomicKind::_Dialect(_) => Ok(()),
317        }
318    }
319
320    fn address_space_for_variable(variable: &Variable<Self>) -> String {
321        format!("{} ", AddressSpace::from(variable))
322    }
323
324    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        write!(f, "thread")
326    }
327
328    fn compile_shared_memory_declaration(
329        f: &mut std::fmt::Formatter<'_>,
330        shared: &SharedMemory<Self>,
331    ) -> std::fmt::Result {
332        match shared {
333            SharedMemory::Array {
334                index,
335                item,
336                length,
337                offset,
338                ..
339            } => {
340                let size_bytes = length * item.size();
341                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342                writeln!(
343                    f,
344                    "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345                )
346            }
347            SharedMemory::Value {
348                index,
349                item,
350                offset,
351                ..
352            } => {
353                let size_bytes = item.size();
354                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355                writeln!(
356                    f,
357                    "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358                )
359            }
360        }
361    }
362}
363
364// Kernel argument bindings
365
366impl DialectBindings<Self> for MslDialect {
367    fn compile_kernel_signature(
368        f: &mut std::fmt::Formatter<'_>,
369        kernel_name: &str,
370        tensor_maps: &[KernelArg<Self>],
371        buffers: &[KernelArg<Self>],
372        flags: &Flags<Self>,
373    ) -> std::fmt::Result {
374        write!(
375            (f),
376            "
377[[kernel]]
378void {kernel_name}("
379        )?;
380        // Global bindings args
381        let mut buffer_idx = 0;
382        debug_assert!(
383            tensor_maps.is_empty(),
384            "Tensor maps aren't supported for metal"
385        );
386        for (i, b) in buffers.iter().enumerate() {
387            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
388        }
389
390        if flags.has_info {
391            let comma = if buffer_idx > 0 { "," } else { "" };
392            let (address_space, var) = match flags.has_dynamic_meta {
393                true => (AddressSpace::ConstDevice, "info_st* info_ptr"),
394                false => (AddressSpace::Constant, "info_st& info"),
395            };
396            let attribute = address_space.attribute();
397
398            write!(f, "{comma}\n    {address_space} {var}",)?;
399            // attribute
400            attribute.indexed_fmt(buffer_idx, f)?;
401            buffer_idx += 1;
402        }
403
404        // Global metal builtins args
405        let builtins = vec![
406            (
407                flags.indexes.absolute_pos_tuple,
408                Variable::<Self>::AbsolutePosBaseName,
409            ),
410            (
411                flags.indexes.cube_dim_tuple,
412                Variable::<Self>::CubeDimBaseName,
413            ),
414            (
415                flags.indexes.cube_count_tuple,
416                Variable::<Self>::CubeCountBaseName,
417            ),
418            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
419            (
420                flags.indexes.unit_pos_tuple,
421                Variable::<Self>::UnitPosBaseName,
422            ),
423            (
424                flags.indexes.cube_pos_tuple,
425                Variable::<Self>::CubePosBaseName,
426            ),
427            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
428            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
429            (flags.indexes.plane_pos, Variable::<Self>::PlanePos),
430        ];
431        let comma = buffer_idx > 0;
432        builtins
433            .iter()
434            .filter(|(cond, _)| *cond)
435            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
436        f.write_str("\n)")
437    }
438
439    fn compile_bindings_body(
440        f: &mut std::fmt::Formatter<'_>,
441        body: &shared::Body<Self>,
442    ) -> std::fmt::Result {
443        if !body.shared_memories.is_empty() {
444            let size = body
445                .shared_memories
446                .iter()
447                .map(|it| it.offset() + it.size())
448                .max()
449                .unwrap();
450
451            writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
452        }
453        if body.info_by_ptr && body.has_dynamic_meta {
454            let address_space = AddressSpace::ConstDevice;
455            writeln!(f, "const {address_space} info_st& info = *info_ptr;")?;
456            // Could use `info_ptr + 1` but that seems dirty, so use manual `sizeof` instead
457            writeln!(
458                f,
459                "const {address_space} {addr}* dynamic_meta = reinterpret_cast<const {address_space} {addr}*>(
460                    reinterpret_cast<const {address_space} char*>(info_ptr) + sizeof(info_st)
461                );\n",
462                addr = body.address_type,
463            )?;
464        }
465        Ok(())
466    }
467}
468
469// Cube builtins dialect
470
471impl DialectCubeBuiltins<Self> for MslDialect {
472    /// Depending on the dialect available built-in variables the
473    /// inclusion rules might change.
474    /// For instance in metal we have a built-in for the Unit plane position
475    /// so we don't rely on other builtins.
476    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
477        let absolute_pos = flags.absolute_pos;
478        let cube_count = flags.cube_count;
479        let cube_dim = flags.cube_dim;
480        let cube_pos = flags.cube_pos;
481        let plane_dim_checked = flags.plane_dim_checked;
482        let plane_index = flags.plane_pos;
483        let unit_pos = flags.unit_pos;
484        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
485        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
486        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
487        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
488        let cluster_pos = flags.cluster_pos;
489        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
490        let unit_pos_plane = flags.unit_pos_plane || plane_index;
491        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
492        CubeIndexFlags {
493            absolute_pos_tuple,
494            absolute_pos,
495            cube_count_tuple,
496            cube_count,
497            cube_dim_tuple,
498            cube_dim,
499            cube_pos_tuple,
500            cube_pos,
501            plane_dim,
502            plane_dim_checked,
503            plane_pos: plane_index,
504            unit_pos_tuple,
505            unit_pos,
506            unit_pos_plane,
507            cluster_pos,
508        }
509    }
510
511    fn compile_absolute_pos_tuple_computation(
512        _f: &mut std::fmt::Formatter<'_>,
513    ) -> std::fmt::Result {
514        // no need to compute it on metal as there is y a built-in for it
515        Ok(())
516    }
517
518    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519        f.write_str("thread_pos_in_grid")
520    }
521
522    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        f.write_str("thread_index_in_grid")
524    }
525
526    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        Self::compile_absolute_pos_base_name(f)?;
528        write!(f, ".x")
529    }
530
531    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532        Self::compile_absolute_pos_base_name(f)?;
533        write!(f, ".y")
534    }
535
536    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537        Self::compile_absolute_pos_base_name(f)?;
538        write!(f, ".z")
539    }
540
541    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542        f.write_str("threadgroups_per_grid")
543    }
544
545    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        f.write_str("total_threadgroups_in_grid")
547    }
548
549    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550        Self::compile_cube_count_base_name(f)?;
551        write!(f, ".x")
552    }
553
554    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        Self::compile_cube_count_base_name(f)?;
556        write!(f, ".y")
557    }
558
559    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        Self::compile_cube_count_base_name(f)?;
561        write!(f, ".z")
562    }
563
564    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565        f.write_str("threads_per_threadgroup")
566    }
567
568    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569        f.write_str("total_thread_in_threadgroup")
570    }
571
572    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573        Self::compile_cube_dim_base_name(f)?;
574        write!(f, ".x")
575    }
576
577    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578        Self::compile_cube_dim_base_name(f)?;
579        write!(f, ".y")
580    }
581
582    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583        Self::compile_cube_dim_base_name(f)?;
584        write!(f, ".z")
585    }
586
587    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        f.write_str("threadgroup_pos_in_grid")
589    }
590
591    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        f.write_str("threadgroup_index_in_grid")
593    }
594
595    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        Self::compile_cube_pos_base_name(f)?;
597        write!(f, ".x")
598    }
599
600    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601        Self::compile_cube_pos_base_name(f)?;
602        write!(f, ".y")
603    }
604
605    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606        Self::compile_cube_pos_base_name(f)?;
607        write!(f, ".z")
608    }
609
610    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611        // no need to compute it on metal as there is y a built-in for it
612        Ok(())
613    }
614
615    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        f.write_str("thread_pos_in_threadgroup")
617    }
618
619    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.write_str("thread_index_in_threadgroup")
621    }
622
623    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624        Self::compile_unit_pos_base_name(f)?;
625        write!(f, ".x")
626    }
627
628    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629        Self::compile_unit_pos_base_name(f)?;
630        write!(f, ".y")
631    }
632
633    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634        Self::compile_unit_pos_base_name(f)?;
635        write!(f, ".z")
636    }
637
638    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.write_str("simd_size")
640    }
641
642    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
643        f.write_str("threads_per_simdgroup_checked")
644    }
645
646    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
647        f.write_str("simd_group_id")
648    }
649
650    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
651        f.write_str("simd_lane_id")
652    }
653}
654
655// Instructions
656
657impl DialectInstructions<Self> for MslDialect {
658    // atomics
659    fn compile_atomic_add(
660        f: &mut std::fmt::Formatter<'_>,
661        lhs: &Variable<Self>,
662        rhs: &Variable<Self>,
663        out: &Variable<Self>,
664    ) -> std::fmt::Result {
665        let out = out.fmt_left();
666        writeln!(
667            f,
668            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
669        )
670    }
671
672    fn compile_atomic_and(
673        f: &mut std::fmt::Formatter<'_>,
674        lhs: &Variable<Self>,
675        rhs: &Variable<Self>,
676        out: &Variable<Self>,
677    ) -> std::fmt::Result {
678        let out = out.fmt_left();
679        writeln!(
680            f,
681            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
682        )
683    }
684
685    fn compile_atomic_cas(
686        f: &mut std::fmt::Formatter<'_>,
687        input: &Variable<Self>,
688        cmp: &Variable<Self>,
689        val: &Variable<Self>,
690        out: &Variable<Self>,
691    ) -> std::fmt::Result {
692        let out = out.fmt_left();
693        writeln!(
694            f,
695            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
696        )
697    }
698
699    fn compile_atomic_load(
700        f: &mut std::fmt::Formatter<'_>,
701        input: &Variable<Self>,
702        out: &Variable<Self>,
703    ) -> std::fmt::Result {
704        let out = out.fmt_left();
705        writeln!(
706            f,
707            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
708        )
709    }
710
711    fn compile_atomic_max(
712        f: &mut std::fmt::Formatter<'_>,
713        lhs: &Variable<Self>,
714        rhs: &Variable<Self>,
715        out: &Variable<Self>,
716    ) -> std::fmt::Result {
717        let out = out.fmt_left();
718        writeln!(
719            f,
720            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
721        )
722    }
723
724    fn compile_atomic_min(
725        f: &mut std::fmt::Formatter<'_>,
726        lhs: &Variable<Self>,
727        rhs: &Variable<Self>,
728        out: &Variable<Self>,
729    ) -> std::fmt::Result {
730        let out = out.fmt_left();
731        writeln!(
732            f,
733            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
734        )
735    }
736
737    fn compile_atomic_or(
738        f: &mut std::fmt::Formatter<'_>,
739        lhs: &Variable<Self>,
740        rhs: &Variable<Self>,
741        out: &Variable<Self>,
742    ) -> std::fmt::Result {
743        let out = out.fmt_left();
744        writeln!(
745            f,
746            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
747        )
748    }
749
750    fn compile_atomic_store(
751        f: &mut std::fmt::Formatter<'_>,
752        input: &Variable<Self>,
753        out: &Variable<Self>,
754    ) -> std::fmt::Result {
755        writeln!(
756            f,
757            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
758        )
759    }
760
761    fn compile_atomic_sub(
762        f: &mut std::fmt::Formatter<'_>,
763        lhs: &Variable<Self>,
764        rhs: &Variable<Self>,
765        out: &Variable<Self>,
766    ) -> std::fmt::Result {
767        let out = out.fmt_left();
768        writeln!(
769            f,
770            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
771        )
772    }
773
774    fn compile_atomic_swap(
775        f: &mut std::fmt::Formatter<'_>,
776        lhs: &Variable<Self>,
777        rhs: &Variable<Self>,
778        out: &Variable<Self>,
779    ) -> std::fmt::Result {
780        let out = out.fmt_left();
781        writeln!(
782            f,
783            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
784        )
785    }
786
787    fn compile_atomic_xor(
788        f: &mut std::fmt::Formatter<'_>,
789        lhs: &Variable<Self>,
790        rhs: &Variable<Self>,
791        out: &Variable<Self>,
792    ) -> std::fmt::Result {
793        let out = out.fmt_left();
794        writeln!(
795            f,
796            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
797        )
798    }
799
800    fn compile_saturating_add(
801        f: &mut std::fmt::Formatter<'_>,
802        lhs: impl Display,
803        rhs: impl Display,
804        _item: Item<Self>,
805    ) -> std::fmt::Result {
806        write!(f, "addsat({lhs}, {rhs})")
807    }
808
809    fn compile_saturating_sub(
810        f: &mut std::fmt::Formatter<'_>,
811        lhs: impl Display,
812        rhs: impl Display,
813        _item: Item<Self>,
814    ) -> std::fmt::Result {
815        write!(f, "subsat({lhs}, {rhs})")
816    }
817
818    // debug
819    fn compile_instruction_printf(
820        f: &mut std::fmt::Formatter<'_>,
821        format_string: &str,
822        args: &[Variable<Self>],
823    ) -> std::fmt::Result {
824        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
825        let args = match args.is_empty() {
826            true => "".to_string(),
827            false => format!(", {}", args.join(",")),
828        };
829        writeln!(f, "os_log_default.log({format_string:?}{args});")
830    }
831
832    // logs
833    fn compile_instruction_log1p_scalar<T: Component<Self>>(
834        f: &mut std::fmt::Formatter<'_>,
835        input: T,
836    ) -> std::fmt::Result {
837        match input.elem() {
838            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
839                write!(f, "log(half(1.0f) + {input})")
840            }
841            _ => write!(f, "log(1.0f + {input})"),
842        }
843    }
844
845    // sync
846    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
848    }
849
850    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
851        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
852    }
853
854    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
856    }
857
858    // trigo
859    fn compile_instruction_tanh_scalar<T: Component<Self>>(
860        f: &mut std::fmt::Formatter<'_>,
861        input: T,
862    ) -> std::fmt::Result {
863        write!(f, "safe_tanh_scalar({input})")
864    }
865
866    // unary
867    fn compile_instruction_find_first_set<T: Component<Self>>(
868        f: &mut std::fmt::Formatter<'_>,
869        input: T,
870        out_elem: Elem<Self>,
871    ) -> std::fmt::Result {
872        write!(f, "{out_elem}(")?;
873        match input.elem() {
874            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
875            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
876            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
877        }?;
878        write!(f, ")")
879    }
880
881    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
882        f: &mut std::fmt::Formatter<'_>,
883        input: T,
884        out_elem: Elem<Self>,
885    ) -> std::fmt::Result {
886        write!(f, "{out_elem}(clz({input}))")
887    }
888
889    fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
890        f: &mut std::fmt::Formatter<'_>,
891        input: T,
892        out_elem: Elem<Self>,
893    ) -> std::fmt::Result {
894        write!(f, "{out_elem}(ctz({input}))")
895    }
896
897    fn compile_instruction_popcount_scalar<T: Component<Self>>(
898        f: &mut std::fmt::Formatter<'_>,
899        input: T,
900        out_elem: Elem<Self>,
901    ) -> std::fmt::Result {
902        write!(f, "{out_elem}(")?;
903        match input.elem() {
904            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
905            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
906        }?;
907        write!(f, ")")
908    }
909
910    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
911        f: &mut std::fmt::Formatter<'_>,
912        input: T,
913        out_elem: Elem<Self>,
914    ) -> std::fmt::Result {
915        write!(f, "{out_elem}(")?;
916        match out_elem {
917            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
918            _ => write!(
919                f,
920                "reverse_bits({}) >> {}",
921                shared::unary::zero_extend(input),
922                (size_of::<u32>() - out_elem.size()) * 8
923            ),
924        }?;
925        write!(f, ")")
926    }
927
928    // others
929    fn compile_instruction_max_function_name(
930        f: &mut std::fmt::Formatter<'_>,
931        _item: Item<Self>,
932    ) -> std::fmt::Result {
933        write!(f, "max")
934    }
935
936    fn compile_instruction_min_function_name(
937        f: &mut std::fmt::Formatter<'_>,
938        _item: Item<Self>,
939    ) -> std::fmt::Result {
940        write!(f, "min")
941    }
942
943    fn compile_instruction_powf(
944        f: &mut std::fmt::Formatter<'_>,
945        lhs: &str,
946        rhs: &str,
947        elem: Elem<Self>,
948    ) -> std::fmt::Result {
949        write!(f, "pow({lhs}, {elem}({rhs}))")
950    }
951
952    fn compile_instruction_hypot(
953        f: &mut std::fmt::Formatter<'_>,
954        lhs: &str,
955        rhs: &str,
956        elem: Elem<Self>,
957    ) -> std::fmt::Result {
958        match elem {
959            Elem::F32 => write!(f, "length(float2({lhs}, {rhs}))"),
960            _ => write!(f, "#error Unsupported type for hypot: {elem}"),
961        }
962    }
963
964    fn compile_instruction_rhypot(
965        f: &mut std::fmt::Formatter<'_>,
966        lhs: &str,
967        rhs: &str,
968        elem: Elem<Self>,
969    ) -> std::fmt::Result {
970        match elem {
971            Elem::F32 => write!(f, "rsqrt({lhs} * {lhs} + {rhs} * {rhs})"),
972            _ => write!(f, "#error Unsupported type for hypot: {elem}"),
973        }
974    }
975
976    fn compile_instruction_half_function_name_prefix() -> &'static str {
977        ""
978    }
979
980    fn compile_instruction_half2_function_name_prefix() -> &'static str {
981        ""
982    }
983
984    // Warp
985    fn compile_warp_shuffle(
986        f: &mut std::fmt::Formatter<'_>,
987        var: &str,
988        source: &str,
989    ) -> std::fmt::Result {
990        write!(f, "simd_shuffle({var}, {source})")
991    }
992
993    fn compile_warp_shuffle_xor(
994        f: &mut std::fmt::Formatter<'_>,
995        var: &str,
996        _elem: &Elem<Self>,
997        offset: &str,
998    ) -> std::fmt::Result {
999        write!(f, "simd_shuffle_xor({var}, {offset})")
1000    }
1001
1002    fn compile_warp_shuffle_up(
1003        f: &mut std::fmt::Formatter<'_>,
1004        var: &str,
1005        offset: &str,
1006    ) -> std::fmt::Result {
1007        write!(f, "simd_shuffle_up({var}, {offset})")
1008    }
1009
1010    fn compile_warp_shuffle_down(
1011        f: &mut std::fmt::Formatter<'_>,
1012        var: &str,
1013        offset: &str,
1014    ) -> std::fmt::Result {
1015        write!(f, "simd_shuffle_down({var}, {offset})")
1016    }
1017
1018    fn compile_warp_all<T: Component<Self>>(
1019        f: &mut std::fmt::Formatter<'_>,
1020        input: &T,
1021    ) -> std::fmt::Result {
1022        write!(f, "simd_all({input})")
1023    }
1024
1025    fn compile_warp_any<T: Component<Self>>(
1026        f: &mut std::fmt::Formatter<'_>,
1027        input: &T,
1028    ) -> std::fmt::Result {
1029        write!(f, "simd_any({input})")
1030    }
1031
1032    fn compile_warp_ballot(
1033        f: &mut std::fmt::Formatter<'_>,
1034        input: &Variable<Self>,
1035        out_elem: &Elem<Self>,
1036    ) -> std::fmt::Result {
1037        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1038    }
1039
1040    fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041        write!(f, "__builtin_unreachable();")
1042    }
1043}
1044
1045// Coop Matrices dialect
1046
1047impl DialectWmmaCompiler<Self> for MslDialect {
1048    fn compile_wmma_includes(
1049        f: &mut std::fmt::Formatter<'_>,
1050        _flags: &Flags<Self>,
1051    ) -> std::fmt::Result {
1052        writeln!(f, "#include <metal_simdgroup_matrix>")
1053    }
1054
1055    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056        // not used
1057        Ok(())
1058    }
1059
1060    fn compile_wmma_fragment_declaration(
1061        f: &mut std::fmt::Formatter<'_>,
1062        var: &crate::shared::Variable<MslDialect>,
1063    ) -> std::fmt::Result {
1064        wmma_api_base::compile_fragment_declaration(f, var)
1065    }
1066
1067    fn compile_wwma_fragment_ident(
1068        _f: &mut std::fmt::Formatter<'_>,
1069        _ident: &FragmentIdent<Self>,
1070    ) -> std::fmt::Result {
1071        // not used
1072        Ok(())
1073    }
1074
1075    fn compile_wmma_fragment_layout(
1076        _f: &mut std::fmt::Formatter<'_>,
1077        _layout: &FragmentLayout<Self>,
1078    ) -> std::fmt::Result {
1079        // not used
1080        Ok(())
1081    }
1082
1083    fn compile_wmma_fragment(
1084        f: &mut std::fmt::Formatter<'_>,
1085        fragment: &Fragment<Self>,
1086    ) -> std::fmt::Result {
1087        let ty = fragment.elem;
1088        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
1089        let m = fragment.m;
1090        let n = fragment.n;
1091        let k = fragment.k;
1092        if m != 8 || n != 8 || k != 8 {
1093            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1094        }
1095        write!(f, "simdgroup_{ty}8x8")
1096    }
1097
1098    fn compile_wmma_instruction(
1099        f: &mut std::fmt::Formatter<'_>,
1100        instruction: &WmmaInstruction<Self>,
1101    ) -> std::fmt::Result {
1102        match instruction {
1103            WmmaInstruction::Fill { frag, value } => {
1104                match frag {
1105                    Variable::WmmaFragment { .. } => {
1106                        let ty = frag.elem();
1107                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
1108                        writeln!(
1109                            f,
1110                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1111                        )
1112                    }
1113                    _ => panic!("should be a fragment"),
1114                }
1115            }
1116            WmmaInstruction::Load {
1117                frag,
1118                value,
1119                stride,
1120                offset,
1121                layout: _layout,
1122            } => {
1123                let transpose = match frag {
1124                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1125                        Some(FragmentLayout::RowMajor) => false,
1126                        Some(FragmentLayout::ColMajor) => true,
1127                        _ => false,
1128                    },
1129                    _ => panic!("should be a fragment"),
1130                };
1131                let item = value.item();
1132                if item.vectorization > 1 {
1133                    let elem = item.elem;
1134                    match value {
1135                        Variable::GlobalInputArray(..) => writeln!(
1136                            f,
1137                            "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1138                        ),
1139                        Variable::SharedArray(..) => writeln!(
1140                            f,
1141                            "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1142                        ),
1143                        _ => panic!(
1144                            "Vectorized wmma load is only supported from global or shared memory."
1145                        ),
1146                    }
1147                } else {
1148                    writeln!(
1149                        f,
1150                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1151                    )
1152                }
1153            }
1154            WmmaInstruction::Execute {
1155                frag_a: a,
1156                frag_b: b,
1157                frag_c: c,
1158                frag_d: d,
1159                ..
1160            } => {
1161                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1162            }
1163            WmmaInstruction::Store {
1164                output,
1165                frag,
1166                stride,
1167                offset,
1168                layout: _layout,
1169            } => {
1170                let item = output.item();
1171                let mut reinterpret_cast = item.vectorization > 1;
1172                let elem = match item.elem {
1173                    Elem::BF16 => {
1174                        reinterpret_cast = true;
1175                        Elem::F16
1176                    }
1177                    _ => item.elem,
1178                };
1179                if reinterpret_cast {
1180                    writeln!(
1181                        f,
1182                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1183                    )
1184                } else {
1185                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1186                }?;
1187                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1188            }
1189            WmmaInstruction::Cast { input, output } => {
1190                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1191                let ty = match output {
1192                    Variable::WmmaFragment { frag, .. } => frag.elem,
1193                    _ => panic!("should be a fragment"),
1194                };
1195                match ty {
1196                    Elem::BF16 => {
1197                        let addr_space = Self::address_space_for_variable(output);
1198                        let elem = Elem::<Self>::F16;
1199                        // TODO: to test with benchmarks
1200
1201                        writeln!(
1202                            f,
1203                            "for(int e=0; e<8; e++) {{
1204    {ty} elem = {ty}({input}.thread_elements()[e]);
1205    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1206}}"
1207                        )
1208                    }
1209                    _ => {
1210                        writeln!(
1211                            f,
1212                            "for(int e=0; e<8; e++) {{
1213    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1214}}"
1215                        )
1216                    }
1217                }
1218            }
1219            WmmaInstruction::ExecuteManual {
1220                shape,
1221                frag_a,
1222                frag_b,
1223                frag_c,
1224                frag_d,
1225            } => {
1226                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1227            }
1228            WmmaInstruction::ExecuteScaled {
1229                shape,
1230                frag_a,
1231                frag_b,
1232                frag_c,
1233                frag_d,
1234                scales_a,
1235                scales_b,
1236                scales_factor,
1237            } => Self::compile_scaled_mma(
1238                f,
1239                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1240                *scales_a,
1241                *scales_b,
1242                *scales_factor,
1243            ),
1244            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1245                f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1246            }
1247        }
1248    }
1249
1250    fn compile_manual_mma(
1251        f: &mut std::fmt::Formatter<'_>,
1252        _mma: shared::ManualMma<Self>,
1253    ) -> std::fmt::Result {
1254        f.write_str("#error manual mma not supported on Metal\n")
1255    }
1256
1257    fn compile_scaled_mma(
1258        f: &mut std::fmt::Formatter<'_>,
1259        _mma: shared::ManualMma<Self>,
1260        _scales_a: Variable<Self>,
1261        _scales_b: Variable<Self>,
1262        _scales_factor: u32,
1263    ) -> std::fmt::Result {
1264        f.write_str("#error scaled mma not supported on Metal\n")
1265    }
1266
1267    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1268        let types = vec![
1269            (
1270                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1271                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1272                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1273            ),
1274            (
1275                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1276                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1277                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1278            ),
1279            (
1280                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1281                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1282                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1283            ),
1284            (
1285                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1286                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1287                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1288            ),
1289        ];
1290        types
1291            .into_iter()
1292            .map(|(a_type, b_type, cd_type)| MmaConfig {
1293                a_type,
1294                b_type,
1295                cd_type,
1296                m: 8,
1297                n: 8,
1298                k: 8,
1299            })
1300            .collect()
1301    }
1302
1303    fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1304        Vec::new()
1305    }
1306}
1307
1308// Coop Matrices dialect
1309
1310impl DialectProcessors<Self> for MslDialect {
1311    fn processors() -> Vec<Box<dyn gpu::Processor>> {
1312        Vec::new()
1313    }
1314}