Skip to main content

cubecl_cpp/shared/
dialect.rs

1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::{ConstantValue, Processor};
5
6use crate::shared::{
7    FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8    reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12    Architecture, AtomicKind, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, KernelArg, SharedMemory, Variable,
14    WarpInstruction, WmmaInstruction,
15};
16
17// Base dialect
18
19pub trait Dialect:
20    DialectIncludes<Self>
21    + DialectTypes<Self>
22    + DialectBindings<Self>
23    + DialectWarpReduceCompiler<Self>
24    + DialectCubeBuiltins<Self>
25    + DialectInstructions<Self>
26    + DialectWmmaCompiler<Self>
27    + DialectProcessors<Self>
28    + Default
29    + Clone
30    + Copy
31    + Debug
32    + Send
33    + Sync
34    + Eq
35    + Hash
36    + 'static
37{
38    type Architecture: Architecture;
39}
40
41// Includes
42
43pub trait DialectIncludes<D: Dialect> {
44    type Extension: Debug + Clone + Sync + Send;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
47    fn compile_extensions(
48        f: &mut std::fmt::Formatter<'_>,
49        extensions: &[Self::Extension],
50    ) -> std::fmt::Result;
51    fn register_instruction_extension(
52        extensions: &mut Vec<Self::Extension>,
53        instruction: &Instruction<D>,
54    );
55    fn register_warp_instruction_extension(
56        extensions: &mut Vec<Self::Extension>,
57        instruction: &WarpInstruction<D>,
58    );
59    #[allow(unused_variables)]
60    fn register_wmma_instruction_extension(
61        extensions: &mut Vec<Self::Extension>,
62        instruction: &WmmaInstruction<D>,
63    ) {
64    }
65}
66
67// Types
68
69pub trait DialectTypes<D: Dialect> {
70    fn item_can_be_optimized() -> bool;
71    fn compile_elem(
72        f: &mut std::fmt::Formatter<'_>,
73        elem: &Elem<D>,
74        word: bool,
75    ) -> std::fmt::Result;
76
77    fn compile_atomic_kind(
78        f: &mut std::fmt::Formatter<'_>,
79        kind: &AtomicKind<D>,
80    ) -> std::fmt::Result {
81        match kind {
82            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87            AtomicKind::F16x2 => write!(f, "{}", Elem::<D>::F16x2),
88            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
89            AtomicKind::BF16x2 => write!(f, "{}", Elem::<D>::BF16x2),
90            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
91            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
92            AtomicKind::_Dialect(_) => Ok(()),
93        }
94    }
95
96    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
97    fn compile_type_definitions(
98        f: &mut std::fmt::Formatter<'_>,
99        items: &HashSet<Item<D>>,
100        scalars: &[(Elem<D>, usize)],
101        info: &cubecl_core::Info,
102        flags: &Flags<D>,
103    ) -> std::fmt::Result;
104    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
105    fn compile_shared_memory_declaration(
106        f: &mut std::fmt::Formatter<'_>,
107        shared: &SharedMemory<D>,
108    ) -> std::fmt::Result {
109        match shared {
110            SharedMemory::Array {
111                index,
112                item,
113                length,
114                offset,
115                ..
116            } => {
117                let size_bytes = *length * item.size();
118                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
119                writeln!(
120                    f,
121                    "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
122                )
123            }
124            SharedMemory::Value {
125                index,
126                item,
127                offset,
128                ..
129            } => {
130                let size_bytes = item.size() as u32;
131                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
132                writeln!(
133                    f,
134                    "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
135                )
136            }
137        }
138    }
139    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
140        Ok(())
141    }
142    /// Address space (for Metal dialect only).
143    fn address_space_for_variable(_variable: &Variable<D>) -> String {
144        "".to_string()
145    }
146}
147
148// Kernel argument bindings
149
150pub trait DialectBindings<D: Dialect> {
151    fn compile_kernel_signature(
152        f: &mut std::fmt::Formatter<'_>,
153        kernel_name: &str,
154        tensor_maps: &[KernelArg<D>],
155        buffers: &[KernelArg<D>],
156        flags: &Flags<D>,
157    ) -> std::fmt::Result;
158    fn compile_bindings_body(
159        _f: &mut std::fmt::Formatter<'_>,
160        _body: &Body<D>,
161    ) -> std::fmt::Result {
162        Ok(())
163    }
164}
165
166// Cube builtins dialect
167
168pub trait DialectCubeBuiltins<D: Dialect> {
169    /// Depending on the dialect available built-in variables the
170    /// inclusion rules might change.
171    /// For instance in metal we have a built-in for the Unit plane position
172    /// but in other dialects there is none so we have to compute it using
173    /// other built-ins.
174    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
175        let unit_pos_plane = flags.unit_pos_plane;
176        let plane_dim_checked = flags.plane_dim_checked;
177        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
178        let plane_pos = flags.plane_pos;
179        let absolute_pos = flags.absolute_pos || unit_pos_plane;
180        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
181        let cube_dim = flags.cube_dim;
182        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
183        let unit_pos = flags.unit_pos;
184        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
185        let cube_count = flags.cube_count;
186        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
187        let cube_pos = flags.cube_pos;
188        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
189        let cluster_group = flags.cluster_pos;
190
191        CubeIndexFlags {
192            absolute_pos,
193            absolute_pos_tuple,
194            cube_count,
195            cube_count_tuple,
196            cube_dim,
197            cube_dim_tuple,
198            cube_pos,
199            cube_pos_tuple,
200            plane_dim,
201            plane_dim_checked,
202            plane_pos,
203            unit_pos_tuple,
204            unit_pos,
205            unit_pos_plane,
206            cluster_pos: cluster_group,
207        }
208    }
209
210    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        let variable = Variable::<D>::AbsolutePosBaseName;
212        let ty = variable.item();
213        let cube_pos_x = Variable::<D>::CubePosX;
214        let cube_pos_y = Variable::<D>::CubePosY;
215        let cube_pos_z = Variable::<D>::CubePosZ;
216        let cube_dim_x = Variable::<D>::CubeDimX;
217        let cube_dim_y = Variable::<D>::CubeDimY;
218        let cube_dim_z = Variable::<D>::CubeDimZ;
219        let unit_pos_x = Variable::<D>::UnitPosX;
220        let unit_pos_y = Variable::<D>::UnitPosY;
221        let unit_pos_z = Variable::<D>::UnitPosZ;
222        writeln!(
223            f,
224            "{ty} {variable} = make_{ty}(
225    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
226    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
227    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
228);"
229        )
230    }
231
232    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        f.write_str("absoluteIdx")
234    }
235
236    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        f.write_str("idxGlobal")
238    }
239
240    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        Self::compile_absolute_pos_base_name(f)?;
242        write!(f, ".x")
243    }
244
245    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        Self::compile_absolute_pos_base_name(f)?;
247        write!(f, ".y")
248    }
249
250    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        Self::compile_absolute_pos_base_name(f)?;
252        write!(f, ".z")
253    }
254
255    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        f.write_str("gridDim")
257    }
258
259    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.write_str("gridDimGlobal")
261    }
262
263    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        Self::compile_cube_count_base_name(f)?;
265        write!(f, ".x")
266    }
267
268    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        Self::compile_cube_count_base_name(f)?;
270        write!(f, ".y")
271    }
272
273    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        Self::compile_cube_count_base_name(f)?;
275        write!(f, ".z")
276    }
277
278    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        f.write_str("blockDim")
280    }
281
282    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.write_str("blockDimGlobal")
284    }
285
286    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        Self::compile_cube_dim_base_name(f)?;
288        write!(f, ".x")
289    }
290
291    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        Self::compile_cube_dim_base_name(f)?;
293        write!(f, ".y")
294    }
295
296    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        Self::compile_cube_dim_base_name(f)?;
298        write!(f, ".z")
299    }
300
301    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.write_str("blockIdx")
303    }
304
305    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        f.write_str("blockIdxGlobal")
307    }
308
309    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        Self::compile_cube_pos_base_name(f)?;
311        write!(f, ".x")
312    }
313
314    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        Self::compile_cube_pos_base_name(f)?;
316        write!(f, ".y")
317    }
318
319    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        Self::compile_cube_pos_base_name(f)?;
321        write!(f, ".z")
322    }
323
324    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        let variable = Variable::<D>::UnitPos;
326        let ty = variable.item();
327        let cube_dim_x = Variable::<D>::CubeDimX;
328        let cube_dim_y = Variable::<D>::CubeDimY;
329        let unit_pos_x = Variable::<D>::UnitPosX;
330        let unit_pos_y = Variable::<D>::UnitPosY;
331        let unit_pos_z = Variable::<D>::UnitPosZ;
332        writeln!(
333            f,
334            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
335        )
336    }
337
338    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.write_str("threadIdxGlobal")
340    }
341
342    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        f.write_str("threadIdx")
344    }
345
346    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        Self::compile_unit_pos_base_name(f)?;
348        write!(f, ".x")
349    }
350
351    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        Self::compile_unit_pos_base_name(f)?;
353        write!(f, ".y")
354    }
355
356    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        Self::compile_unit_pos_base_name(f)?;
358        write!(f, ".z")
359    }
360
361    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        f.write_str("warpSize")
363    }
364
365    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        f.write_str("warpSizeChecked")
367    }
368
369    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        let unit_pos_x = Variable::<D>::UnitPosX;
371        let plane_dim = Variable::<D>::PlaneDim;
372        write!(f, "{unit_pos_x} / {plane_dim}")
373    }
374
375    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
377        let plane_dim = Variable::<D>::PlaneDim;
378        let ty = plane_dim.item();
379        write!(f, "{ty}({absolute_pos}) % {plane_dim}")
380    }
381
382    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        write!(f, "0")
384    }
385    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        write!(f, "0")
387    }
388    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        write!(f, "0")
390    }
391    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        write!(f, "0")
393    }
394}
395
396// Instructions
397
398pub trait DialectInstructions<D: Dialect> {
399    // atomics
400    fn compile_atomic_add(
401        f: &mut std::fmt::Formatter<'_>,
402        lhs: &Variable<D>,
403        rhs: &Variable<D>,
404        out: &Variable<D>,
405    ) -> std::fmt::Result {
406        let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
407        let [lhs, rhs, out_optimized] = optimized.args;
408
409        let addr_space = D::address_space_for_variable(out);
410        let out_item = out.item();
411        let out = out.fmt_left();
412
413        match out_optimized.elem() {
414            Elem::I64 => writeln!(
415                f,
416                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
417                uint = Elem::<D>::U64
418            ),
419            Elem::F32 if out_item.vectorization > 1 => {
420                // Hacky but CUDA needs this to be the builtin vector type. Revisit if metal ever
421                // supports this
422                let vec_ty = format!("float{}", out_item.vectorization);
423                let out_tmp = Variable::tmp(out_optimized.item());
424                writeln!(
425                    f,
426                    "{vec_ty} {out_tmp} = atomicAdd(
427                    reinterpret_cast<{addr_space}{vec_ty}*>({lhs}),
428                    reinterpret_cast<const {addr_space}{vec_ty}&>({rhs}));",
429                )?;
430                writeln!(
431                    f,
432                    "{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
433                )
434            }
435            Elem::F16x2 | Elem::BF16x2 => {
436                let out_tmp = Variable::tmp(out_optimized.item());
437                writeln!(
438                    f,
439                    "{} = atomicAdd(
440                    reinterpret_cast<{addr_space}{}*>({lhs}),
441                    reinterpret_cast<const {addr_space}{}&>({rhs}));",
442                    out_tmp.fmt_left(),
443                    lhs.item(),
444                    rhs.item()
445                )?;
446                writeln!(
447                    f,
448                    "{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
449                )
450            }
451            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
452        }
453    }
454
455    fn compile_atomic_and(
456        f: &mut std::fmt::Formatter<'_>,
457        lhs: &Variable<D>,
458        rhs: &Variable<D>,
459        out: &Variable<D>,
460    ) -> std::fmt::Result {
461        let out = out.fmt_left();
462        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
463    }
464
465    fn compile_atomic_cas(
466        f: &mut std::fmt::Formatter<'_>,
467        input: &Variable<D>,
468        cmp: &Variable<D>,
469        val: &Variable<D>,
470        out: &Variable<D>,
471    ) -> std::fmt::Result {
472        let out_item = out.item();
473        let out = out.fmt_left();
474        match val.elem() {
475            // vec4 is automatically supported by the new 128-bit template version
476            Elem::F32 if val.item().vectorization == 2 => {
477                let u64 = Item::new(Elem::<D>::U64, 1, true);
478                let out_tmp = Variable::tmp(u64);
479                writeln!(
480                    f,
481                    "{} = atomicCAS(
482                reinterpret_cast<{u64}*>({input}),
483                reinterpret_cast<{u64}&>({cmp}),
484                reinterpret_cast<{u64}&>({val}));",
485                    out_tmp.fmt_left()
486                )?;
487                writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
488            }
489            Elem::F16 | Elem::BF16 if val.item().vectorization == 2 => {
490                let u32 = Item::new(Elem::<D>::U32, 1, true);
491                let out_tmp = Variable::tmp(u32);
492                writeln!(
493                    f,
494                    "{} = atomicCAS(
495                reinterpret_cast<{u32}*>({input}),
496                reinterpret_cast<{u32}&>({cmp}),
497                reinterpret_cast<{u32}&>({val}));",
498                    out_tmp.fmt_left()
499                )?;
500                writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
501            }
502            _ => writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});"),
503        }
504    }
505
506    fn compile_atomic_load(
507        f: &mut std::fmt::Formatter<'_>,
508        input: &Variable<D>,
509        out: &Variable<D>,
510    ) -> std::fmt::Result {
511        let zero = Variable::Constant(ConstantValue::UInt(0), input.item());
512        Self::compile_atomic_add(f, input, &zero, out)
513    }
514
515    fn compile_atomic_max(
516        f: &mut std::fmt::Formatter<'_>,
517        lhs: &Variable<D>,
518        rhs: &Variable<D>,
519        out: &Variable<D>,
520    ) -> std::fmt::Result {
521        let out = out.fmt_left();
522        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
523    }
524
525    fn compile_atomic_min(
526        f: &mut std::fmt::Formatter<'_>,
527        lhs: &Variable<D>,
528        rhs: &Variable<D>,
529        out: &Variable<D>,
530    ) -> std::fmt::Result {
531        let out = out.fmt_left();
532        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
533    }
534
535    fn compile_atomic_or(
536        f: &mut std::fmt::Formatter<'_>,
537        lhs: &Variable<D>,
538        rhs: &Variable<D>,
539        out: &Variable<D>,
540    ) -> std::fmt::Result {
541        let out = out.fmt_left();
542        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
543    }
544
545    fn compile_atomic_store(
546        f: &mut std::fmt::Formatter<'_>,
547        input: &Variable<D>,
548        out: &Variable<D>,
549    ) -> std::fmt::Result {
550        let tmp = Variable::tmp(input.item());
551        Self::compile_atomic_swap(f, out, input, &tmp)
552    }
553
554    fn compile_atomic_sub(
555        f: &mut std::fmt::Formatter<'_>,
556        lhs: &Variable<D>,
557        rhs: &Variable<D>,
558        out: &Variable<D>,
559    ) -> std::fmt::Result {
560        let out = out.fmt_left();
561        match rhs.elem() {
562            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
563            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
564            Elem::I64 => writeln!(
565                f,
566                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
567                uint = Elem::<D>::U64
568            ),
569            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
570        }
571    }
572
573    fn compile_atomic_swap(
574        f: &mut std::fmt::Formatter<'_>,
575        lhs: &Variable<D>,
576        rhs: &Variable<D>,
577        out: &Variable<D>,
578    ) -> std::fmt::Result {
579        let out_item = out.item();
580        let out = out.fmt_left();
581        match rhs.elem() {
582            // vec4 is automatically supported by the new 128-bit template version
583            Elem::F32 if rhs.item().vectorization == 2 => {
584                let u64 = Item::new(Elem::<D>::U64, 1, true);
585                let out_tmp = Variable::tmp(u64);
586                writeln!(
587                    f,
588                    "{} = atomicExch(
589                reinterpret_cast<{u64}*>({lhs}),
590                reinterpret_cast<{u64}&>({rhs}));",
591                    out_tmp.fmt_left()
592                )?;
593                writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
594            }
595            Elem::F16 | Elem::BF16 if rhs.item().vectorization == 2 => {
596                let u32 = Item::new(Elem::<D>::U32, 1, true);
597                let out_tmp = Variable::tmp(u32);
598                writeln!(
599                    f,
600                    "{} = atomicExch(
601                reinterpret_cast<{u32}*>({lhs}),
602                reinterpret_cast<{u32}&>({rhs}));",
603                    out_tmp.fmt_left()
604                )?;
605                writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
606            }
607            _ => writeln!(f, "{out} = atomicExch({lhs}, {rhs});"),
608        }
609    }
610
611    fn compile_atomic_xor(
612        f: &mut std::fmt::Formatter<'_>,
613        lhs: &Variable<D>,
614        rhs: &Variable<D>,
615        out: &Variable<D>,
616    ) -> std::fmt::Result {
617        let out = out.fmt_left();
618        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
619    }
620
621    fn compile_saturating_add(
622        f: &mut std::fmt::Formatter<'_>,
623        lhs: impl Display,
624        rhs: impl Display,
625        item: Item<D>,
626    ) -> std::fmt::Result;
627
628    fn compile_saturating_sub(
629        f: &mut std::fmt::Formatter<'_>,
630        lhs: impl Display,
631        rhs: impl Display,
632        item: Item<D>,
633    ) -> std::fmt::Result;
634
635    // debug
636    fn compile_instruction_printf(
637        f: &mut std::fmt::Formatter<'_>,
638        format_string: &str,
639        args: &[Variable<D>],
640    ) -> std::fmt::Result {
641        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
642        let args = match args.is_empty() {
643            true => "".to_string(),
644            false => format!(", {}", args.join(",")),
645        };
646        writeln!(f, "printf({format_string:?}{args});")
647    }
648
649    // logs
650    fn compile_instruction_log1p_scalar<T: Component<D>>(
651        f: &mut std::fmt::Formatter<'_>,
652        input: T,
653    ) -> std::fmt::Result {
654        let elem = input.elem();
655        match elem {
656            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
657                write!(f, "{elem}(log1p(float({input})))")
658            }
659            _ => write!(f, "log1p({input})"),
660        }
661    }
662
663    // sync
664    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
665    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
666    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
667
668    // trigo
669    fn compile_instruction_tanh_scalar<T: Component<D>>(
670        f: &mut std::fmt::Formatter<'_>,
671        input: T,
672    ) -> std::fmt::Result {
673        let elem = input.elem();
674        match elem {
675            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
676                write!(f, "{elem}(tanh(float({input})))")
677            }
678            _ => write!(f, "tanh({input})"),
679        }
680    }
681
682    // unary
683    fn compile_instruction_find_first_set<T: Component<D>>(
684        f: &mut std::fmt::Formatter<'_>,
685        input: T,
686        out_elem: Elem<D>,
687    ) -> std::fmt::Result;
688    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
689        f: &mut std::fmt::Formatter<'_>,
690        input: T,
691        out_elem: Elem<D>,
692    ) -> std::fmt::Result;
693
694    fn compile_instruction_trailing_zeros_scalar<T: Component<D>>(
695        f: &mut std::fmt::Formatter<'_>,
696        input: T,
697        out_elem: Elem<D>,
698    ) -> std::fmt::Result;
699
700    fn compile_instruction_popcount_scalar<T: Component<D>>(
701        f: &mut std::fmt::Formatter<'_>,
702        input: T,
703        out_elem: Elem<D>,
704    ) -> std::fmt::Result {
705        write!(f, "{out_elem}(")?;
706        match input.elem() {
707            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
708            Elem::U32 => write!(f, "__popc({input})"),
709            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
710            Elem::U64 => write!(f, "__popcll({input})"),
711            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
712        }?;
713        write!(f, ")")
714    }
715
716    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
717        f: &mut std::fmt::Formatter<'_>,
718        input: T,
719        out_elem: Elem<D>,
720    ) -> std::fmt::Result {
721        write!(f, "{out_elem}(")?;
722        match out_elem {
723            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
724            Elem::U32 => write!(f, "__brev({input})"),
725            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
726            Elem::U64 => write!(f, "__brevll({input})"),
727            _ => write!(
728                f,
729                "__brev({}) >> {}",
730                super::unary::zero_extend(input),
731                (size_of::<u32>() - out_elem.size()) * 8
732            ),
733        }?;
734        write!(f, ")")
735    }
736
737    // others
738    fn compile_instruction_max_function_name(
739        f: &mut std::fmt::Formatter<'_>,
740        item: Item<D>,
741    ) -> std::fmt::Result;
742
743    fn compile_instruction_min_function_name(
744        f: &mut std::fmt::Formatter<'_>,
745        item: Item<D>,
746    ) -> std::fmt::Result;
747
748    fn compile_instruction_powf(
749        f: &mut std::fmt::Formatter<'_>,
750        lhs: &str,
751        rhs: &str,
752        elem: Elem<D>,
753    ) -> std::fmt::Result {
754        match elem {
755            Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
756            Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
757            _ => write!(f, "#error Unsupported type for powf: {elem}"),
758        }
759    }
760
761    fn compile_instruction_hypot(
762        f: &mut std::fmt::Formatter<'_>,
763        lhs: &str,
764        rhs: &str,
765        elem: Elem<D>,
766    ) -> std::fmt::Result {
767        match elem {
768            Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
769            Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
770            _ => write!(f, "#error Unsupported type for hypot: {elem}"),
771        }
772    }
773
774    fn compile_instruction_rhypot(
775        f: &mut std::fmt::Formatter<'_>,
776        lhs: &str,
777        rhs: &str,
778        elem: Elem<D>,
779    ) -> std::fmt::Result {
780        match elem {
781            Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
782            Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
783            _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
784        }
785    }
786
787    fn compile_instruction_half_function_name_prefix() -> &'static str {
788        "h"
789    }
790
791    fn compile_instruction_half2_function_name_prefix() -> &'static str {
792        "h2"
793    }
794
795    // warp
796    fn compile_warp_shuffle(
797        f: &mut std::fmt::Formatter<'_>,
798        var: &str,
799        source: &str,
800    ) -> std::fmt::Result;
801    fn compile_warp_shuffle_xor(
802        f: &mut std::fmt::Formatter<'_>,
803        var: &str,
804        elem: &Elem<D>,
805        offset: &str,
806    ) -> std::fmt::Result;
807    fn compile_warp_shuffle_up(
808        f: &mut std::fmt::Formatter<'_>,
809        var: &str,
810        offset: &str,
811    ) -> std::fmt::Result;
812    fn compile_warp_shuffle_down(
813        f: &mut std::fmt::Formatter<'_>,
814        var: &str,
815        offset: &str,
816    ) -> std::fmt::Result;
817    fn compile_warp_all<T: Component<D>>(
818        f: &mut std::fmt::Formatter<'_>,
819        input: &T,
820    ) -> std::fmt::Result;
821    fn compile_warp_any<T: Component<D>>(
822        f: &mut std::fmt::Formatter<'_>,
823        input: &T,
824    ) -> std::fmt::Result;
825    fn compile_warp_ballot(
826        f: &mut std::fmt::Formatter<'_>,
827        input: &Variable<D>,
828        out_elem: &Elem<D>,
829    ) -> std::fmt::Result;
830    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
831        write!(
832            f,
833            "
834unsigned int mask = __activemask();
835unsigned int leader = __ffs(mask) - 1;
836{out} = threadIdx.x % warpSize == leader;
837            "
838        )
839    }
840    fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
841}
842
843#[derive(Debug, Clone, Copy, new)]
844pub struct ManualMma<'a, D: Dialect> {
845    pub shape: MmaShape<D>,
846    pub frag_a: &'a Variable<D>,
847    pub frag_b: &'a Variable<D>,
848    pub frag_c: &'a Variable<D>,
849    pub frag_d: &'a Variable<D>,
850}
851
852pub trait DialectWarpReduceCompiler<D: Dialect>:
853    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
854{
855    fn warp_reduce_sum(
856        f: &mut core::fmt::Formatter<'_>,
857        input: &Variable<D>,
858        out: &Variable<D>,
859    ) -> core::fmt::Result {
860        reduce_operator(f, input, out, "+=")
861    }
862    fn warp_reduce_prod(
863        f: &mut core::fmt::Formatter<'_>,
864        input: &Variable<D>,
865        out: &Variable<D>,
866    ) -> core::fmt::Result {
867        reduce_operator(f, input, out, "*=")
868    }
869    fn warp_reduce_max(
870        f: &mut core::fmt::Formatter<'_>,
871        input: &Variable<D>,
872        out: &Variable<D>,
873    ) -> core::fmt::Result {
874        reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
875    }
876    fn warp_reduce_min(
877        f: &mut core::fmt::Formatter<'_>,
878        input: &Variable<D>,
879        out: &Variable<D>,
880    ) -> core::fmt::Result {
881        reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
882    }
883    fn warp_reduce_all(
884        f: &mut core::fmt::Formatter<'_>,
885        input: &Variable<D>,
886        out: &Variable<D>,
887    ) -> core::fmt::Result {
888        reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
889    }
890    fn warp_reduce_any(
891        f: &mut core::fmt::Formatter<'_>,
892        input: &Variable<D>,
893        out: &Variable<D>,
894    ) -> core::fmt::Result {
895        reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
896    }
897    fn warp_reduce_sum_inclusive(
898        f: &mut core::fmt::Formatter<'_>,
899        input: &Variable<D>,
900        out: &Variable<D>,
901    ) -> core::fmt::Result {
902        reduce_inclusive(f, input, out, "+=")
903    }
904    fn warp_reduce_prod_inclusive(
905        f: &mut core::fmt::Formatter<'_>,
906        input: &Variable<D>,
907        out: &Variable<D>,
908    ) -> core::fmt::Result {
909        reduce_inclusive(f, input, out, "*=")
910    }
911    fn warp_reduce_sum_exclusive(
912        f: &mut core::fmt::Formatter<'_>,
913        input: &Variable<D>,
914        out: &Variable<D>,
915    ) -> core::fmt::Result {
916        reduce_exclusive(f, input, out, "+=", "0")
917    }
918    fn warp_reduce_prod_exclusive(
919        f: &mut core::fmt::Formatter<'_>,
920        input: &Variable<D>,
921        out: &Variable<D>,
922    ) -> core::fmt::Result {
923        reduce_exclusive(f, input, out, "*=", "1")
924    }
925}
926
927pub trait DialectWmmaCompiler<D: Dialect>:
928    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
929{
930    #[allow(unused_variables)]
931    fn compile_wmma_includes(
932        f: &mut std::fmt::Formatter<'_>,
933        flags: &Flags<D>,
934    ) -> std::fmt::Result {
935        Ok(())
936    }
937    #[allow(unused_variables)]
938    fn compile_wmma_type_definitions(
939        f: &mut std::fmt::Formatter<'_>,
940        flags: &Flags<D>,
941    ) -> std::fmt::Result {
942        Ok(())
943    }
944    #[allow(unused_variables)]
945    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
946        Ok(())
947    }
948    #[allow(unused_variables)]
949    fn compile_wwma_fragment_ident(
950        f: &mut std::fmt::Formatter<'_>,
951        ident: &FragmentIdent<D>,
952    ) -> std::fmt::Result {
953        Ok(())
954    }
955    #[allow(unused_variables)]
956    fn compile_wmma_fragment_layout(
957        f: &mut std::fmt::Formatter<'_>,
958        layout: &FragmentLayout<D>,
959    ) -> std::fmt::Result {
960        Ok(())
961    }
962    #[allow(unused_variables)]
963    fn compile_wmma_fragment(
964        f: &mut std::fmt::Formatter<'_>,
965        fragment: &Fragment<D>,
966    ) -> std::fmt::Result {
967        Ok(())
968    }
969
970    fn compile_wmma_fragment_declaration(
971        f: &mut std::fmt::Formatter<'_>,
972        var: &Variable<D>,
973    ) -> std::fmt::Result;
974
975    fn compile_wmma_instruction(
976        f: &mut std::fmt::Formatter<'_>,
977        instruction: &WmmaInstruction<D>,
978    ) -> std::fmt::Result;
979    fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
980    fn compile_scaled_mma(
981        f: &mut std::fmt::Formatter<'_>,
982        mma: ManualMma<D>,
983        scales_a: Variable<D>,
984        scales_b: Variable<D>,
985        scales_factor: u32,
986    ) -> std::fmt::Result;
987    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
988    fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
989    fn supported_scaled_mma_combinations(
990        _arch: &D::Architecture,
991    ) -> SupportedScaledMmaCombinations {
992        Vec::new()
993    }
994}
995
996/// IR Processors to be applied to the scopes during processing. ``CheckedIO`` is always applied
997/// by default, so these are only for target specific processors like MMA index processors.
998pub trait DialectProcessors<D: Dialect> {
999    fn processors() -> Vec<Box<dyn Processor>>;
1000}