cubecl_cpp/metal/
dialect.rs

1use super::{
2    AddressSpace, Extension,
3    arch::MetalArchitecture,
4    extension::{format_ffs, format_mulhi},
5    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8    Dialect,
9    shared::{
10        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11        DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12        DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13        FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
14        SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15    },
16};
17use core::panic;
18use cubecl_core::{
19    ir::{self as gpu, features::MmaConfig},
20    prelude::{Location, Visibility},
21};
22use std::fmt::Display;
23
24#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
25pub struct MslDialect {}
26
27// Base dialect
28
29impl Dialect for MslDialect {
30    type Architecture = MetalArchitecture;
31}
32
33impl MslDialect {
34    fn warp_op_vectorized(
35        f: &mut core::fmt::Formatter<'_>,
36        input: &Variable<Self>,
37        out: &Variable<Self>,
38        simd_op_prefix: &str,
39        simd_op_suffix: &str,
40    ) -> core::fmt::Result {
41        let out = out.fmt_left();
42        let vectorization = input.item().vectorization;
43
44        f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
45
46        for k in 0..vectorization {
47            let index = if vectorization > 1 {
48                format!(".i_{k}")
49            } else {
50                String::new()
51            };
52            let comma = if k + 1 < vectorization { "," } else { "" };
53
54            writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
55        }
56
57        f.write_fmt(format_args!("}};\n"))
58    }
59}
60
61impl DialectWarpReduceCompiler<Self> for MslDialect {
62    fn warp_reduce_sum(
63        f: &mut core::fmt::Formatter<'_>,
64        input: &Variable<Self>,
65        out: &Variable<Self>,
66    ) -> core::fmt::Result {
67        Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
68    }
69    fn warp_reduce_prod(
70        f: &mut core::fmt::Formatter<'_>,
71        input: &Variable<Self>,
72        out: &Variable<Self>,
73    ) -> core::fmt::Result {
74        Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
75    }
76    fn warp_reduce_max(
77        f: &mut core::fmt::Formatter<'_>,
78        input: &Variable<Self>,
79        out: &Variable<Self>,
80    ) -> core::fmt::Result {
81        Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
82    }
83    fn warp_reduce_min(
84        f: &mut core::fmt::Formatter<'_>,
85        input: &Variable<Self>,
86        out: &Variable<Self>,
87    ) -> core::fmt::Result {
88        Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
89    }
90    fn warp_reduce_all(
91        f: &mut core::fmt::Formatter<'_>,
92        input: &Variable<Self>,
93        out: &Variable<Self>,
94    ) -> core::fmt::Result {
95        Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
96    }
97    fn warp_reduce_any(
98        f: &mut core::fmt::Formatter<'_>,
99        input: &Variable<Self>,
100        out: &Variable<Self>,
101    ) -> core::fmt::Result {
102        Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
103    }
104    fn warp_reduce_sum_inclusive(
105        f: &mut core::fmt::Formatter<'_>,
106        input: &Variable<Self>,
107        out: &Variable<Self>,
108    ) -> core::fmt::Result {
109        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
110    }
111    fn warp_reduce_prod_inclusive(
112        f: &mut core::fmt::Formatter<'_>,
113        input: &Variable<Self>,
114        out: &Variable<Self>,
115    ) -> core::fmt::Result {
116        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
117    }
118    fn warp_reduce_sum_exclusive(
119        f: &mut core::fmt::Formatter<'_>,
120        input: &Variable<Self>,
121        out: &Variable<Self>,
122    ) -> core::fmt::Result {
123        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
124    }
125    fn warp_reduce_prod_exclusive(
126        f: &mut core::fmt::Formatter<'_>,
127        input: &Variable<Self>,
128        out: &Variable<Self>,
129    ) -> core::fmt::Result {
130        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
131    }
132}
133
134// Includes
135
136impl DialectIncludes<Self> for MslDialect {
137    type Extension = Extension<Self>;
138
139    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
140        write!(
141            f,
142            "
143#include <metal_stdlib>
144using namespace metal;
145"
146        )?;
147        Ok(())
148    }
149
150    fn compile_extensions(
151        f: &mut std::fmt::Formatter<'_>,
152        extensions: &[Self::Extension],
153    ) -> std::fmt::Result {
154        for extension in extensions {
155            match extension {
156                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
157                Extension::Ffs(elem) => format_ffs(f, elem)?,
158                Extension::MulHi(elem) => format_mulhi(f, elem)?,
159                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
160                Extension::NoExtension => {}
161            }
162        }
163        Ok(())
164    }
165
166    fn register_instruction_extension(
167        extensions: &mut Vec<Self::Extension>,
168        instruction: &Instruction<Self>,
169    ) {
170        let mut register_extension = |extension: Self::Extension| {
171            if !extensions.contains(&extension) {
172                extensions.push(extension);
173            }
174        };
175        #[allow(clippy::single_match)]
176        match instruction {
177            shared::Instruction::<Self>::Erf(instruction) => {
178                register_extension(Extension::Erf(
179                    instruction.input.elem(),
180                    instruction.out.elem(),
181                ));
182            }
183            shared::Instruction::<Self>::FindFirstSet(instruction) => {
184                let input_elem = instruction.input.elem();
185                match input_elem {
186                    Elem::U32 | Elem::U64 => {
187                        register_extension(Extension::Ffs(instruction.input.elem()));
188                    }
189                    Elem::I32 => {
190                        register_extension(Extension::Ffs(Elem::<Self>::U32));
191                        register_extension(Extension::Ffs(instruction.input.elem()));
192                    }
193                    Elem::I64 => {
194                        register_extension(Extension::Ffs(Elem::<Self>::U64));
195                        register_extension(Extension::Ffs(instruction.input.elem()));
196                    }
197                    _ => {
198                        register_extension(Extension::Ffs(Elem::<Self>::U32));
199                    }
200                }
201            }
202            shared::Instruction::<Self>::HiMul(instruction) => {
203                register_extension(Extension::MulHi(instruction.out.elem()));
204            }
205            shared::Instruction::<Self>::Tanh(instruction) => {
206                register_extension(Extension::SafeTanh(instruction.input.item()));
207            }
208            _ => {}
209        }
210    }
211
212    fn register_warp_instruction_extension(
213        _extensions: &mut Vec<Self::Extension>,
214        _instruction: &WarpInstruction<Self>,
215    ) {
216    }
217}
218
219// Types
220
221impl DialectTypes<Self> for MslDialect {
222    fn item_can_be_optimized() -> bool {
223        false
224    }
225
226    fn compile_type_definitions(
227        f: &mut std::fmt::Formatter<'_>,
228        items: &std::collections::HashSet<crate::shared::Item<Self>>,
229        _scalars: &[(Elem<Self>, usize)],
230        _flags: &Flags<Self>,
231    ) -> std::fmt::Result {
232        for item in items.iter() {
233            let elem = item.elem;
234            let size = item.vectorization;
235            let alignment = elem.size() * size;
236            if size > 1 {
237                write!(
238                    f,
239                    "
240struct alignas({alignment}) {item} {{"
241                )?;
242
243                for i in 0..size {
244                    write!(
245                        f,
246                        "
247    {elem} i_{i};"
248                    )?;
249                }
250
251                f.write_str("\n};\n")?;
252            }
253        }
254        Ok(())
255    }
256
257    fn compile_elem(
258        f: &mut std::fmt::Formatter<'_>,
259        elem: &shared::Elem<Self>,
260        _words: bool,
261    ) -> std::fmt::Result {
262        // we always use the word form of types
263        match elem {
264            shared::Elem::FP4(_)
265            | shared::Elem::FP4x2(_)
266            | shared::Elem::FP6(_)
267            | shared::Elem::FP6x2(_)
268            | shared::Elem::FP8(_)
269            | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270            shared::Elem::F16 => f.write_str("half"),
271            shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272            shared::Elem::F32 => f.write_str("float"),
273            shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274            shared::Elem::BF16 => f.write_str("bfloat"),
275            shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276            shared::Elem::TF32 => f.write_str("float"),
277            shared::Elem::I8 => f.write_str("char"),
278            shared::Elem::I16 => f.write_str("short"),
279            shared::Elem::I32 => f.write_str("int"),
280            shared::Elem::I64 => f.write_str("long"),
281            shared::Elem::U8 => f.write_str("uchar"),
282            shared::Elem::U16 => f.write_str("ushort"),
283            shared::Elem::U32 => f.write_str("uint"),
284            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
285            shared::Elem::Bool => f.write_str("bool"),
286            shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287            shared::Elem::Atomic(inner) => inner.fmt(f),
288            shared::Elem::_Dialect(_) => Ok(()),
289        }
290    }
291
292    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293        if 1 == item.vectorization {
294            return write!(f, "{}", item.elem);
295        }
296        if item.native {
297            write!(f, "{}{}", item.elem, item.vectorization)
298        } else {
299            write!(f, "{}_{}", item.elem, item.vectorization)
300        }
301    }
302
303    fn compile_atomic_kind(
304        f: &mut std::fmt::Formatter<'_>,
305        kind: &AtomicKind<Self>,
306    ) -> std::fmt::Result {
307        match kind {
308            AtomicKind::I32 => write!(f, "atomic_int"),
309            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310            AtomicKind::U32 => write!(f, "atomic_uint"),
311            AtomicKind::U64 => write!(f, "atomic_ulong"),
312            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
313            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
314            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
315            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316            AtomicKind::_Dialect(_) => Ok(()),
317        }
318    }
319
320    fn address_space_for_variable(variable: &Variable<Self>) -> String {
321        format!("{} ", AddressSpace::from(variable))
322    }
323
324    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        write!(f, "thread")
326    }
327
328    fn compile_shared_memory_declaration(
329        f: &mut std::fmt::Formatter<'_>,
330        shared: &SharedMemory<Self>,
331    ) -> std::fmt::Result {
332        match shared {
333            SharedMemory::Array {
334                index,
335                item,
336                length,
337                offset,
338                ..
339            } => {
340                let size_bytes = length * item.size();
341                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342                writeln!(
343                    f,
344                    "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345                )
346            }
347            SharedMemory::Value {
348                index,
349                item,
350                offset,
351                ..
352            } => {
353                let size_bytes = item.size();
354                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355                writeln!(
356                    f,
357                    "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358                )
359            }
360        }
361    }
362}
363
364// Kernel argument bindings
365
366impl DialectBindings<Self> for MslDialect {
367    fn compile_kernel_signature(
368        f: &mut std::fmt::Formatter<'_>,
369        kernel_name: &str,
370        tensor_maps: &[Binding<Self>],
371        buffers: &[Binding<Self>],
372        scalars: &[(Elem<Self>, usize)],
373        flags: &Flags<Self>,
374    ) -> std::fmt::Result {
375        write!(
376            (f),
377            "
378[[kernel]]
379void {kernel_name}("
380        )?;
381        // Global bindings args
382        let mut buffer_idx = 0;
383        debug_assert!(
384            tensor_maps.is_empty(),
385            "Tensor maps aren't supported for metal"
386        );
387        for (i, b) in buffers.iter().enumerate() {
388            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
389        }
390        if flags.static_meta_length > 0 {
391            let binding = Binding {
392                id: 0,
393                item: Item::scalar(Elem::<Self>::U32, true),
394                location: Location::Storage,
395                size: None,
396                vis: Visibility::Read,
397            };
398            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
399        }
400        for (elem, _) in scalars.iter() {
401            let binding = Binding {
402                id: 0,
403                item: Item::scalar(*elem, true),
404                location: Location::Storage,
405                size: None,
406                vis: Visibility::Read,
407            };
408
409            let name = format!("scalars_{elem}");
410            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
411        }
412
413        // Global metal builtins args
414        let builtins = vec![
415            (
416                flags.indexes.absolute_pos_tuple,
417                Variable::<Self>::AbsolutePosBaseName,
418            ),
419            (
420                flags.indexes.cube_dim_tuple,
421                Variable::<Self>::CubeDimBaseName,
422            ),
423            (
424                flags.indexes.cube_count_tuple,
425                Variable::<Self>::CubeCountBaseName,
426            ),
427            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
428            (
429                flags.indexes.unit_pos_tuple,
430                Variable::<Self>::UnitPosBaseName,
431            ),
432            (
433                flags.indexes.cube_pos_tuple,
434                Variable::<Self>::CubePosBaseName,
435            ),
436            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
437            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
438            (flags.indexes.plane_index, Variable::<Self>::PlanePos),
439        ];
440        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
441        builtins
442            .iter()
443            .filter(|(cond, _)| *cond)
444            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
445        f.write_str("\n)")
446    }
447
448    fn compile_bindings_body(
449        f: &mut std::fmt::Formatter<'_>,
450        body: &shared::Body<Self>,
451    ) -> std::fmt::Result {
452        if !body.shared_memories.is_empty() {
453            let size = body
454                .shared_memories
455                .iter()
456                .map(|it| it.offset() + it.size())
457                .max()
458                .unwrap();
459
460            writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
461        }
462        Ok(())
463    }
464}
465
466// Cube builtins dialect
467
468impl DialectCubeBuiltins<Self> for MslDialect {
469    /// Depending on the dialect available built-in variables the
470    /// inclusion rules might change.
471    /// For instance in metal we have a built-in for the Unit plane position
472    /// so we don't rely on other builtins.
473    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
474        let absolute_pos = flags.absolute_pos;
475        let cube_count = flags.cube_count;
476        let cube_dim = flags.cube_dim;
477        let cube_pos = flags.cube_pos;
478        let plane_dim_checked = flags.plane_dim_checked;
479        let plane_index = flags.plane_index;
480        let unit_pos = flags.unit_pos;
481        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
482        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
483        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
484        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
485        let cluster_pos = flags.cluster_pos;
486        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
487        let unit_pos_plane = flags.unit_pos_plane || plane_index;
488        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
489        CubeIndexFlags {
490            absolute_pos_tuple,
491            absolute_pos,
492            cube_count_tuple,
493            cube_count,
494            cube_dim_tuple,
495            cube_dim,
496            cube_pos_tuple,
497            cube_pos,
498            plane_dim,
499            plane_dim_checked,
500            plane_index,
501            unit_pos_tuple,
502            unit_pos,
503            unit_pos_plane,
504            cluster_pos,
505        }
506    }
507
508    fn compile_absolute_pos_tuple_computation(
509        _f: &mut std::fmt::Formatter<'_>,
510    ) -> std::fmt::Result {
511        // no need to compute it on metal as there is y a built-in for it
512        Ok(())
513    }
514
515    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        f.write_str("thread_pos_in_grid")
517    }
518
519    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        f.write_str("thread_index_in_grid")
521    }
522
523    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524        Self::compile_absolute_pos_base_name(f)?;
525        write!(f, ".x")
526    }
527
528    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529        Self::compile_absolute_pos_base_name(f)?;
530        write!(f, ".y")
531    }
532
533    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534        Self::compile_absolute_pos_base_name(f)?;
535        write!(f, ".z")
536    }
537
538    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539        f.write_str("threadgroups_per_grid")
540    }
541
542    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543        f.write_str("total_threadgroups_in_grid")
544    }
545
546    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547        Self::compile_cube_count_base_name(f)?;
548        write!(f, ".x")
549    }
550
551    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552        Self::compile_cube_count_base_name(f)?;
553        write!(f, ".y")
554    }
555
556    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        Self::compile_cube_count_base_name(f)?;
558        write!(f, ".z")
559    }
560
561    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562        f.write_str("threads_per_threadgroup")
563    }
564
565    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        f.write_str("total_thread_in_threadgroup")
567    }
568
569    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
570        Self::compile_cube_dim_base_name(f)?;
571        write!(f, ".x")
572    }
573
574    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575        Self::compile_cube_dim_base_name(f)?;
576        write!(f, ".y")
577    }
578
579    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580        Self::compile_cube_dim_base_name(f)?;
581        write!(f, ".z")
582    }
583
584    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585        f.write_str("threadgroup_pos_in_grid")
586    }
587
588    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589        f.write_str("threadgroup_index_in_grid")
590    }
591
592    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        Self::compile_cube_pos_base_name(f)?;
594        write!(f, ".x")
595    }
596
597    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598        Self::compile_cube_pos_base_name(f)?;
599        write!(f, ".y")
600    }
601
602    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603        Self::compile_cube_pos_base_name(f)?;
604        write!(f, ".z")
605    }
606
607    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608        // no need to compute it on metal as there is y a built-in for it
609        Ok(())
610    }
611
612    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
613        f.write_str("thread_pos_in_threadgroup")
614    }
615
616    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        f.write_str("thread_index_in_threadgroup")
618    }
619
620    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        Self::compile_unit_pos_base_name(f)?;
622        write!(f, ".x")
623    }
624
625    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        Self::compile_unit_pos_base_name(f)?;
627        write!(f, ".y")
628    }
629
630    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        Self::compile_unit_pos_base_name(f)?;
632        write!(f, ".z")
633    }
634
635    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        f.write_str("simd_size")
637    }
638
639    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640        f.write_str("threads_per_simdgroup_checked")
641    }
642
643    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644        f.write_str("simd_group_id")
645    }
646
647    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648        f.write_str("simd_lane_id")
649    }
650}
651
652// Instructions
653
654impl DialectInstructions<Self> for MslDialect {
655    // atomics
656    fn compile_atomic_add(
657        f: &mut std::fmt::Formatter<'_>,
658        lhs: &Variable<Self>,
659        rhs: &Variable<Self>,
660        out: &Variable<Self>,
661    ) -> std::fmt::Result {
662        let out = out.fmt_left();
663        writeln!(
664            f,
665            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
666        )
667    }
668
669    fn compile_atomic_and(
670        f: &mut std::fmt::Formatter<'_>,
671        lhs: &Variable<Self>,
672        rhs: &Variable<Self>,
673        out: &Variable<Self>,
674    ) -> std::fmt::Result {
675        let out = out.fmt_left();
676        writeln!(
677            f,
678            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
679        )
680    }
681
682    fn compile_atomic_cas(
683        f: &mut std::fmt::Formatter<'_>,
684        input: &Variable<Self>,
685        cmp: &Variable<Self>,
686        val: &Variable<Self>,
687        out: &Variable<Self>,
688    ) -> std::fmt::Result {
689        let out = out.fmt_left();
690        writeln!(
691            f,
692            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
693        )
694    }
695
696    fn compile_atomic_load(
697        f: &mut std::fmt::Formatter<'_>,
698        input: &Variable<Self>,
699        out: &Variable<Self>,
700    ) -> std::fmt::Result {
701        let out = out.fmt_left();
702        writeln!(
703            f,
704            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
705        )
706    }
707
708    fn compile_atomic_max(
709        f: &mut std::fmt::Formatter<'_>,
710        lhs: &Variable<Self>,
711        rhs: &Variable<Self>,
712        out: &Variable<Self>,
713    ) -> std::fmt::Result {
714        let out = out.fmt_left();
715        writeln!(
716            f,
717            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
718        )
719    }
720
721    fn compile_atomic_min(
722        f: &mut std::fmt::Formatter<'_>,
723        lhs: &Variable<Self>,
724        rhs: &Variable<Self>,
725        out: &Variable<Self>,
726    ) -> std::fmt::Result {
727        let out = out.fmt_left();
728        writeln!(
729            f,
730            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
731        )
732    }
733
734    fn compile_atomic_or(
735        f: &mut std::fmt::Formatter<'_>,
736        lhs: &Variable<Self>,
737        rhs: &Variable<Self>,
738        out: &Variable<Self>,
739    ) -> std::fmt::Result {
740        let out = out.fmt_left();
741        writeln!(
742            f,
743            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
744        )
745    }
746
747    fn compile_atomic_store(
748        f: &mut std::fmt::Formatter<'_>,
749        input: &Variable<Self>,
750        out: &Variable<Self>,
751    ) -> std::fmt::Result {
752        writeln!(
753            f,
754            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
755        )
756    }
757
758    fn compile_atomic_sub(
759        f: &mut std::fmt::Formatter<'_>,
760        lhs: &Variable<Self>,
761        rhs: &Variable<Self>,
762        out: &Variable<Self>,
763    ) -> std::fmt::Result {
764        let out = out.fmt_left();
765        writeln!(
766            f,
767            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
768        )
769    }
770
771    fn compile_atomic_swap(
772        f: &mut std::fmt::Formatter<'_>,
773        lhs: &Variable<Self>,
774        rhs: &Variable<Self>,
775        out: &Variable<Self>,
776    ) -> std::fmt::Result {
777        let out = out.fmt_left();
778        writeln!(
779            f,
780            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
781        )
782    }
783
784    fn compile_atomic_xor(
785        f: &mut std::fmt::Formatter<'_>,
786        lhs: &Variable<Self>,
787        rhs: &Variable<Self>,
788        out: &Variable<Self>,
789    ) -> std::fmt::Result {
790        let out = out.fmt_left();
791        writeln!(
792            f,
793            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
794        )
795    }
796
797    fn compile_saturating_add(
798        f: &mut std::fmt::Formatter<'_>,
799        lhs: impl Display,
800        rhs: impl Display,
801        _item: Item<Self>,
802    ) -> std::fmt::Result {
803        write!(f, "addsat({lhs}, {rhs})")
804    }
805
806    fn compile_saturating_sub(
807        f: &mut std::fmt::Formatter<'_>,
808        lhs: impl Display,
809        rhs: impl Display,
810        _item: Item<Self>,
811    ) -> std::fmt::Result {
812        write!(f, "subsat({lhs}, {rhs})")
813    }
814
815    // debug
816    fn compile_instruction_printf(
817        f: &mut std::fmt::Formatter<'_>,
818        format_string: &str,
819        args: &[Variable<Self>],
820    ) -> std::fmt::Result {
821        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
822        let args = match args.is_empty() {
823            true => "".to_string(),
824            false => format!(", {}", args.join(",")),
825        };
826        writeln!(f, "os_log_default.log({format_string:?}{args});")
827    }
828
829    // logs
830    fn compile_instruction_log1p_scalar<T: Component<Self>>(
831        f: &mut std::fmt::Formatter<'_>,
832        input: T,
833    ) -> std::fmt::Result {
834        match input.elem() {
835            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
836                write!(f, "log(half(1.0f) + {input})")
837            }
838            _ => write!(f, "log(1.0f + {input})"),
839        }
840    }
841
842    // sync
843    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
845    }
846
847    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
849    }
850
851    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
852        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
853    }
854
855    // trigo
856    fn compile_instruction_tanh_scalar<T: Component<Self>>(
857        f: &mut std::fmt::Formatter<'_>,
858        input: T,
859    ) -> std::fmt::Result {
860        write!(f, "safe_tanh_scalar({input})")
861    }
862
863    // unary
864    fn compile_instruction_find_first_set<T: Component<Self>>(
865        f: &mut std::fmt::Formatter<'_>,
866        input: T,
867        out_elem: Elem<Self>,
868    ) -> std::fmt::Result {
869        write!(f, "{out_elem}(")?;
870        match input.elem() {
871            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
872            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
873            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
874        }?;
875        write!(f, ")")
876    }
877
878    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
879        f: &mut std::fmt::Formatter<'_>,
880        input: T,
881        out_elem: Elem<Self>,
882    ) -> std::fmt::Result {
883        write!(f, "{out_elem}(clz({input}))")
884    }
885
886    fn compile_instruction_popcount_scalar<T: Component<Self>>(
887        f: &mut std::fmt::Formatter<'_>,
888        input: T,
889        out_elem: Elem<Self>,
890    ) -> std::fmt::Result {
891        write!(f, "{out_elem}(")?;
892        match input.elem() {
893            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
894            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
895        }?;
896        write!(f, ")")
897    }
898
899    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
900        f: &mut std::fmt::Formatter<'_>,
901        input: T,
902        out_elem: Elem<Self>,
903    ) -> std::fmt::Result {
904        write!(f, "{out_elem}(")?;
905        match out_elem {
906            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
907            _ => write!(
908                f,
909                "reverse_bits({}) >> {}",
910                shared::unary::zero_extend(input),
911                (size_of::<u32>() - out_elem.size()) * 8
912            ),
913        }?;
914        write!(f, ")")
915    }
916
917    // others
918    fn compile_instruction_max_function_name(
919        f: &mut std::fmt::Formatter<'_>,
920        _item: Item<Self>,
921    ) -> std::fmt::Result {
922        write!(f, "max")
923    }
924
925    fn compile_instruction_min_function_name(
926        f: &mut std::fmt::Formatter<'_>,
927        _item: Item<Self>,
928    ) -> std::fmt::Result {
929        write!(f, "min")
930    }
931
932    fn compile_instruction_powf(
933        f: &mut std::fmt::Formatter<'_>,
934        lhs: &str,
935        rhs: &str,
936        elem: Elem<Self>,
937    ) -> std::fmt::Result {
938        write!(f, "pow({lhs}, {elem}({rhs}))")
939    }
940
941    fn compile_instruction_half_function_name_prefix() -> &'static str {
942        ""
943    }
944
945    fn compile_instruction_half2_function_name_prefix() -> &'static str {
946        ""
947    }
948
949    // Warp
950    fn compile_warp_shuffle(
951        f: &mut std::fmt::Formatter<'_>,
952        var: &str,
953        source: &str,
954    ) -> std::fmt::Result {
955        write!(f, "simd_shuffle({var}, {source})")
956    }
957
958    fn compile_warp_shuffle_xor(
959        f: &mut std::fmt::Formatter<'_>,
960        var: &str,
961        _elem: &Elem<Self>,
962        offset: &str,
963    ) -> std::fmt::Result {
964        write!(f, "simd_shuffle_xor({var}, {offset})")
965    }
966
967    fn compile_warp_shuffle_up(
968        f: &mut std::fmt::Formatter<'_>,
969        var: &str,
970        offset: &str,
971    ) -> std::fmt::Result {
972        write!(f, "simd_shuffle_up({var}, {offset})")
973    }
974
975    fn compile_warp_shuffle_down(
976        f: &mut std::fmt::Formatter<'_>,
977        var: &str,
978        offset: &str,
979    ) -> std::fmt::Result {
980        write!(f, "simd_shuffle_down({var}, {offset})")
981    }
982
983    fn compile_warp_all<T: Component<Self>>(
984        f: &mut std::fmt::Formatter<'_>,
985        input: &T,
986    ) -> std::fmt::Result {
987        write!(f, "simd_all({input})")
988    }
989
990    fn compile_warp_any<T: Component<Self>>(
991        f: &mut std::fmt::Formatter<'_>,
992        input: &T,
993    ) -> std::fmt::Result {
994        write!(f, "simd_any({input})")
995    }
996
997    fn compile_warp_ballot(
998        f: &mut std::fmt::Formatter<'_>,
999        input: &Variable<Self>,
1000        out_elem: &Elem<Self>,
1001    ) -> std::fmt::Result {
1002        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1003    }
1004}
1005
1006// Coop Matrices dialect
1007
1008impl DialectWmmaCompiler<Self> for MslDialect {
1009    fn compile_wmma_includes(
1010        f: &mut std::fmt::Formatter<'_>,
1011        _flags: &Flags<Self>,
1012    ) -> std::fmt::Result {
1013        writeln!(f, "#include <metal_simdgroup_matrix>")
1014    }
1015
1016    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        // not used
1018        Ok(())
1019    }
1020
1021    fn compile_wmma_fragment_declaration(
1022        f: &mut std::fmt::Formatter<'_>,
1023        var: &crate::shared::Variable<MslDialect>,
1024    ) -> std::fmt::Result {
1025        wmma_api_base::compile_fragment_declaration(f, var)
1026    }
1027
1028    fn compile_wwma_fragment_ident(
1029        _f: &mut std::fmt::Formatter<'_>,
1030        _ident: &FragmentIdent<Self>,
1031    ) -> std::fmt::Result {
1032        // not used
1033        Ok(())
1034    }
1035
1036    fn compile_wmma_fragment_layout(
1037        _f: &mut std::fmt::Formatter<'_>,
1038        _layout: &FragmentLayout<Self>,
1039    ) -> std::fmt::Result {
1040        // not used
1041        Ok(())
1042    }
1043
1044    fn compile_wmma_fragment(
1045        f: &mut std::fmt::Formatter<'_>,
1046        fragment: &Fragment<Self>,
1047    ) -> std::fmt::Result {
1048        let ty = fragment.elem;
1049        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
1050        let m = fragment.m;
1051        let n = fragment.n;
1052        let k = fragment.k;
1053        if m != 8 || n != 8 || k != 8 {
1054            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1055        }
1056        write!(f, "simdgroup_{ty}8x8")
1057    }
1058
1059    fn compile_wmma_instruction(
1060        f: &mut std::fmt::Formatter<'_>,
1061        instruction: &WmmaInstruction<Self>,
1062    ) -> std::fmt::Result {
1063        match instruction {
1064            WmmaInstruction::Fill { frag, value } => {
1065                match frag {
1066                    Variable::WmmaFragment { .. } => {
1067                        let ty = frag.elem();
1068                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
1069                        writeln!(
1070                            f,
1071                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1072                        )
1073                    }
1074                    _ => panic!("should be a fragment"),
1075                }
1076            }
1077            WmmaInstruction::Load {
1078                frag,
1079                value,
1080                stride,
1081                offset,
1082                layout: _layout,
1083            } => {
1084                let transpose = match frag {
1085                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1086                        Some(FragmentLayout::RowMajor) => false,
1087                        Some(FragmentLayout::ColMajor) => true,
1088                        _ => false,
1089                    },
1090                    _ => panic!("should be a fragment"),
1091                };
1092                let item = value.item();
1093                if item.vectorization > 1 {
1094                    let elem = item.elem;
1095                    match value {
1096                        Variable::GlobalInputArray(..) => writeln!(
1097                            f,
1098                            "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1099                        ),
1100                        Variable::SharedArray(..) => writeln!(
1101                            f,
1102                            "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1103                        ),
1104                        _ => panic!(
1105                            "Vectorized wmma load is only supported from global or shared memory."
1106                        ),
1107                    }
1108                } else {
1109                    writeln!(
1110                        f,
1111                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1112                    )
1113                }
1114            }
1115            WmmaInstruction::Execute {
1116                frag_a: a,
1117                frag_b: b,
1118                frag_c: c,
1119                frag_d: d,
1120                ..
1121            } => {
1122                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1123            }
1124            WmmaInstruction::Store {
1125                output,
1126                frag,
1127                stride,
1128                offset,
1129                layout: _layout,
1130            } => {
1131                let item = output.item();
1132                let mut reinterpret_cast = item.vectorization > 1;
1133                let elem = match item.elem {
1134                    Elem::BF16 => {
1135                        reinterpret_cast = true;
1136                        Elem::F16
1137                    }
1138                    _ => item.elem,
1139                };
1140                if reinterpret_cast {
1141                    writeln!(
1142                        f,
1143                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1144                    )
1145                } else {
1146                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1147                }?;
1148                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1149            }
1150            WmmaInstruction::Cast { input, output } => {
1151                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1152                let ty = match output {
1153                    Variable::WmmaFragment { frag, .. } => frag.elem,
1154                    _ => panic!("should be a fragment"),
1155                };
1156                match ty {
1157                    Elem::BF16 => {
1158                        let addr_space = Self::address_space_for_variable(output);
1159                        let elem = Elem::<Self>::F16;
1160                        // TODO: to test with benchmarks
1161
1162                        writeln!(
1163                            f,
1164                            "for(int e=0; e<8; e++) {{
1165    {ty} elem = {ty}({input}.thread_elements()[e]);
1166    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1167}}"
1168                        )
1169                    }
1170                    _ => {
1171                        writeln!(
1172                            f,
1173                            "for(int e=0; e<8; e++) {{
1174    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1175}}"
1176                        )
1177                    }
1178                }
1179            }
1180            WmmaInstruction::ExecuteManual {
1181                shape,
1182                frag_a,
1183                frag_b,
1184                frag_c,
1185                frag_d,
1186            } => {
1187                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1188            }
1189            WmmaInstruction::ExecuteScaled {
1190                shape,
1191                frag_a,
1192                frag_b,
1193                frag_c,
1194                frag_d,
1195                scales_a,
1196                scales_b,
1197                scales_factor,
1198            } => Self::compile_scaled_mma(
1199                f,
1200                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1201                *scales_a,
1202                *scales_b,
1203                *scales_factor,
1204            ),
1205            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1206                f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1207            }
1208        }
1209    }
1210
1211    fn compile_manual_mma(
1212        f: &mut std::fmt::Formatter<'_>,
1213        _mma: shared::ManualMma<Self>,
1214    ) -> std::fmt::Result {
1215        f.write_str("#error manual mma not supported on Metal\n")
1216    }
1217
1218    fn compile_scaled_mma(
1219        f: &mut std::fmt::Formatter<'_>,
1220        _mma: shared::ManualMma<Self>,
1221        _scales_a: Variable<Self>,
1222        _scales_b: Variable<Self>,
1223        _scales_factor: u32,
1224    ) -> std::fmt::Result {
1225        f.write_str("#error scaled mma not supported on Metal\n")
1226    }
1227
1228    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1229        let types = vec![
1230            (
1231                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1232                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1233                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1234            ),
1235            (
1236                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1237                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1238                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1239            ),
1240            (
1241                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1242                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1243                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1244            ),
1245            (
1246                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1248                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1249            ),
1250        ];
1251        types
1252            .into_iter()
1253            .map(|(a_type, b_type, cd_type)| MmaConfig {
1254                a_type,
1255                b_type,
1256                cd_type,
1257                m: 8,
1258                n: 8,
1259                k: 8,
1260            })
1261            .collect()
1262    }
1263
1264    fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1265        Vec::new()
1266    }
1267}
1268
1269// Coop Matrices dialect
1270
1271impl DialectProcessors<Self> for MslDialect {
1272    fn processors() -> Vec<Box<dyn gpu::Processor>> {
1273        Vec::new()
1274    }
1275}