cubecl_cpp/shared/
dialect.rs

1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7    FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8    reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12    Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14    WmmaInstruction,
15};
16
17// Base dialect
18
19pub trait Dialect:
20    DialectIncludes<Self>
21    + DialectTypes<Self>
22    + DialectBindings<Self>
23    + DialectWarpReduceCompiler<Self>
24    + DialectCubeBuiltins<Self>
25    + DialectInstructions<Self>
26    + DialectWmmaCompiler<Self>
27    + DialectProcessors<Self>
28    + Default
29    + Clone
30    + Copy
31    + Debug
32    + Send
33    + Sync
34    + Eq
35    + Hash
36    + 'static
37{
38    type Architecture: Architecture;
39}
40
41// Includes
42
43pub trait DialectIncludes<D: Dialect> {
44    type Extension: Debug + Clone + Sync + Send;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
47    fn compile_extensions(
48        f: &mut std::fmt::Formatter<'_>,
49        extensions: &[Self::Extension],
50    ) -> std::fmt::Result;
51    fn register_instruction_extension(
52        extensions: &mut Vec<Self::Extension>,
53        instruction: &Instruction<D>,
54    );
55    fn register_warp_instruction_extension(
56        extensions: &mut Vec<Self::Extension>,
57        instruction: &WarpInstruction<D>,
58    );
59    #[allow(unused_variables)]
60    fn register_wmma_instruction_extension(
61        extensions: &mut Vec<Self::Extension>,
62        instruction: &WmmaInstruction<D>,
63    ) {
64    }
65}
66
67// Types
68
69pub trait DialectTypes<D: Dialect> {
70    fn item_can_be_optimized() -> bool;
71    fn compile_elem(
72        f: &mut std::fmt::Formatter<'_>,
73        elem: &Elem<D>,
74        word: bool,
75    ) -> std::fmt::Result;
76
77    fn compile_atomic_kind(
78        f: &mut std::fmt::Formatter<'_>,
79        kind: &AtomicKind<D>,
80    ) -> std::fmt::Result {
81        match kind {
82            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90            AtomicKind::_Dialect(_) => Ok(()),
91        }
92    }
93
94    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95    fn compile_type_definitions(
96        f: &mut std::fmt::Formatter<'_>,
97        items: &HashSet<Item<D>>,
98        scalars: &[(Elem<D>, usize)],
99        flags: &Flags<D>,
100    ) -> std::fmt::Result;
101    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102    fn compile_shared_memory_declaration(
103        f: &mut std::fmt::Formatter<'_>,
104        shared: &SharedMemory<D>,
105    ) -> std::fmt::Result {
106        match shared {
107            SharedMemory::Array {
108                index,
109                item,
110                length,
111                offset,
112                ..
113            } => {
114                let size_bytes = *length * item.size();
115                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116                writeln!(
117                    f,
118                    "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119                )
120            }
121            SharedMemory::Value {
122                index,
123                item,
124                offset,
125                ..
126            } => {
127                let size_bytes = item.size() as u32;
128                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129                writeln!(
130                    f,
131                    "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132                )
133            }
134        }
135    }
136    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
137        Ok(())
138    }
139    /// Address space (for Metal dialect only).
140    fn address_space_for_variable(_variable: &Variable<D>) -> String {
141        "".to_string()
142    }
143}
144
145// Kernel argument bindings
146
147pub trait DialectBindings<D: Dialect> {
148    fn compile_kernel_signature(
149        f: &mut std::fmt::Formatter<'_>,
150        kernel_name: &str,
151        tensor_maps: &[Binding<D>],
152        buffers: &[Binding<D>],
153        scalars: &[(Elem<D>, usize)],
154        flags: &Flags<D>,
155    ) -> std::fmt::Result;
156    fn compile_bindings_body(
157        _f: &mut std::fmt::Formatter<'_>,
158        _body: &Body<D>,
159    ) -> std::fmt::Result {
160        Ok(())
161    }
162}
163
164// Cube builtins dialect
165
166pub trait DialectCubeBuiltins<D: Dialect> {
167    /// Depending on the dialect available built-in variables the
168    /// inclusion rules might change.
169    /// For instance in metal we have a built-in for the Unit plane position
170    /// but in other dialects there is none so we have to compute it using
171    /// other built-ins.
172    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173        let unit_pos_plane = flags.unit_pos_plane;
174        let plane_dim_checked = flags.plane_dim_checked;
175        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176        let plane_index = flags.plane_index;
177        let absolute_pos = flags.absolute_pos || unit_pos_plane;
178        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179        let cube_dim = flags.cube_dim;
180        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181        let unit_pos = flags.unit_pos;
182        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183        let cube_count = flags.cube_count;
184        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185        let cube_pos = flags.cube_pos;
186        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187        let cluster_group = flags.cluster_pos;
188
189        CubeIndexFlags {
190            absolute_pos,
191            absolute_pos_tuple,
192            cube_count,
193            cube_count_tuple,
194            cube_dim,
195            cube_dim_tuple,
196            cube_pos,
197            cube_pos_tuple,
198            plane_dim,
199            plane_dim_checked,
200            plane_index,
201            unit_pos_tuple,
202            unit_pos,
203            unit_pos_plane,
204            cluster_pos: cluster_group,
205        }
206    }
207
208    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        let variable = Variable::<D>::AbsolutePosBaseName;
210        let ty = variable.item();
211        let cube_pos_x = Variable::<D>::CubePosX;
212        let cube_pos_y = Variable::<D>::CubePosY;
213        let cube_pos_z = Variable::<D>::CubePosZ;
214        let cube_dim_x = Variable::<D>::CubeDimX;
215        let cube_dim_y = Variable::<D>::CubeDimY;
216        let cube_dim_z = Variable::<D>::CubeDimZ;
217        let unit_pos_x = Variable::<D>::UnitPosX;
218        let unit_pos_y = Variable::<D>::UnitPosY;
219        let unit_pos_z = Variable::<D>::UnitPosZ;
220        writeln!(
221            f,
222            "{ty} {variable} = make_{ty}(
223    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227        )
228    }
229
230    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.write_str("absoluteIdx")
232    }
233
234    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.write_str("idxGlobal")
236    }
237
238    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        Self::compile_absolute_pos_base_name(f)?;
240        write!(f, ".x")
241    }
242
243    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        Self::compile_absolute_pos_base_name(f)?;
245        write!(f, ".y")
246    }
247
248    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        Self::compile_absolute_pos_base_name(f)?;
250        write!(f, ".z")
251    }
252
253    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.write_str("gridDim")
255    }
256
257    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.write_str("gridDimGlobal")
259    }
260
261    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        Self::compile_cube_count_base_name(f)?;
263        write!(f, ".x")
264    }
265
266    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        Self::compile_cube_count_base_name(f)?;
268        write!(f, ".y")
269    }
270
271    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        Self::compile_cube_count_base_name(f)?;
273        write!(f, ".z")
274    }
275
276    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_str("blockDim")
278    }
279
280    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.write_str("blockDimGlobal")
282    }
283
284    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        Self::compile_cube_dim_base_name(f)?;
286        write!(f, ".x")
287    }
288
289    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290        Self::compile_cube_dim_base_name(f)?;
291        write!(f, ".y")
292    }
293
294    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        Self::compile_cube_dim_base_name(f)?;
296        write!(f, ".z")
297    }
298
299    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.write_str("blockIdx")
301    }
302
303    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        f.write_str("blockIdxGlobal")
305    }
306
307    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        Self::compile_cube_pos_base_name(f)?;
309        write!(f, ".x")
310    }
311
312    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        Self::compile_cube_pos_base_name(f)?;
314        write!(f, ".y")
315    }
316
317    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        Self::compile_cube_pos_base_name(f)?;
319        write!(f, ".z")
320    }
321
322    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        let variable = Variable::<D>::UnitPos;
324        let ty = variable.item();
325        let cube_dim_x = Variable::<D>::CubeDimX;
326        let cube_dim_y = Variable::<D>::CubeDimY;
327        let unit_pos_x = Variable::<D>::UnitPosX;
328        let unit_pos_y = Variable::<D>::UnitPosY;
329        let unit_pos_z = Variable::<D>::UnitPosZ;
330        writeln!(
331            f,
332            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333        )
334    }
335
336    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.write_str("threadIdxGlobal")
338    }
339
340    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.write_str("threadIdx")
342    }
343
344    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        Self::compile_unit_pos_base_name(f)?;
346        write!(f, ".x")
347    }
348
349    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        Self::compile_unit_pos_base_name(f)?;
351        write!(f, ".y")
352    }
353
354    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        Self::compile_unit_pos_base_name(f)?;
356        write!(f, ".z")
357    }
358
359    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.write_str("warpSize")
361    }
362
363    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        f.write_str("warpSizeChecked")
365    }
366
367    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        let unit_pos_x = Variable::<D>::UnitPosX;
369        let plane_dim = Variable::<D>::PlaneDim;
370        write!(f, "{unit_pos_x} / {plane_dim}")
371    }
372
373    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
375        let plane_dim = Variable::<D>::PlaneDim;
376        let ty = plane_dim.item();
377        write!(f, "{ty}({absolute_pos}) % {plane_dim}")
378    }
379
380    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        write!(f, "0")
382    }
383    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        write!(f, "0")
385    }
386    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        write!(f, "0")
388    }
389    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        write!(f, "0")
391    }
392}
393
394// Instructions
395
396pub trait DialectInstructions<D: Dialect> {
397    // atomics
398    fn compile_atomic_add(
399        f: &mut std::fmt::Formatter<'_>,
400        lhs: &Variable<D>,
401        rhs: &Variable<D>,
402        out: &Variable<D>,
403    ) -> std::fmt::Result {
404        let out = out.fmt_left();
405        match rhs.elem() {
406            Elem::I64 => writeln!(
407                f,
408                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
409                uint = Elem::<D>::U64
410            ),
411            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
412        }
413    }
414
415    fn compile_atomic_and(
416        f: &mut std::fmt::Formatter<'_>,
417        lhs: &Variable<D>,
418        rhs: &Variable<D>,
419        out: &Variable<D>,
420    ) -> std::fmt::Result {
421        let out = out.fmt_left();
422        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
423    }
424
425    fn compile_atomic_cas(
426        f: &mut std::fmt::Formatter<'_>,
427        input: &Variable<D>,
428        cmp: &Variable<D>,
429        val: &Variable<D>,
430        out: &Variable<D>,
431    ) -> std::fmt::Result {
432        let out = out.fmt_left();
433        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
434    }
435
436    fn compile_atomic_load(
437        f: &mut std::fmt::Formatter<'_>,
438        input: &Variable<D>,
439        out: &Variable<D>,
440    ) -> std::fmt::Result {
441        let out = out.fmt_left();
442        writeln!(f, "{out} = atomicAdd({input}, 0);")
443    }
444
445    fn compile_atomic_max(
446        f: &mut std::fmt::Formatter<'_>,
447        lhs: &Variable<D>,
448        rhs: &Variable<D>,
449        out: &Variable<D>,
450    ) -> std::fmt::Result {
451        let out = out.fmt_left();
452        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
453    }
454
455    fn compile_atomic_min(
456        f: &mut std::fmt::Formatter<'_>,
457        lhs: &Variable<D>,
458        rhs: &Variable<D>,
459        out: &Variable<D>,
460    ) -> std::fmt::Result {
461        let out = out.fmt_left();
462        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
463    }
464
465    fn compile_atomic_or(
466        f: &mut std::fmt::Formatter<'_>,
467        lhs: &Variable<D>,
468        rhs: &Variable<D>,
469        out: &Variable<D>,
470    ) -> std::fmt::Result {
471        let out = out.fmt_left();
472        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
473    }
474
475    fn compile_atomic_store(
476        f: &mut std::fmt::Formatter<'_>,
477        input: &Variable<D>,
478        out: &Variable<D>,
479    ) -> std::fmt::Result {
480        writeln!(f, "atomicExch({out}, {input});")
481    }
482
483    fn compile_atomic_sub(
484        f: &mut std::fmt::Formatter<'_>,
485        lhs: &Variable<D>,
486        rhs: &Variable<D>,
487        out: &Variable<D>,
488    ) -> std::fmt::Result {
489        let out = out.fmt_left();
490        match rhs.elem() {
491            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
492            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
493            Elem::I64 => writeln!(
494                f,
495                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
496                uint = Elem::<D>::U64
497            ),
498            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
499        }
500    }
501
502    fn compile_atomic_swap(
503        f: &mut std::fmt::Formatter<'_>,
504        lhs: &Variable<D>,
505        rhs: &Variable<D>,
506        out: &Variable<D>,
507    ) -> std::fmt::Result {
508        let out = out.fmt_left();
509        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
510    }
511
512    fn compile_atomic_xor(
513        f: &mut std::fmt::Formatter<'_>,
514        lhs: &Variable<D>,
515        rhs: &Variable<D>,
516        out: &Variable<D>,
517    ) -> std::fmt::Result {
518        let out = out.fmt_left();
519        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
520    }
521
522    fn compile_saturating_add(
523        f: &mut std::fmt::Formatter<'_>,
524        lhs: impl Display,
525        rhs: impl Display,
526        item: Item<D>,
527    ) -> std::fmt::Result;
528
529    fn compile_saturating_sub(
530        f: &mut std::fmt::Formatter<'_>,
531        lhs: impl Display,
532        rhs: impl Display,
533        item: Item<D>,
534    ) -> std::fmt::Result;
535
536    // debug
537    fn compile_instruction_printf(
538        f: &mut std::fmt::Formatter<'_>,
539        format_string: &str,
540        args: &[Variable<D>],
541    ) -> std::fmt::Result {
542        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
543        let args = match args.is_empty() {
544            true => "".to_string(),
545            false => format!(", {}", args.join(",")),
546        };
547        writeln!(f, "printf({format_string:?}{args});")
548    }
549
550    // logs
551    fn compile_instruction_log1p_scalar<T: Component<D>>(
552        f: &mut std::fmt::Formatter<'_>,
553        input: T,
554    ) -> std::fmt::Result {
555        let elem = input.elem();
556        match elem {
557            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
558                write!(f, "{elem}(log1p(float({input})))")
559            }
560            _ => write!(f, "log1p({input})"),
561        }
562    }
563
564    // sync
565    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
568
569    // trigo
570    fn compile_instruction_tanh_scalar<T: Component<D>>(
571        f: &mut std::fmt::Formatter<'_>,
572        input: T,
573    ) -> std::fmt::Result {
574        let elem = input.elem();
575        match elem {
576            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
577                write!(f, "{elem}(tanh(float({input})))")
578            }
579            _ => write!(f, "tanh({input})"),
580        }
581    }
582
583    // unary
584    fn compile_instruction_find_first_set<T: Component<D>>(
585        f: &mut std::fmt::Formatter<'_>,
586        input: T,
587        out_elem: Elem<D>,
588    ) -> std::fmt::Result;
589    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
590        f: &mut std::fmt::Formatter<'_>,
591        input: T,
592        out_elem: Elem<D>,
593    ) -> std::fmt::Result;
594
595    fn compile_instruction_popcount_scalar<T: Component<D>>(
596        f: &mut std::fmt::Formatter<'_>,
597        input: T,
598        out_elem: Elem<D>,
599    ) -> std::fmt::Result {
600        write!(f, "{out_elem}(")?;
601        match input.elem() {
602            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
603            Elem::U32 => write!(f, "__popc({input})"),
604            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
605            Elem::U64 => write!(f, "__popcll({input})"),
606            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
607        }?;
608        write!(f, ")")
609    }
610
611    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
612        f: &mut std::fmt::Formatter<'_>,
613        input: T,
614        out_elem: Elem<D>,
615    ) -> std::fmt::Result {
616        write!(f, "{out_elem}(")?;
617        match out_elem {
618            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
619            Elem::U32 => write!(f, "__brev({input})"),
620            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
621            Elem::U64 => write!(f, "__brevll({input})"),
622            _ => write!(
623                f,
624                "__brev({}) >> {}",
625                super::unary::zero_extend(input),
626                (size_of::<u32>() - out_elem.size()) * 8
627            ),
628        }?;
629        write!(f, ")")
630    }
631
632    // others
633    fn compile_instruction_max_function_name(
634        f: &mut std::fmt::Formatter<'_>,
635        item: Item<D>,
636    ) -> std::fmt::Result;
637
638    fn compile_instruction_min_function_name(
639        f: &mut std::fmt::Formatter<'_>,
640        item: Item<D>,
641    ) -> std::fmt::Result;
642
643    fn compile_instruction_powf(
644        f: &mut std::fmt::Formatter<'_>,
645        lhs: &str,
646        rhs: &str,
647        elem: Elem<D>,
648    ) -> std::fmt::Result {
649        match elem {
650            Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
651            Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
652            _ => write!(f, "#error Unsupported type for powf: {elem}"),
653        }
654    }
655
656    fn compile_instruction_hypot(
657        f: &mut std::fmt::Formatter<'_>,
658        lhs: &str,
659        rhs: &str,
660        elem: Elem<D>,
661    ) -> std::fmt::Result {
662        match elem {
663            Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
664            Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
665            _ => write!(f, "#error Unsupported type for hypot: {elem}"),
666        }
667    }
668
669    fn compile_instruction_rhypot(
670        f: &mut std::fmt::Formatter<'_>,
671        lhs: &str,
672        rhs: &str,
673        elem: Elem<D>,
674    ) -> std::fmt::Result {
675        match elem {
676            Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
677            Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
678            _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
679        }
680    }
681
682    fn compile_instruction_half_function_name_prefix() -> &'static str {
683        "h"
684    }
685
686    fn compile_instruction_half2_function_name_prefix() -> &'static str {
687        "h2"
688    }
689
690    // warp
691    fn compile_warp_shuffle(
692        f: &mut std::fmt::Formatter<'_>,
693        var: &str,
694        source: &str,
695    ) -> std::fmt::Result;
696    fn compile_warp_shuffle_xor(
697        f: &mut std::fmt::Formatter<'_>,
698        var: &str,
699        elem: &Elem<D>,
700        offset: &str,
701    ) -> std::fmt::Result;
702    fn compile_warp_shuffle_up(
703        f: &mut std::fmt::Formatter<'_>,
704        var: &str,
705        offset: &str,
706    ) -> std::fmt::Result;
707    fn compile_warp_shuffle_down(
708        f: &mut std::fmt::Formatter<'_>,
709        var: &str,
710        offset: &str,
711    ) -> std::fmt::Result;
712    fn compile_warp_all<T: Component<D>>(
713        f: &mut std::fmt::Formatter<'_>,
714        input: &T,
715    ) -> std::fmt::Result;
716    fn compile_warp_any<T: Component<D>>(
717        f: &mut std::fmt::Formatter<'_>,
718        input: &T,
719    ) -> std::fmt::Result;
720    fn compile_warp_ballot(
721        f: &mut std::fmt::Formatter<'_>,
722        input: &Variable<D>,
723        out_elem: &Elem<D>,
724    ) -> std::fmt::Result;
725    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
726        write!(
727            f,
728            "
729unsigned int mask = __activemask();
730unsigned int leader = __ffs(mask) - 1;
731{out} = threadIdx.x % warpSize == leader;
732            "
733        )
734    }
735}
736
737#[derive(Debug, Clone, Copy, new)]
738pub struct ManualMma<'a, D: Dialect> {
739    pub shape: MmaShape<D>,
740    pub frag_a: &'a Variable<D>,
741    pub frag_b: &'a Variable<D>,
742    pub frag_c: &'a Variable<D>,
743    pub frag_d: &'a Variable<D>,
744}
745
746pub trait DialectWarpReduceCompiler<D: Dialect>:
747    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
748{
749    fn warp_reduce_sum(
750        f: &mut core::fmt::Formatter<'_>,
751        input: &Variable<D>,
752        out: &Variable<D>,
753    ) -> core::fmt::Result {
754        reduce_operator(f, input, out, "+=")
755    }
756    fn warp_reduce_prod(
757        f: &mut core::fmt::Formatter<'_>,
758        input: &Variable<D>,
759        out: &Variable<D>,
760    ) -> core::fmt::Result {
761        reduce_operator(f, input, out, "*=")
762    }
763    fn warp_reduce_max(
764        f: &mut core::fmt::Formatter<'_>,
765        input: &Variable<D>,
766        out: &Variable<D>,
767    ) -> core::fmt::Result {
768        reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
769    }
770    fn warp_reduce_min(
771        f: &mut core::fmt::Formatter<'_>,
772        input: &Variable<D>,
773        out: &Variable<D>,
774    ) -> core::fmt::Result {
775        reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
776    }
777    fn warp_reduce_all(
778        f: &mut core::fmt::Formatter<'_>,
779        input: &Variable<D>,
780        out: &Variable<D>,
781    ) -> core::fmt::Result {
782        reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
783    }
784    fn warp_reduce_any(
785        f: &mut core::fmt::Formatter<'_>,
786        input: &Variable<D>,
787        out: &Variable<D>,
788    ) -> core::fmt::Result {
789        reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
790    }
791    fn warp_reduce_sum_inclusive(
792        f: &mut core::fmt::Formatter<'_>,
793        input: &Variable<D>,
794        out: &Variable<D>,
795    ) -> core::fmt::Result {
796        reduce_inclusive(f, input, out, "+=")
797    }
798    fn warp_reduce_prod_inclusive(
799        f: &mut core::fmt::Formatter<'_>,
800        input: &Variable<D>,
801        out: &Variable<D>,
802    ) -> core::fmt::Result {
803        reduce_inclusive(f, input, out, "*=")
804    }
805    fn warp_reduce_sum_exclusive(
806        f: &mut core::fmt::Formatter<'_>,
807        input: &Variable<D>,
808        out: &Variable<D>,
809    ) -> core::fmt::Result {
810        reduce_exclusive(f, input, out, "+=", "0")
811    }
812    fn warp_reduce_prod_exclusive(
813        f: &mut core::fmt::Formatter<'_>,
814        input: &Variable<D>,
815        out: &Variable<D>,
816    ) -> core::fmt::Result {
817        reduce_exclusive(f, input, out, "*=", "1")
818    }
819}
820
821pub trait DialectWmmaCompiler<D: Dialect>:
822    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
823{
824    #[allow(unused_variables)]
825    fn compile_wmma_includes(
826        f: &mut std::fmt::Formatter<'_>,
827        flags: &Flags<D>,
828    ) -> std::fmt::Result {
829        Ok(())
830    }
831    #[allow(unused_variables)]
832    fn compile_wmma_type_definitions(
833        f: &mut std::fmt::Formatter<'_>,
834        flags: &Flags<D>,
835    ) -> std::fmt::Result {
836        Ok(())
837    }
838    #[allow(unused_variables)]
839    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840        Ok(())
841    }
842    #[allow(unused_variables)]
843    fn compile_wwma_fragment_ident(
844        f: &mut std::fmt::Formatter<'_>,
845        ident: &FragmentIdent<D>,
846    ) -> std::fmt::Result {
847        Ok(())
848    }
849    #[allow(unused_variables)]
850    fn compile_wmma_fragment_layout(
851        f: &mut std::fmt::Formatter<'_>,
852        layout: &FragmentLayout<D>,
853    ) -> std::fmt::Result {
854        Ok(())
855    }
856    #[allow(unused_variables)]
857    fn compile_wmma_fragment(
858        f: &mut std::fmt::Formatter<'_>,
859        fragment: &Fragment<D>,
860    ) -> std::fmt::Result {
861        Ok(())
862    }
863
864    fn compile_wmma_fragment_declaration(
865        f: &mut std::fmt::Formatter<'_>,
866        var: &Variable<D>,
867    ) -> std::fmt::Result;
868
869    fn compile_wmma_instruction(
870        f: &mut std::fmt::Formatter<'_>,
871        instruction: &WmmaInstruction<D>,
872    ) -> std::fmt::Result;
873    fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
874    fn compile_scaled_mma(
875        f: &mut std::fmt::Formatter<'_>,
876        mma: ManualMma<D>,
877        scales_a: Variable<D>,
878        scales_b: Variable<D>,
879        scales_factor: u32,
880    ) -> std::fmt::Result;
881    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
882    fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
883    fn supported_scaled_mma_combinations(
884        _arch: &D::Architecture,
885    ) -> SupportedScaledMmaCombinations {
886        Vec::new()
887    }
888}
889
890/// IR Processors to be applied to the scopes during processing. [`CheckedIO`] is always applied
891/// by default, so these are only for target specific processors like MMA index processors.
892pub trait DialectProcessors<D: Dialect> {
893    fn processors() -> Vec<Box<dyn Processor>>;
894}