Skip to main content

cubecl_cpp/metal/
dialect.rs

1use super::{
2    AddressSpace, Extension,
3    arch::MetalArchitecture,
4    extension::{format_ffs, format_mulhi},
5    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8    Dialect,
9    shared::{
10        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11        DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12        DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13        FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
14        SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15    },
16};
17use core::panic;
18use cubecl_core::{
19    ir::{self as gpu, features::MmaConfig},
20    prelude::{Location, Visibility},
21};
22use std::fmt::Display;
23
24#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
25pub struct MslDialect {}
26
27// Base dialect
28
29impl Dialect for MslDialect {
30    type Architecture = MetalArchitecture;
31}
32
33impl MslDialect {
34    fn warp_op_vectorized(
35        f: &mut core::fmt::Formatter<'_>,
36        input: &Variable<Self>,
37        out: &Variable<Self>,
38        simd_op_prefix: &str,
39        simd_op_suffix: &str,
40    ) -> core::fmt::Result {
41        let out = out.fmt_left();
42        let vectorization = input.item().vectorization;
43
44        f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
45
46        for k in 0..vectorization {
47            let index = if vectorization > 1 {
48                format!(".i_{k}")
49            } else {
50                String::new()
51            };
52            let comma = if k + 1 < vectorization { "," } else { "" };
53
54            writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
55        }
56
57        f.write_fmt(format_args!("}};\n"))
58    }
59}
60
61impl DialectWarpReduceCompiler<Self> for MslDialect {
62    fn warp_reduce_sum(
63        f: &mut core::fmt::Formatter<'_>,
64        input: &Variable<Self>,
65        out: &Variable<Self>,
66    ) -> core::fmt::Result {
67        Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
68    }
69    fn warp_reduce_prod(
70        f: &mut core::fmt::Formatter<'_>,
71        input: &Variable<Self>,
72        out: &Variable<Self>,
73    ) -> core::fmt::Result {
74        Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
75    }
76    fn warp_reduce_max(
77        f: &mut core::fmt::Formatter<'_>,
78        input: &Variable<Self>,
79        out: &Variable<Self>,
80    ) -> core::fmt::Result {
81        Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
82    }
83    fn warp_reduce_min(
84        f: &mut core::fmt::Formatter<'_>,
85        input: &Variable<Self>,
86        out: &Variable<Self>,
87    ) -> core::fmt::Result {
88        Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
89    }
90    fn warp_reduce_all(
91        f: &mut core::fmt::Formatter<'_>,
92        input: &Variable<Self>,
93        out: &Variable<Self>,
94    ) -> core::fmt::Result {
95        Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
96    }
97    fn warp_reduce_any(
98        f: &mut core::fmt::Formatter<'_>,
99        input: &Variable<Self>,
100        out: &Variable<Self>,
101    ) -> core::fmt::Result {
102        Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
103    }
104    fn warp_reduce_sum_inclusive(
105        f: &mut core::fmt::Formatter<'_>,
106        input: &Variable<Self>,
107        out: &Variable<Self>,
108    ) -> core::fmt::Result {
109        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
110    }
111    fn warp_reduce_prod_inclusive(
112        f: &mut core::fmt::Formatter<'_>,
113        input: &Variable<Self>,
114        out: &Variable<Self>,
115    ) -> core::fmt::Result {
116        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
117    }
118    fn warp_reduce_sum_exclusive(
119        f: &mut core::fmt::Formatter<'_>,
120        input: &Variable<Self>,
121        out: &Variable<Self>,
122    ) -> core::fmt::Result {
123        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
124    }
125    fn warp_reduce_prod_exclusive(
126        f: &mut core::fmt::Formatter<'_>,
127        input: &Variable<Self>,
128        out: &Variable<Self>,
129    ) -> core::fmt::Result {
130        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
131    }
132}
133
134// Includes
135
136impl DialectIncludes<Self> for MslDialect {
137    type Extension = Extension<Self>;
138
139    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
140        write!(
141            f,
142            "
143#include <metal_stdlib>
144using namespace metal;
145"
146        )?;
147        Ok(())
148    }
149
150    fn compile_extensions(
151        f: &mut std::fmt::Formatter<'_>,
152        extensions: &[Self::Extension],
153    ) -> std::fmt::Result {
154        for extension in extensions {
155            match extension {
156                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
157                Extension::Ffs(elem) => format_ffs(f, elem)?,
158                Extension::MulHi(elem) => format_mulhi(f, elem)?,
159                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
160                Extension::NoExtension => {}
161            }
162        }
163        Ok(())
164    }
165
166    fn register_instruction_extension(
167        extensions: &mut Vec<Self::Extension>,
168        instruction: &Instruction<Self>,
169    ) {
170        let mut register_extension = |extension: Self::Extension| {
171            if !extensions.contains(&extension) {
172                extensions.push(extension);
173            }
174        };
175        #[allow(clippy::single_match)]
176        match instruction {
177            shared::Instruction::<Self>::Erf(instruction) => {
178                register_extension(Extension::Erf(
179                    instruction.input.elem(),
180                    instruction.out.elem(),
181                ));
182            }
183            shared::Instruction::<Self>::FindFirstSet(instruction) => {
184                let input_elem = instruction.input.elem();
185                match input_elem {
186                    Elem::U32 | Elem::U64 => {
187                        register_extension(Extension::Ffs(instruction.input.elem()));
188                    }
189                    Elem::I32 => {
190                        register_extension(Extension::Ffs(Elem::<Self>::U32));
191                        register_extension(Extension::Ffs(instruction.input.elem()));
192                    }
193                    Elem::I64 => {
194                        register_extension(Extension::Ffs(Elem::<Self>::U64));
195                        register_extension(Extension::Ffs(instruction.input.elem()));
196                    }
197                    _ => {
198                        register_extension(Extension::Ffs(Elem::<Self>::U32));
199                    }
200                }
201            }
202            shared::Instruction::<Self>::HiMul(instruction) => {
203                register_extension(Extension::MulHi(instruction.out.elem()));
204            }
205            shared::Instruction::<Self>::Tanh(instruction) => {
206                register_extension(Extension::SafeTanh(instruction.input.item()));
207            }
208            _ => {}
209        }
210    }
211
212    fn register_warp_instruction_extension(
213        _extensions: &mut Vec<Self::Extension>,
214        _instruction: &WarpInstruction<Self>,
215    ) {
216    }
217}
218
219// Types
220
221impl DialectTypes<Self> for MslDialect {
222    fn item_can_be_optimized() -> bool {
223        false
224    }
225
226    fn compile_type_definitions(
227        f: &mut std::fmt::Formatter<'_>,
228        items: &std::collections::HashSet<crate::shared::Item<Self>>,
229        _scalars: &[(Elem<Self>, usize)],
230        _flags: &Flags<Self>,
231    ) -> std::fmt::Result {
232        for item in items.iter() {
233            let elem = item.elem;
234            let size = item.vectorization;
235            let alignment = elem.size() * size;
236            if size > 1 {
237                write!(
238                    f,
239                    "
240struct alignas({alignment}) {item} {{"
241                )?;
242
243                for i in 0..size {
244                    write!(
245                        f,
246                        "
247    {elem} i_{i};"
248                    )?;
249                }
250
251                f.write_str("\n};\n")?;
252            }
253        }
254        Ok(())
255    }
256
257    fn compile_elem(
258        f: &mut std::fmt::Formatter<'_>,
259        elem: &shared::Elem<Self>,
260        _words: bool,
261    ) -> std::fmt::Result {
262        // we always use the word form of types
263        match elem {
264            shared::Elem::FP4(_)
265            | shared::Elem::FP4x2(_)
266            | shared::Elem::FP6(_)
267            | shared::Elem::FP6x2(_)
268            | shared::Elem::FP8(_)
269            | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270            shared::Elem::F16 => f.write_str("half"),
271            shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272            shared::Elem::F32 => f.write_str("float"),
273            shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274            shared::Elem::BF16 => f.write_str("bfloat"),
275            shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276            shared::Elem::TF32 => f.write_str("float"),
277            shared::Elem::I8 => f.write_str("char"),
278            shared::Elem::I16 => f.write_str("short"),
279            shared::Elem::I32 => f.write_str("int"),
280            shared::Elem::I64 => f.write_str("long"),
281            shared::Elem::U8 => f.write_str("uchar"),
282            shared::Elem::U16 => f.write_str("ushort"),
283            shared::Elem::U32 => f.write_str("uint"),
284            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
285            shared::Elem::Bool => f.write_str("bool"),
286            shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287            shared::Elem::Atomic(inner) => inner.fmt(f),
288            shared::Elem::_Dialect(_) => Ok(()),
289        }
290    }
291
292    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293        if 1 == item.vectorization {
294            return write!(f, "{}", item.elem);
295        }
296        if item.native {
297            write!(f, "{}{}", item.elem, item.vectorization)
298        } else {
299            write!(f, "{}_{}", item.elem, item.vectorization)
300        }
301    }
302
303    fn compile_atomic_kind(
304        f: &mut std::fmt::Formatter<'_>,
305        kind: &AtomicKind<Self>,
306    ) -> std::fmt::Result {
307        match kind {
308            AtomicKind::I32 => write!(f, "atomic_int"),
309            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310            AtomicKind::U32 => write!(f, "atomic_uint"),
311            AtomicKind::U64 => write!(f, "atomic_ulong"),
312            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
313            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
314            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
315            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316            AtomicKind::_Dialect(_) => Ok(()),
317        }
318    }
319
320    fn address_space_for_variable(variable: &Variable<Self>) -> String {
321        format!("{} ", AddressSpace::from(variable))
322    }
323
324    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        write!(f, "thread")
326    }
327
328    fn compile_shared_memory_declaration(
329        f: &mut std::fmt::Formatter<'_>,
330        shared: &SharedMemory<Self>,
331    ) -> std::fmt::Result {
332        match shared {
333            SharedMemory::Array {
334                index,
335                item,
336                length,
337                offset,
338                ..
339            } => {
340                let size_bytes = length * item.size();
341                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342                writeln!(
343                    f,
344                    "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345                )
346            }
347            SharedMemory::Value {
348                index,
349                item,
350                offset,
351                ..
352            } => {
353                let size_bytes = item.size();
354                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355                writeln!(
356                    f,
357                    "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358                )
359            }
360        }
361    }
362}
363
364// Kernel argument bindings
365
366impl DialectBindings<Self> for MslDialect {
367    fn compile_kernel_signature(
368        f: &mut std::fmt::Formatter<'_>,
369        kernel_name: &str,
370        tensor_maps: &[Binding<Self>],
371        buffers: &[Binding<Self>],
372        scalars: &[(Elem<Self>, usize)],
373        flags: &Flags<Self>,
374    ) -> std::fmt::Result {
375        write!(
376            (f),
377            "
378[[kernel]]
379void {kernel_name}("
380        )?;
381        // Global bindings args
382        let mut buffer_idx = 0;
383        debug_assert!(
384            tensor_maps.is_empty(),
385            "Tensor maps aren't supported for metal"
386        );
387        for (i, b) in buffers.iter().enumerate() {
388            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
389        }
390        if flags.static_meta_length > 0 {
391            let binding = Binding {
392                id: 0,
393                item: Item::scalar(Elem::<Self>::U32, true),
394                location: Location::Storage,
395                size: None,
396                vis: Visibility::Read,
397            };
398            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
399        }
400        for (elem, _) in scalars.iter() {
401            let binding = Binding {
402                id: 0,
403                item: Item::scalar(*elem, true),
404                location: Location::Storage,
405                size: None,
406                vis: Visibility::Read,
407            };
408
409            let name = format!("scalars_{elem}");
410            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
411        }
412
413        // Global metal builtins args
414        let builtins = vec![
415            (
416                flags.indexes.absolute_pos_tuple,
417                Variable::<Self>::AbsolutePosBaseName,
418            ),
419            (
420                flags.indexes.cube_dim_tuple,
421                Variable::<Self>::CubeDimBaseName,
422            ),
423            (
424                flags.indexes.cube_count_tuple,
425                Variable::<Self>::CubeCountBaseName,
426            ),
427            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
428            (
429                flags.indexes.unit_pos_tuple,
430                Variable::<Self>::UnitPosBaseName,
431            ),
432            (
433                flags.indexes.cube_pos_tuple,
434                Variable::<Self>::CubePosBaseName,
435            ),
436            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
437            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
438            (flags.indexes.plane_pos, Variable::<Self>::PlanePos),
439        ];
440        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
441        builtins
442            .iter()
443            .filter(|(cond, _)| *cond)
444            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
445        f.write_str("\n)")
446    }
447
448    fn compile_bindings_body(
449        f: &mut std::fmt::Formatter<'_>,
450        body: &shared::Body<Self>,
451    ) -> std::fmt::Result {
452        if !body.shared_memories.is_empty() {
453            let size = body
454                .shared_memories
455                .iter()
456                .map(|it| it.offset() + it.size())
457                .max()
458                .unwrap();
459
460            writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
461        }
462        Ok(())
463    }
464}
465
466// Cube builtins dialect
467
468impl DialectCubeBuiltins<Self> for MslDialect {
469    /// Depending on the dialect available built-in variables the
470    /// inclusion rules might change.
471    /// For instance in metal we have a built-in for the Unit plane position
472    /// so we don't rely on other builtins.
473    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
474        let absolute_pos = flags.absolute_pos;
475        let cube_count = flags.cube_count;
476        let cube_dim = flags.cube_dim;
477        let cube_pos = flags.cube_pos;
478        let plane_dim_checked = flags.plane_dim_checked;
479        let plane_index = flags.plane_pos;
480        let unit_pos = flags.unit_pos;
481        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
482        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
483        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
484        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
485        let cluster_pos = flags.cluster_pos;
486        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
487        let unit_pos_plane = flags.unit_pos_plane || plane_index;
488        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
489        CubeIndexFlags {
490            absolute_pos_tuple,
491            absolute_pos,
492            cube_count_tuple,
493            cube_count,
494            cube_dim_tuple,
495            cube_dim,
496            cube_pos_tuple,
497            cube_pos,
498            plane_dim,
499            plane_dim_checked,
500            plane_pos: plane_index,
501            unit_pos_tuple,
502            unit_pos,
503            unit_pos_plane,
504            cluster_pos,
505        }
506    }
507
508    fn compile_absolute_pos_tuple_computation(
509        _f: &mut std::fmt::Formatter<'_>,
510    ) -> std::fmt::Result {
511        // no need to compute it on metal as there is y a built-in for it
512        Ok(())
513    }
514
515    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        f.write_str("thread_pos_in_grid")
517    }
518
519    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        f.write_str("thread_index_in_grid")
521    }
522
523    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524        Self::compile_absolute_pos_base_name(f)?;
525        write!(f, ".x")
526    }
527
528    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529        Self::compile_absolute_pos_base_name(f)?;
530        write!(f, ".y")
531    }
532
533    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534        Self::compile_absolute_pos_base_name(f)?;
535        write!(f, ".z")
536    }
537
538    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539        f.write_str("threadgroups_per_grid")
540    }
541
542    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543        f.write_str("total_threadgroups_in_grid")
544    }
545
546    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547        Self::compile_cube_count_base_name(f)?;
548        write!(f, ".x")
549    }
550
551    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552        Self::compile_cube_count_base_name(f)?;
553        write!(f, ".y")
554    }
555
556    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        Self::compile_cube_count_base_name(f)?;
558        write!(f, ".z")
559    }
560
561    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562        f.write_str("threads_per_threadgroup")
563    }
564
565    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        f.write_str("total_thread_in_threadgroup")
567    }
568
569    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
570        Self::compile_cube_dim_base_name(f)?;
571        write!(f, ".x")
572    }
573
574    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575        Self::compile_cube_dim_base_name(f)?;
576        write!(f, ".y")
577    }
578
579    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580        Self::compile_cube_dim_base_name(f)?;
581        write!(f, ".z")
582    }
583
584    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585        f.write_str("threadgroup_pos_in_grid")
586    }
587
588    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589        f.write_str("threadgroup_index_in_grid")
590    }
591
592    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        Self::compile_cube_pos_base_name(f)?;
594        write!(f, ".x")
595    }
596
597    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598        Self::compile_cube_pos_base_name(f)?;
599        write!(f, ".y")
600    }
601
602    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603        Self::compile_cube_pos_base_name(f)?;
604        write!(f, ".z")
605    }
606
607    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608        // no need to compute it on metal as there is y a built-in for it
609        Ok(())
610    }
611
612    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
613        f.write_str("thread_pos_in_threadgroup")
614    }
615
616    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        f.write_str("thread_index_in_threadgroup")
618    }
619
620    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        Self::compile_unit_pos_base_name(f)?;
622        write!(f, ".x")
623    }
624
625    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        Self::compile_unit_pos_base_name(f)?;
627        write!(f, ".y")
628    }
629
630    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        Self::compile_unit_pos_base_name(f)?;
632        write!(f, ".z")
633    }
634
635    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        f.write_str("simd_size")
637    }
638
639    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640        f.write_str("threads_per_simdgroup_checked")
641    }
642
643    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644        f.write_str("simd_group_id")
645    }
646
647    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648        f.write_str("simd_lane_id")
649    }
650}
651
652// Instructions
653
654impl DialectInstructions<Self> for MslDialect {
655    // atomics
656    fn compile_atomic_add(
657        f: &mut std::fmt::Formatter<'_>,
658        lhs: &Variable<Self>,
659        rhs: &Variable<Self>,
660        out: &Variable<Self>,
661    ) -> std::fmt::Result {
662        let out = out.fmt_left();
663        writeln!(
664            f,
665            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
666        )
667    }
668
669    fn compile_atomic_and(
670        f: &mut std::fmt::Formatter<'_>,
671        lhs: &Variable<Self>,
672        rhs: &Variable<Self>,
673        out: &Variable<Self>,
674    ) -> std::fmt::Result {
675        let out = out.fmt_left();
676        writeln!(
677            f,
678            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
679        )
680    }
681
682    fn compile_atomic_cas(
683        f: &mut std::fmt::Formatter<'_>,
684        input: &Variable<Self>,
685        cmp: &Variable<Self>,
686        val: &Variable<Self>,
687        out: &Variable<Self>,
688    ) -> std::fmt::Result {
689        let out = out.fmt_left();
690        writeln!(
691            f,
692            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
693        )
694    }
695
696    fn compile_atomic_load(
697        f: &mut std::fmt::Formatter<'_>,
698        input: &Variable<Self>,
699        out: &Variable<Self>,
700    ) -> std::fmt::Result {
701        let out = out.fmt_left();
702        writeln!(
703            f,
704            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
705        )
706    }
707
708    fn compile_atomic_max(
709        f: &mut std::fmt::Formatter<'_>,
710        lhs: &Variable<Self>,
711        rhs: &Variable<Self>,
712        out: &Variable<Self>,
713    ) -> std::fmt::Result {
714        let out = out.fmt_left();
715        writeln!(
716            f,
717            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
718        )
719    }
720
721    fn compile_atomic_min(
722        f: &mut std::fmt::Formatter<'_>,
723        lhs: &Variable<Self>,
724        rhs: &Variable<Self>,
725        out: &Variable<Self>,
726    ) -> std::fmt::Result {
727        let out = out.fmt_left();
728        writeln!(
729            f,
730            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
731        )
732    }
733
734    fn compile_atomic_or(
735        f: &mut std::fmt::Formatter<'_>,
736        lhs: &Variable<Self>,
737        rhs: &Variable<Self>,
738        out: &Variable<Self>,
739    ) -> std::fmt::Result {
740        let out = out.fmt_left();
741        writeln!(
742            f,
743            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
744        )
745    }
746
747    fn compile_atomic_store(
748        f: &mut std::fmt::Formatter<'_>,
749        input: &Variable<Self>,
750        out: &Variable<Self>,
751    ) -> std::fmt::Result {
752        writeln!(
753            f,
754            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
755        )
756    }
757
758    fn compile_atomic_sub(
759        f: &mut std::fmt::Formatter<'_>,
760        lhs: &Variable<Self>,
761        rhs: &Variable<Self>,
762        out: &Variable<Self>,
763    ) -> std::fmt::Result {
764        let out = out.fmt_left();
765        writeln!(
766            f,
767            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
768        )
769    }
770
771    fn compile_atomic_swap(
772        f: &mut std::fmt::Formatter<'_>,
773        lhs: &Variable<Self>,
774        rhs: &Variable<Self>,
775        out: &Variable<Self>,
776    ) -> std::fmt::Result {
777        let out = out.fmt_left();
778        writeln!(
779            f,
780            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
781        )
782    }
783
784    fn compile_atomic_xor(
785        f: &mut std::fmt::Formatter<'_>,
786        lhs: &Variable<Self>,
787        rhs: &Variable<Self>,
788        out: &Variable<Self>,
789    ) -> std::fmt::Result {
790        let out = out.fmt_left();
791        writeln!(
792            f,
793            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
794        )
795    }
796
797    fn compile_saturating_add(
798        f: &mut std::fmt::Formatter<'_>,
799        lhs: impl Display,
800        rhs: impl Display,
801        _item: Item<Self>,
802    ) -> std::fmt::Result {
803        write!(f, "addsat({lhs}, {rhs})")
804    }
805
806    fn compile_saturating_sub(
807        f: &mut std::fmt::Formatter<'_>,
808        lhs: impl Display,
809        rhs: impl Display,
810        _item: Item<Self>,
811    ) -> std::fmt::Result {
812        write!(f, "subsat({lhs}, {rhs})")
813    }
814
815    // debug
816    fn compile_instruction_printf(
817        f: &mut std::fmt::Formatter<'_>,
818        format_string: &str,
819        args: &[Variable<Self>],
820    ) -> std::fmt::Result {
821        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
822        let args = match args.is_empty() {
823            true => "".to_string(),
824            false => format!(", {}", args.join(",")),
825        };
826        writeln!(f, "os_log_default.log({format_string:?}{args});")
827    }
828
829    // logs
830    fn compile_instruction_log1p_scalar<T: Component<Self>>(
831        f: &mut std::fmt::Formatter<'_>,
832        input: T,
833    ) -> std::fmt::Result {
834        match input.elem() {
835            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
836                write!(f, "log(half(1.0f) + {input})")
837            }
838            _ => write!(f, "log(1.0f + {input})"),
839        }
840    }
841
842    // sync
843    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
845    }
846
847    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
849    }
850
851    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
852        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
853    }
854
855    // trigo
856    fn compile_instruction_tanh_scalar<T: Component<Self>>(
857        f: &mut std::fmt::Formatter<'_>,
858        input: T,
859    ) -> std::fmt::Result {
860        write!(f, "safe_tanh_scalar({input})")
861    }
862
863    // unary
864    fn compile_instruction_find_first_set<T: Component<Self>>(
865        f: &mut std::fmt::Formatter<'_>,
866        input: T,
867        out_elem: Elem<Self>,
868    ) -> std::fmt::Result {
869        write!(f, "{out_elem}(")?;
870        match input.elem() {
871            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
872            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
873            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
874        }?;
875        write!(f, ")")
876    }
877
878    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
879        f: &mut std::fmt::Formatter<'_>,
880        input: T,
881        out_elem: Elem<Self>,
882    ) -> std::fmt::Result {
883        write!(f, "{out_elem}(clz({input}))")
884    }
885
886    fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
887        f: &mut std::fmt::Formatter<'_>,
888        input: T,
889        out_elem: Elem<Self>,
890    ) -> std::fmt::Result {
891        write!(f, "{out_elem}(ctz({input}))")
892    }
893
894    fn compile_instruction_popcount_scalar<T: Component<Self>>(
895        f: &mut std::fmt::Formatter<'_>,
896        input: T,
897        out_elem: Elem<Self>,
898    ) -> std::fmt::Result {
899        write!(f, "{out_elem}(")?;
900        match input.elem() {
901            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
902            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
903        }?;
904        write!(f, ")")
905    }
906
907    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
908        f: &mut std::fmt::Formatter<'_>,
909        input: T,
910        out_elem: Elem<Self>,
911    ) -> std::fmt::Result {
912        write!(f, "{out_elem}(")?;
913        match out_elem {
914            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
915            _ => write!(
916                f,
917                "reverse_bits({}) >> {}",
918                shared::unary::zero_extend(input),
919                (size_of::<u32>() - out_elem.size()) * 8
920            ),
921        }?;
922        write!(f, ")")
923    }
924
925    // others
926    fn compile_instruction_max_function_name(
927        f: &mut std::fmt::Formatter<'_>,
928        _item: Item<Self>,
929    ) -> std::fmt::Result {
930        write!(f, "max")
931    }
932
933    fn compile_instruction_min_function_name(
934        f: &mut std::fmt::Formatter<'_>,
935        _item: Item<Self>,
936    ) -> std::fmt::Result {
937        write!(f, "min")
938    }
939
940    fn compile_instruction_powf(
941        f: &mut std::fmt::Formatter<'_>,
942        lhs: &str,
943        rhs: &str,
944        elem: Elem<Self>,
945    ) -> std::fmt::Result {
946        write!(f, "pow({lhs}, {elem}({rhs}))")
947    }
948
949    fn compile_instruction_half_function_name_prefix() -> &'static str {
950        ""
951    }
952
953    fn compile_instruction_half2_function_name_prefix() -> &'static str {
954        ""
955    }
956
957    // Warp
958    fn compile_warp_shuffle(
959        f: &mut std::fmt::Formatter<'_>,
960        var: &str,
961        source: &str,
962    ) -> std::fmt::Result {
963        write!(f, "simd_shuffle({var}, {source})")
964    }
965
966    fn compile_warp_shuffle_xor(
967        f: &mut std::fmt::Formatter<'_>,
968        var: &str,
969        _elem: &Elem<Self>,
970        offset: &str,
971    ) -> std::fmt::Result {
972        write!(f, "simd_shuffle_xor({var}, {offset})")
973    }
974
975    fn compile_warp_shuffle_up(
976        f: &mut std::fmt::Formatter<'_>,
977        var: &str,
978        offset: &str,
979    ) -> std::fmt::Result {
980        write!(f, "simd_shuffle_up({var}, {offset})")
981    }
982
983    fn compile_warp_shuffle_down(
984        f: &mut std::fmt::Formatter<'_>,
985        var: &str,
986        offset: &str,
987    ) -> std::fmt::Result {
988        write!(f, "simd_shuffle_down({var}, {offset})")
989    }
990
991    fn compile_warp_all<T: Component<Self>>(
992        f: &mut std::fmt::Formatter<'_>,
993        input: &T,
994    ) -> std::fmt::Result {
995        write!(f, "simd_all({input})")
996    }
997
998    fn compile_warp_any<T: Component<Self>>(
999        f: &mut std::fmt::Formatter<'_>,
1000        input: &T,
1001    ) -> std::fmt::Result {
1002        write!(f, "simd_any({input})")
1003    }
1004
1005    fn compile_warp_ballot(
1006        f: &mut std::fmt::Formatter<'_>,
1007        input: &Variable<Self>,
1008        out_elem: &Elem<Self>,
1009    ) -> std::fmt::Result {
1010        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1011    }
1012}
1013
1014// Coop Matrices dialect
1015
1016impl DialectWmmaCompiler<Self> for MslDialect {
1017    fn compile_wmma_includes(
1018        f: &mut std::fmt::Formatter<'_>,
1019        _flags: &Flags<Self>,
1020    ) -> std::fmt::Result {
1021        writeln!(f, "#include <metal_simdgroup_matrix>")
1022    }
1023
1024    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1025        // not used
1026        Ok(())
1027    }
1028
1029    fn compile_wmma_fragment_declaration(
1030        f: &mut std::fmt::Formatter<'_>,
1031        var: &crate::shared::Variable<MslDialect>,
1032    ) -> std::fmt::Result {
1033        wmma_api_base::compile_fragment_declaration(f, var)
1034    }
1035
1036    fn compile_wwma_fragment_ident(
1037        _f: &mut std::fmt::Formatter<'_>,
1038        _ident: &FragmentIdent<Self>,
1039    ) -> std::fmt::Result {
1040        // not used
1041        Ok(())
1042    }
1043
1044    fn compile_wmma_fragment_layout(
1045        _f: &mut std::fmt::Formatter<'_>,
1046        _layout: &FragmentLayout<Self>,
1047    ) -> std::fmt::Result {
1048        // not used
1049        Ok(())
1050    }
1051
1052    fn compile_wmma_fragment(
1053        f: &mut std::fmt::Formatter<'_>,
1054        fragment: &Fragment<Self>,
1055    ) -> std::fmt::Result {
1056        let ty = fragment.elem;
1057        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
1058        let m = fragment.m;
1059        let n = fragment.n;
1060        let k = fragment.k;
1061        if m != 8 || n != 8 || k != 8 {
1062            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1063        }
1064        write!(f, "simdgroup_{ty}8x8")
1065    }
1066
1067    fn compile_wmma_instruction(
1068        f: &mut std::fmt::Formatter<'_>,
1069        instruction: &WmmaInstruction<Self>,
1070    ) -> std::fmt::Result {
1071        match instruction {
1072            WmmaInstruction::Fill { frag, value } => {
1073                match frag {
1074                    Variable::WmmaFragment { .. } => {
1075                        let ty = frag.elem();
1076                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
1077                        writeln!(
1078                            f,
1079                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1080                        )
1081                    }
1082                    _ => panic!("should be a fragment"),
1083                }
1084            }
1085            WmmaInstruction::Load {
1086                frag,
1087                value,
1088                stride,
1089                offset,
1090                layout: _layout,
1091            } => {
1092                let transpose = match frag {
1093                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1094                        Some(FragmentLayout::RowMajor) => false,
1095                        Some(FragmentLayout::ColMajor) => true,
1096                        _ => false,
1097                    },
1098                    _ => panic!("should be a fragment"),
1099                };
1100                let item = value.item();
1101                if item.vectorization > 1 {
1102                    let elem = item.elem;
1103                    match value {
1104                        Variable::GlobalInputArray(..) => writeln!(
1105                            f,
1106                            "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1107                        ),
1108                        Variable::SharedArray(..) => writeln!(
1109                            f,
1110                            "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1111                        ),
1112                        _ => panic!(
1113                            "Vectorized wmma load is only supported from global or shared memory."
1114                        ),
1115                    }
1116                } else {
1117                    writeln!(
1118                        f,
1119                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1120                    )
1121                }
1122            }
1123            WmmaInstruction::Execute {
1124                frag_a: a,
1125                frag_b: b,
1126                frag_c: c,
1127                frag_d: d,
1128                ..
1129            } => {
1130                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1131            }
1132            WmmaInstruction::Store {
1133                output,
1134                frag,
1135                stride,
1136                offset,
1137                layout: _layout,
1138            } => {
1139                let item = output.item();
1140                let mut reinterpret_cast = item.vectorization > 1;
1141                let elem = match item.elem {
1142                    Elem::BF16 => {
1143                        reinterpret_cast = true;
1144                        Elem::F16
1145                    }
1146                    _ => item.elem,
1147                };
1148                if reinterpret_cast {
1149                    writeln!(
1150                        f,
1151                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1152                    )
1153                } else {
1154                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1155                }?;
1156                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1157            }
1158            WmmaInstruction::Cast { input, output } => {
1159                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1160                let ty = match output {
1161                    Variable::WmmaFragment { frag, .. } => frag.elem,
1162                    _ => panic!("should be a fragment"),
1163                };
1164                match ty {
1165                    Elem::BF16 => {
1166                        let addr_space = Self::address_space_for_variable(output);
1167                        let elem = Elem::<Self>::F16;
1168                        // TODO: to test with benchmarks
1169
1170                        writeln!(
1171                            f,
1172                            "for(int e=0; e<8; e++) {{
1173    {ty} elem = {ty}({input}.thread_elements()[e]);
1174    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1175}}"
1176                        )
1177                    }
1178                    _ => {
1179                        writeln!(
1180                            f,
1181                            "for(int e=0; e<8; e++) {{
1182    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1183}}"
1184                        )
1185                    }
1186                }
1187            }
1188            WmmaInstruction::ExecuteManual {
1189                shape,
1190                frag_a,
1191                frag_b,
1192                frag_c,
1193                frag_d,
1194            } => {
1195                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1196            }
1197            WmmaInstruction::ExecuteScaled {
1198                shape,
1199                frag_a,
1200                frag_b,
1201                frag_c,
1202                frag_d,
1203                scales_a,
1204                scales_b,
1205                scales_factor,
1206            } => Self::compile_scaled_mma(
1207                f,
1208                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1209                *scales_a,
1210                *scales_b,
1211                *scales_factor,
1212            ),
1213            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1214                f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1215            }
1216        }
1217    }
1218
1219    fn compile_manual_mma(
1220        f: &mut std::fmt::Formatter<'_>,
1221        _mma: shared::ManualMma<Self>,
1222    ) -> std::fmt::Result {
1223        f.write_str("#error manual mma not supported on Metal\n")
1224    }
1225
1226    fn compile_scaled_mma(
1227        f: &mut std::fmt::Formatter<'_>,
1228        _mma: shared::ManualMma<Self>,
1229        _scales_a: Variable<Self>,
1230        _scales_b: Variable<Self>,
1231        _scales_factor: u32,
1232    ) -> std::fmt::Result {
1233        f.write_str("#error scaled mma not supported on Metal\n")
1234    }
1235
1236    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1237        let types = vec![
1238            (
1239                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1240                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1241                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1242            ),
1243            (
1244                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1245                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1246                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247            ),
1248            (
1249                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1250                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1251                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1252            ),
1253            (
1254                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1255                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1256                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1257            ),
1258        ];
1259        types
1260            .into_iter()
1261            .map(|(a_type, b_type, cd_type)| MmaConfig {
1262                a_type,
1263                b_type,
1264                cd_type,
1265                m: 8,
1266                n: 8,
1267                k: 8,
1268            })
1269            .collect()
1270    }
1271
1272    fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1273        Vec::new()
1274    }
1275}
1276
1277// Coop Matrices dialect
1278
1279impl DialectProcessors<Self> for MslDialect {
1280    fn processors() -> Vec<Box<dyn gpu::Processor>> {
1281        Vec::new()
1282    }
1283}