cubecl_cpp/metal/
dialect.rs

1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5    Dialect,
6    shared::{
7        self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8        DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
9        DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
10        FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
11        SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
12    },
13};
14use cubecl_core::{
15    compute::{Location, Visibility},
16    ir::{self as gpu},
17};
18use cubecl_runtime::MmaConfig;
19
20use super::{
21    AddressSpace, Extension,
22    arch::MetalArchitecture,
23    extension::{format_ffs, format_mulhi},
24    format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
25};
26
27#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
28pub struct MslDialect {}
29
30// Base dialect
31
32impl Dialect for MslDialect {
33    type Architecture = MetalArchitecture;
34}
35
36impl MslDialect {
37    fn warp_op_vectorized(
38        f: &mut core::fmt::Formatter<'_>,
39        input: &Variable<Self>,
40        out: &Variable<Self>,
41        simd_op_prefix: &str,
42        simd_op_suffix: &str,
43    ) -> core::fmt::Result {
44        let out = out.fmt_left();
45        let vectorization = input.item().vectorization;
46
47        f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
48
49        for k in 0..vectorization {
50            let index = if vectorization > 1 {
51                format!(".i_{k}")
52            } else {
53                String::new()
54            };
55            let comma = if k + 1 < vectorization { "," } else { "" };
56
57            writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
58        }
59
60        f.write_fmt(format_args!("}};\n"))
61    }
62}
63
64impl DialectWarpReduceCompiler<Self> for MslDialect {
65    fn warp_reduce_sum(
66        f: &mut core::fmt::Formatter<'_>,
67        input: &Variable<Self>,
68        out: &Variable<Self>,
69    ) -> core::fmt::Result {
70        Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
71    }
72    fn warp_reduce_prod(
73        f: &mut core::fmt::Formatter<'_>,
74        input: &Variable<Self>,
75        out: &Variable<Self>,
76    ) -> core::fmt::Result {
77        Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
78    }
79    fn warp_reduce_max(
80        f: &mut core::fmt::Formatter<'_>,
81        input: &Variable<Self>,
82        out: &Variable<Self>,
83    ) -> core::fmt::Result {
84        Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
85    }
86    fn warp_reduce_min(
87        f: &mut core::fmt::Formatter<'_>,
88        input: &Variable<Self>,
89        out: &Variable<Self>,
90    ) -> core::fmt::Result {
91        Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
92    }
93    fn warp_reduce_all(
94        f: &mut core::fmt::Formatter<'_>,
95        input: &Variable<Self>,
96        out: &Variable<Self>,
97    ) -> core::fmt::Result {
98        Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
99    }
100    fn warp_reduce_any(
101        f: &mut core::fmt::Formatter<'_>,
102        input: &Variable<Self>,
103        out: &Variable<Self>,
104    ) -> core::fmt::Result {
105        Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
106    }
107    fn warp_reduce_sum_inclusive(
108        f: &mut core::fmt::Formatter<'_>,
109        input: &Variable<Self>,
110        out: &Variable<Self>,
111    ) -> core::fmt::Result {
112        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
113    }
114    fn warp_reduce_prod_inclusive(
115        f: &mut core::fmt::Formatter<'_>,
116        input: &Variable<Self>,
117        out: &Variable<Self>,
118    ) -> core::fmt::Result {
119        Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
120    }
121    fn warp_reduce_sum_exclusive(
122        f: &mut core::fmt::Formatter<'_>,
123        input: &Variable<Self>,
124        out: &Variable<Self>,
125    ) -> core::fmt::Result {
126        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
127    }
128    fn warp_reduce_prod_exclusive(
129        f: &mut core::fmt::Formatter<'_>,
130        input: &Variable<Self>,
131        out: &Variable<Self>,
132    ) -> core::fmt::Result {
133        Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
134    }
135}
136
137// Includes
138
139impl DialectIncludes<Self> for MslDialect {
140    type Extension = Extension<Self>;
141
142    fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
143        write!(
144            f,
145            "
146#include <metal_stdlib>
147using namespace metal;
148"
149        )?;
150        Ok(())
151    }
152
153    fn compile_extensions(
154        f: &mut std::fmt::Formatter<'_>,
155        extensions: &[Self::Extension],
156    ) -> std::fmt::Result {
157        for extension in extensions {
158            match extension {
159                Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
160                Extension::Ffs(elem) => format_ffs(f, elem)?,
161                Extension::MulHi(elem) => format_mulhi(f, elem)?,
162                Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
163                Extension::NoExtension => {}
164            }
165        }
166        Ok(())
167    }
168
169    fn register_instruction_extension(
170        extensions: &mut Vec<Self::Extension>,
171        instruction: &Instruction<Self>,
172    ) {
173        let mut register_extension = |extension: Self::Extension| {
174            if !extensions.contains(&extension) {
175                extensions.push(extension);
176            }
177        };
178        #[allow(clippy::single_match)]
179        match instruction {
180            shared::Instruction::<Self>::Erf(instruction) => {
181                register_extension(Extension::Erf(
182                    instruction.input.elem(),
183                    instruction.out.elem(),
184                ));
185            }
186            shared::Instruction::<Self>::FindFirstSet(instruction) => {
187                let input_elem = instruction.input.elem();
188                match input_elem {
189                    Elem::U32 | Elem::U64 => {
190                        register_extension(Extension::Ffs(instruction.input.elem()));
191                    }
192                    Elem::I32 => {
193                        register_extension(Extension::Ffs(Elem::<Self>::U32));
194                        register_extension(Extension::Ffs(instruction.input.elem()));
195                    }
196                    Elem::I64 => {
197                        register_extension(Extension::Ffs(Elem::<Self>::U64));
198                        register_extension(Extension::Ffs(instruction.input.elem()));
199                    }
200                    _ => {
201                        register_extension(Extension::Ffs(Elem::<Self>::U32));
202                    }
203                }
204            }
205            shared::Instruction::<Self>::HiMul(instruction) => {
206                register_extension(Extension::MulHi(instruction.out.elem()));
207            }
208            shared::Instruction::<Self>::Tanh(instruction) => {
209                register_extension(Extension::SafeTanh(instruction.input.item()));
210            }
211            _ => {}
212        }
213    }
214
215    fn register_warp_instruction_extension(
216        _extensions: &mut Vec<Self::Extension>,
217        _instruction: &WarpInstruction<Self>,
218    ) {
219    }
220}
221
222// Types
223
224impl DialectTypes<Self> for MslDialect {
225    fn item_can_be_optimized() -> bool {
226        false
227    }
228
229    fn compile_type_definitions(
230        f: &mut std::fmt::Formatter<'_>,
231        items: &std::collections::HashSet<crate::shared::Item<Self>>,
232        _scalars: &[(Elem<Self>, usize)],
233        _flags: &Flags,
234    ) -> std::fmt::Result {
235        for item in items.iter() {
236            let elem = item.elem;
237            let size = item.vectorization;
238            let alignment = elem.size() * size;
239            if size > 1 {
240                write!(
241                    f,
242                    "
243struct alignas({alignment}) {item} {{"
244                )?;
245
246                for i in 0..size {
247                    write!(
248                        f,
249                        "
250    {elem} i_{i};"
251                    )?;
252                }
253
254                f.write_str("\n};\n")?;
255            }
256        }
257        Ok(())
258    }
259
260    fn compile_elem(
261        f: &mut std::fmt::Formatter<'_>,
262        elem: &shared::Elem<Self>,
263        _words: bool,
264    ) -> std::fmt::Result {
265        // we always use the word form of types
266        match elem {
267            shared::Elem::FP4(_)
268            | shared::Elem::FP4x2(_)
269            | shared::Elem::FP6(_)
270            | shared::Elem::FP6x2(_)
271            | shared::Elem::FP8(_)
272            | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in Metal"),
273            shared::Elem::F16 => f.write_str("half"),
274            shared::Elem::F16x2 => panic!("type F162 not supported!"),
275            shared::Elem::F32 => f.write_str("float"),
276            shared::Elem::F64 => panic!("type double not supported!"),
277            shared::Elem::BF16 => f.write_str("bfloat"),
278            shared::Elem::BF16x2 => panic!("type BF162 not supported!"),
279            shared::Elem::TF32 => f.write_str("float"),
280            shared::Elem::I8 => f.write_str("char"),
281            shared::Elem::I16 => f.write_str("short"),
282            shared::Elem::I32 => f.write_str("int"),
283            shared::Elem::I64 => f.write_str("long"),
284            shared::Elem::U8 => f.write_str("uchar"),
285            shared::Elem::U16 => f.write_str("ushort"),
286            shared::Elem::U32 => f.write_str("uint"),
287            shared::Elem::U64 => f.write_str("uint64_t"), // or unsigned long
288            shared::Elem::Bool => f.write_str("bool"),
289            shared::Elem::Atomic(inner) => inner.fmt(f),
290            shared::Elem::_Dialect(_) => Ok(()),
291        }
292    }
293
294    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
295        if 1 == item.vectorization {
296            return write!(f, "{}", item.elem);
297        }
298        if item.native {
299            write!(f, "{}{}", item.elem, item.vectorization)
300        } else {
301            write!(f, "{}_{}", item.elem, item.vectorization)
302        }
303    }
304
305    fn compile_atomic_kind(
306        f: &mut std::fmt::Formatter<'_>,
307        kind: &AtomicKind<Self>,
308    ) -> std::fmt::Result {
309        match kind {
310            AtomicKind::I32 => write!(f, "atomic_int"),
311            AtomicKind::I64 => panic!("I64 atomic kind no supported."),
312            AtomicKind::U32 => write!(f, "atomic_uint"),
313            AtomicKind::U64 => write!(f, "atomic_ulong"),
314            AtomicKind::F16 => panic!("F16 atomic kind no supported."),
315            AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
316            AtomicKind::F32 => write!(f, "atomic_float"), // needs metal 3
317            AtomicKind::F64 => panic!("F64 atomic kind no supported."),
318            AtomicKind::_Dialect(_) => Ok(()),
319        }
320    }
321
322    fn address_space_for_variable(variable: &Variable<Self>) -> String {
323        format!("{} ", AddressSpace::from(variable))
324    }
325
326    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        write!(f, "thread")
328    }
329
330    fn compile_shared_memory_declaration(
331        f: &mut std::fmt::Formatter<'_>,
332        shared: &SharedMemory<Self>,
333    ) -> std::fmt::Result {
334        let item = shared.item;
335        let index = shared.index;
336        let offset = shared.offset;
337        let size = shared.length;
338        let size_bytes = size * shared.item.size() as u32;
339        writeln!(f, "// Shared memory size: {size}, {size_bytes} bytes")?;
340        writeln!(
341            f,
342            "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
343        )
344    }
345}
346
347// Kernel argument bindings
348
349impl DialectBindings<Self> for MslDialect {
350    fn compile_kernel_signature(
351        f: &mut std::fmt::Formatter<'_>,
352        kernel_name: &str,
353        tensor_maps: &[Binding<Self>],
354        buffers: &[Binding<Self>],
355        scalars: &[(Elem<Self>, usize)],
356        flags: &Flags,
357    ) -> std::fmt::Result {
358        write!(
359            (f),
360            "
361[[kernel]]
362void {kernel_name}("
363        )?;
364        // Global bindings args
365        let mut buffer_idx = 0;
366        debug_assert!(
367            tensor_maps.is_empty(),
368            "Tensor maps aren't supported for metal"
369        );
370        for (i, b) in buffers.iter().enumerate() {
371            format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
372        }
373        if flags.static_meta_length > 0 {
374            let binding = Binding {
375                id: 0,
376                item: Item::scalar(Elem::<Self>::U32, true),
377                location: Location::Storage,
378                size: None,
379                vis: Visibility::Read,
380            };
381            format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
382        }
383        for (elem, _) in scalars.iter() {
384            let binding = Binding {
385                id: 0,
386                item: Item::scalar(*elem, true),
387                location: Location::Storage,
388                size: None,
389                vis: Visibility::Read,
390            };
391
392            let name = format!("scalars_{elem}");
393            format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
394        }
395
396        // Global metal builtins args
397        let builtins = vec![
398            (
399                flags.indexes.absolute_pos_tuple,
400                Variable::<Self>::AbsolutePosBaseName,
401            ),
402            (
403                flags.indexes.cube_dim_tuple,
404                Variable::<Self>::CubeDimBaseName,
405            ),
406            (
407                flags.indexes.cube_count_tuple,
408                Variable::<Self>::CubeCountBaseName,
409            ),
410            (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
411            (
412                flags.indexes.unit_pos_tuple,
413                Variable::<Self>::UnitPosBaseName,
414            ),
415            (
416                flags.indexes.cube_pos_tuple,
417                Variable::<Self>::CubePosBaseName,
418            ),
419            (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
420            (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
421            (flags.indexes.plane_index, Variable::<Self>::PlanePos),
422        ];
423        let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
424        builtins
425            .iter()
426            .filter(|(cond, _)| *cond)
427            .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
428        f.write_str("\n)")
429    }
430
431    fn compile_bindings_body(
432        f: &mut std::fmt::Formatter<'_>,
433        body: &shared::Body<Self>,
434    ) -> std::fmt::Result {
435        if !body.shared_memories.is_empty() {
436            let size = body
437                .shared_memories
438                .iter()
439                .map(|it| it.offset + it.size())
440                .max()
441                .unwrap();
442
443            writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
444        }
445        Ok(())
446    }
447}
448
449// Cube builtins dialect
450
451impl DialectCubeBuiltins<Self> for MslDialect {
452    /// Depending on the dialect available built-in variables the
453    /// inclusion rules might change.
454    /// For instance in metal we have a built-in for the Unit plane position
455    /// so we don't rely on other builtins.
456    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
457        let absolute_pos = flags.absolute_pos;
458        let cube_count = flags.cube_count;
459        let cube_dim = flags.cube_dim;
460        let cube_pos = flags.cube_pos;
461        let plane_dim_checked = flags.plane_dim_checked;
462        let plane_index = flags.plane_index;
463        let unit_pos = flags.unit_pos;
464        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
465        let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
466        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
467        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
468        let cluster_pos = flags.cluster_pos;
469        let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
470        let unit_pos_plane = flags.unit_pos_plane || plane_index;
471        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
472        CubeIndexFlags {
473            absolute_pos_tuple,
474            absolute_pos,
475            cube_count_tuple,
476            cube_count,
477            cube_dim_tuple,
478            cube_dim,
479            cube_pos_tuple,
480            cube_pos,
481            plane_dim,
482            plane_dim_checked,
483            plane_index,
484            unit_pos_tuple,
485            unit_pos,
486            unit_pos_plane,
487            cluster_pos,
488        }
489    }
490
491    fn compile_absolute_pos_tuple_computation(
492        _f: &mut std::fmt::Formatter<'_>,
493    ) -> std::fmt::Result {
494        // no need to compute it on metal as there is y a built-in for it
495        Ok(())
496    }
497
498    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499        f.write_str("thread_pos_in_grid")
500    }
501
502    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        f.write_str("thread_index_in_grid")
504    }
505
506    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507        Self::compile_absolute_pos_base_name(f)?;
508        write!(f, ".x")
509    }
510
511    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512        Self::compile_absolute_pos_base_name(f)?;
513        write!(f, ".y")
514    }
515
516    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517        Self::compile_absolute_pos_base_name(f)?;
518        write!(f, ".z")
519    }
520
521    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522        f.write_str("threadgroups_per_grid")
523    }
524
525    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        f.write_str("total_threadgroups_in_grid")
527    }
528
529    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530        Self::compile_cube_count_base_name(f)?;
531        write!(f, ".x")
532    }
533
534    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        Self::compile_cube_count_base_name(f)?;
536        write!(f, ".y")
537    }
538
539    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540        Self::compile_cube_count_base_name(f)?;
541        write!(f, ".z")
542    }
543
544    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545        f.write_str("threads_per_threadgroup")
546    }
547
548    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549        f.write_str("total_thread_in_threadgroup")
550    }
551
552    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        Self::compile_cube_dim_base_name(f)?;
554        write!(f, ".x")
555    }
556
557    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558        Self::compile_cube_dim_base_name(f)?;
559        write!(f, ".y")
560    }
561
562    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
563        Self::compile_cube_dim_base_name(f)?;
564        write!(f, ".z")
565    }
566
567    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568        f.write_str("threadgroup_pos_in_grid")
569    }
570
571    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572        f.write_str("threadgroup_index_in_grid")
573    }
574
575    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576        Self::compile_cube_pos_base_name(f)?;
577        write!(f, ".x")
578    }
579
580    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        Self::compile_cube_pos_base_name(f)?;
582        write!(f, ".y")
583    }
584
585    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586        Self::compile_cube_pos_base_name(f)?;
587        write!(f, ".z")
588    }
589
590    fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591        // no need to compute it on metal as there is y a built-in for it
592        Ok(())
593    }
594
595    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        f.write_str("thread_pos_in_threadgroup")
597    }
598
599    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
600        f.write_str("thread_index_in_threadgroup")
601    }
602
603    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604        Self::compile_unit_pos_base_name(f)?;
605        write!(f, ".x")
606    }
607
608    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609        Self::compile_unit_pos_base_name(f)?;
610        write!(f, ".y")
611    }
612
613    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614        Self::compile_unit_pos_base_name(f)?;
615        write!(f, ".z")
616    }
617
618    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
619        f.write_str("simd_size")
620    }
621
622    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623        f.write_str("threads_per_simdgroup_checked")
624    }
625
626    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627        f.write_str("simd_group_id")
628    }
629
630    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        f.write_str("simd_lane_id")
632    }
633}
634
635// Instructions
636
637impl DialectInstructions<Self> for MslDialect {
638    // atomics
639    fn compile_atomic_add(
640        f: &mut std::fmt::Formatter<'_>,
641        lhs: &Variable<Self>,
642        rhs: &Variable<Self>,
643        out: &Variable<Self>,
644    ) -> std::fmt::Result {
645        let out = out.fmt_left();
646        writeln!(
647            f,
648            "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
649        )
650    }
651
652    fn compile_atomic_and(
653        f: &mut std::fmt::Formatter<'_>,
654        lhs: &Variable<Self>,
655        rhs: &Variable<Self>,
656        out: &Variable<Self>,
657    ) -> std::fmt::Result {
658        let out = out.fmt_left();
659        writeln!(
660            f,
661            "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
662        )
663    }
664
665    fn compile_atomic_cas(
666        f: &mut std::fmt::Formatter<'_>,
667        input: &Variable<Self>,
668        cmp: &Variable<Self>,
669        val: &Variable<Self>,
670        out: &Variable<Self>,
671    ) -> std::fmt::Result {
672        let out = out.fmt_left();
673        writeln!(
674            f,
675            "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
676        )
677    }
678
679    fn compile_atomic_load(
680        f: &mut std::fmt::Formatter<'_>,
681        input: &Variable<Self>,
682        out: &Variable<Self>,
683    ) -> std::fmt::Result {
684        let out = out.fmt_left();
685        writeln!(
686            f,
687            "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
688        )
689    }
690
691    fn compile_atomic_max(
692        f: &mut std::fmt::Formatter<'_>,
693        lhs: &Variable<Self>,
694        rhs: &Variable<Self>,
695        out: &Variable<Self>,
696    ) -> std::fmt::Result {
697        let out = out.fmt_left();
698        writeln!(
699            f,
700            "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
701        )
702    }
703
704    fn compile_atomic_min(
705        f: &mut std::fmt::Formatter<'_>,
706        lhs: &Variable<Self>,
707        rhs: &Variable<Self>,
708        out: &Variable<Self>,
709    ) -> std::fmt::Result {
710        let out = out.fmt_left();
711        writeln!(
712            f,
713            "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
714        )
715    }
716
717    fn compile_atomic_or(
718        f: &mut std::fmt::Formatter<'_>,
719        lhs: &Variable<Self>,
720        rhs: &Variable<Self>,
721        out: &Variable<Self>,
722    ) -> std::fmt::Result {
723        let out = out.fmt_left();
724        writeln!(
725            f,
726            "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
727        )
728    }
729
730    fn compile_atomic_store(
731        f: &mut std::fmt::Formatter<'_>,
732        input: &Variable<Self>,
733        out: &Variable<Self>,
734    ) -> std::fmt::Result {
735        writeln!(
736            f,
737            "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
738        )
739    }
740
741    fn compile_atomic_sub(
742        f: &mut std::fmt::Formatter<'_>,
743        lhs: &Variable<Self>,
744        rhs: &Variable<Self>,
745        out: &Variable<Self>,
746    ) -> std::fmt::Result {
747        let out = out.fmt_left();
748        writeln!(
749            f,
750            "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
751        )
752    }
753
754    fn compile_atomic_swap(
755        f: &mut std::fmt::Formatter<'_>,
756        lhs: &Variable<Self>,
757        rhs: &Variable<Self>,
758        out: &Variable<Self>,
759    ) -> std::fmt::Result {
760        let out = out.fmt_left();
761        writeln!(
762            f,
763            "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
764        )
765    }
766
767    fn compile_atomic_xor(
768        f: &mut std::fmt::Formatter<'_>,
769        lhs: &Variable<Self>,
770        rhs: &Variable<Self>,
771        out: &Variable<Self>,
772    ) -> std::fmt::Result {
773        let out = out.fmt_left();
774        writeln!(
775            f,
776            "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
777        )
778    }
779
780    fn compile_saturating_add(
781        f: &mut std::fmt::Formatter<'_>,
782        lhs: impl Display,
783        rhs: impl Display,
784        _item: Item<Self>,
785    ) -> std::fmt::Result {
786        write!(f, "addsat({lhs}, {rhs})")
787    }
788
789    fn compile_saturating_sub(
790        f: &mut std::fmt::Formatter<'_>,
791        lhs: impl Display,
792        rhs: impl Display,
793        _item: Item<Self>,
794    ) -> std::fmt::Result {
795        write!(f, "subsat({lhs}, {rhs})")
796    }
797
798    // debug
799    fn compile_instruction_printf(
800        f: &mut std::fmt::Formatter<'_>,
801        format_string: &str,
802        args: &[Variable<Self>],
803    ) -> std::fmt::Result {
804        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
805        let args = match args.is_empty() {
806            true => "".to_string(),
807            false => format!(", {}", args.join(",")),
808        };
809        writeln!(f, "os_log_default.log({format_string:?}{args});")
810    }
811
812    // logs
813    fn compile_instruction_log1p_scalar<T: Component<Self>>(
814        f: &mut std::fmt::Formatter<'_>,
815        input: T,
816    ) -> std::fmt::Result {
817        match input.elem() {
818            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
819                write!(f, "log(half(1.0f) + {input})")
820            }
821            _ => write!(f, "log(1.0f + {input})"),
822        }
823    }
824
825    // sync
826    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827        writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
828    }
829
830    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
831        writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
832    }
833
834    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
835        writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
836    }
837
838    // trigo
839    fn compile_instruction_tanh_scalar<T: Component<Self>>(
840        f: &mut std::fmt::Formatter<'_>,
841        input: T,
842    ) -> std::fmt::Result {
843        write!(f, "safe_tanh_scalar({input})")
844    }
845
846    // unary
847    fn compile_instruction_find_first_set<T: Component<Self>>(
848        f: &mut std::fmt::Formatter<'_>,
849        input: T,
850        out_elem: Elem<Self>,
851    ) -> std::fmt::Result {
852        write!(f, "{out_elem}(")?;
853        match input.elem() {
854            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
855            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
856            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
857        }?;
858        write!(f, ")")
859    }
860
861    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
862        f: &mut std::fmt::Formatter<'_>,
863        input: T,
864        out_elem: Elem<Self>,
865    ) -> std::fmt::Result {
866        write!(f, "{out_elem}(clz({input}))")
867    }
868
869    fn compile_instruction_popcount_scalar<T: Component<Self>>(
870        f: &mut std::fmt::Formatter<'_>,
871        input: T,
872        out_elem: Elem<Self>,
873    ) -> std::fmt::Result {
874        write!(f, "{out_elem}(")?;
875        match input.elem() {
876            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
877            _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
878        }?;
879        write!(f, ")")
880    }
881
882    fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
883        f: &mut std::fmt::Formatter<'_>,
884        input: T,
885        out_elem: Elem<Self>,
886    ) -> std::fmt::Result {
887        write!(f, "{out_elem}(")?;
888        match out_elem {
889            Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
890            _ => write!(
891                f,
892                "reverse_bits({}) >> {}",
893                shared::unary::zero_extend(input),
894                (size_of::<u32>() - out_elem.size()) * 8
895            ),
896        }?;
897        write!(f, ")")
898    }
899
900    // others
901    fn compile_instruction_max_function_name(
902        f: &mut std::fmt::Formatter<'_>,
903        _item: Item<Self>,
904    ) -> std::fmt::Result {
905        write!(f, "max")
906    }
907
908    fn compile_instruction_min_function_name(
909        f: &mut std::fmt::Formatter<'_>,
910        _item: Item<Self>,
911    ) -> std::fmt::Result {
912        write!(f, "min")
913    }
914
915    fn compile_instruction_powf(
916        f: &mut std::fmt::Formatter<'_>,
917        lhs: &str,
918        rhs: &str,
919        elem: Elem<Self>,
920    ) -> std::fmt::Result {
921        write!(f, "pow({lhs}, {elem}({rhs}))")
922    }
923
924    fn compile_instruction_half_function_name_prefix() -> &'static str {
925        ""
926    }
927
928    fn compile_instruction_half2_function_name_prefix() -> &'static str {
929        ""
930    }
931
932    // Warp
933    fn compile_warp_shuffle(
934        f: &mut std::fmt::Formatter<'_>,
935        var: &str,
936        source: &str,
937    ) -> std::fmt::Result {
938        write!(f, "simd_shuffle({var}, {source})")
939    }
940
941    fn compile_warp_shuffle_xor(
942        f: &mut std::fmt::Formatter<'_>,
943        var: &str,
944        _elem: &Elem<Self>,
945        offset: &str,
946    ) -> std::fmt::Result {
947        write!(f, "simd_shuffle_xor({var}, {offset})")
948    }
949
950    fn compile_warp_shuffle_up(
951        f: &mut std::fmt::Formatter<'_>,
952        var: &str,
953        offset: &str,
954    ) -> std::fmt::Result {
955        write!(f, "simd_shuffle_up({var}, {offset})")
956    }
957
958    fn compile_warp_shuffle_down(
959        f: &mut std::fmt::Formatter<'_>,
960        var: &str,
961        offset: &str,
962    ) -> std::fmt::Result {
963        write!(f, "simd_shuffle_down({var}, {offset})")
964    }
965
966    fn compile_warp_all<T: Component<Self>>(
967        f: &mut std::fmt::Formatter<'_>,
968        input: &T,
969    ) -> std::fmt::Result {
970        write!(f, "simd_all({input})")
971    }
972
973    fn compile_warp_any<T: Component<Self>>(
974        f: &mut std::fmt::Formatter<'_>,
975        input: &T,
976    ) -> std::fmt::Result {
977        write!(f, "simd_any({input})")
978    }
979
980    fn compile_warp_ballot(
981        f: &mut std::fmt::Formatter<'_>,
982        input: &Variable<Self>,
983        out_elem: &Elem<Self>,
984    ) -> std::fmt::Result {
985        write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
986    }
987}
988
989// Coop Matrices dialect
990
991impl DialectWmmaCompiler<Self> for MslDialect {
992    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
993        writeln!(f, "#include <metal_simdgroup_matrix>")
994    }
995
996    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
997        // not used
998        Ok(())
999    }
1000
1001    fn compile_wmma_fragment_declaration(
1002        f: &mut std::fmt::Formatter<'_>,
1003        var: &crate::shared::Variable<MslDialect>,
1004    ) -> std::fmt::Result {
1005        wmma_api_base::compile_fragment_declaration(f, var)
1006    }
1007
1008    fn compile_wwma_fragment_ident(
1009        _f: &mut std::fmt::Formatter<'_>,
1010        _ident: &FragmentIdent<Self>,
1011    ) -> std::fmt::Result {
1012        // not used
1013        Ok(())
1014    }
1015
1016    fn compile_wmma_fragment_layout(
1017        _f: &mut std::fmt::Formatter<'_>,
1018        _layout: &FragmentLayout<Self>,
1019    ) -> std::fmt::Result {
1020        // not used
1021        Ok(())
1022    }
1023
1024    fn compile_wmma_fragment(
1025        f: &mut std::fmt::Formatter<'_>,
1026        fragment: &Fragment<Self>,
1027    ) -> std::fmt::Result {
1028        let ty = fragment.elem;
1029        // currently as of Metal 3.2 only fragments of 8x8x8 are supported
1030        let m = fragment.m;
1031        let n = fragment.n;
1032        let k = fragment.k;
1033        if m != 8 || n != 8 || k != 8 {
1034            panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1035        }
1036        write!(f, "simdgroup_{ty}8x8")
1037    }
1038
1039    fn compile_wmma_instruction(
1040        f: &mut std::fmt::Formatter<'_>,
1041        instruction: &WmmaInstruction<Self>,
1042    ) -> std::fmt::Result {
1043        match instruction {
1044            WmmaInstruction::Fill { frag, value } => {
1045                match frag {
1046                    Variable::WmmaFragment { .. } => {
1047                        let ty = frag.elem();
1048                        // Only 8x8x8 fragemts are supported. Check is done at fragment compilation time.
1049                        writeln!(
1050                            f,
1051                            "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1052                        )
1053                    }
1054                    _ => panic!("should be a fragment"),
1055                }
1056            }
1057            WmmaInstruction::Load {
1058                frag,
1059                value,
1060                stride,
1061                offset,
1062                layout: _layout,
1063            } => {
1064                let transpose = match frag {
1065                    Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1066                        Some(FragmentLayout::RowMajor) => false,
1067                        Some(FragmentLayout::ColMajor) => true,
1068                        _ => false,
1069                    },
1070                    _ => panic!("should be a fragment"),
1071                };
1072                let item = value.item();
1073                if item.vectorization > 1 {
1074                    let elem = item.elem;
1075                    writeln!(
1076                        f,
1077                        "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1078                    )
1079                } else {
1080                    writeln!(
1081                        f,
1082                        "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1083                    )
1084                }
1085            }
1086            WmmaInstruction::Execute {
1087                frag_a: a,
1088                frag_b: b,
1089                frag_c: c,
1090                frag_d: d,
1091                ..
1092            } => {
1093                writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1094            }
1095            WmmaInstruction::Store {
1096                output,
1097                frag,
1098                stride,
1099                offset,
1100                layout: _layout,
1101            } => {
1102                let item = output.item();
1103                let mut reinterpret_cast = item.vectorization > 1;
1104                let elem = match item.elem {
1105                    Elem::BF16 => {
1106                        reinterpret_cast = true;
1107                        Elem::F16
1108                    }
1109                    _ => item.elem,
1110                };
1111                if reinterpret_cast {
1112                    writeln!(
1113                        f,
1114                        "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1115                    )
1116                } else {
1117                    writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1118                }?;
1119                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1120            }
1121            WmmaInstruction::Cast { input, output } => {
1122                writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1123                let ty = match output {
1124                    Variable::WmmaFragment { frag, .. } => frag.elem,
1125                    _ => panic!("should be a fragment"),
1126                };
1127                match ty {
1128                    Elem::BF16 => {
1129                        let addr_space = Self::address_space_for_variable(output);
1130                        let elem = Elem::<Self>::F16;
1131                        // TODO: to test with benchmarks
1132
1133                        writeln!(
1134                            f,
1135                            "for(int e=0; e<8; e++) {{
1136    {ty} elem = {ty}({input}.thread_elements()[e]);
1137    {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1138}}"
1139                        )
1140                    }
1141                    _ => {
1142                        writeln!(
1143                            f,
1144                            "for(int e=0; e<8; e++) {{
1145    {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1146}}"
1147                        )
1148                    }
1149                }
1150            }
1151            WmmaInstruction::ExecuteManual {
1152                shape,
1153                frag_a,
1154                frag_b,
1155                frag_c,
1156                frag_d,
1157            } => {
1158                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1159            }
1160            WmmaInstruction::ExecuteScaled {
1161                shape,
1162                frag_a,
1163                frag_b,
1164                frag_c,
1165                frag_d,
1166                scales_a,
1167                scales_b,
1168                scales_factor,
1169            } => Self::compile_scaled_mma(
1170                f,
1171                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1172                *scales_a,
1173                *scales_b,
1174                *scales_factor,
1175            ),
1176        }
1177    }
1178
1179    fn compile_manual_mma(
1180        _f: &mut std::fmt::Formatter<'_>,
1181        _mma: shared::ManualMma<Self>,
1182    ) -> std::fmt::Result {
1183        unimplemented!("Not supported")
1184    }
1185
1186    fn compile_scaled_mma(
1187        _f: &mut std::fmt::Formatter<'_>,
1188        _mma: shared::ManualMma<Self>,
1189        _scales_a: Variable<Self>,
1190        _scales_b: Variable<Self>,
1191        _scales_factor: u32,
1192    ) -> std::fmt::Result {
1193        unimplemented!("Not supported")
1194    }
1195
1196    fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1197        let types = vec![
1198            (
1199                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1200                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1201                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1202            ),
1203            (
1204                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1205                gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1206                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1207            ),
1208            (
1209                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1210                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1211                gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1212            ),
1213            (
1214                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1215                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1216                gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1217            ),
1218        ];
1219        types
1220            .into_iter()
1221            .map(|(a_type, b_type, cd_type)| MmaConfig {
1222                a_type,
1223                b_type,
1224                cd_type,
1225                m: 8,
1226                n: 8,
1227                k: 8,
1228            })
1229            .collect()
1230    }
1231
1232    fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1233        Vec::new()
1234    }
1235}
1236
1237// Coop Matrices dialect
1238
1239impl DialectProcessors<Self> for MslDialect {
1240    fn processors() -> Vec<Box<dyn gpu::Processor>> {
1241        Vec::new()
1242    }
1243}