cubecl_cpp/shared/
dialect.rs

1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7    FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8    reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12    Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14    WmmaInstruction,
15};
16
17// Base dialect
18
19pub trait Dialect:
20    DialectIncludes<Self>
21    + DialectTypes<Self>
22    + DialectBindings<Self>
23    + DialectWarpReduceCompiler<Self>
24    + DialectCubeBuiltins<Self>
25    + DialectInstructions<Self>
26    + DialectWmmaCompiler<Self>
27    + DialectProcessors<Self>
28    + Default
29    + Clone
30    + Copy
31    + Debug
32    + Send
33    + Sync
34    + Eq
35    + Hash
36    + 'static
37{
38    type Architecture: Architecture;
39}
40
41// Includes
42
43pub trait DialectIncludes<D: Dialect> {
44    type Extension: Debug + Clone + Sync + Send;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47    fn compile_extensions(
48        f: &mut std::fmt::Formatter<'_>,
49        extensions: &[Self::Extension],
50    ) -> std::fmt::Result;
51    fn register_instruction_extension(
52        extensions: &mut Vec<Self::Extension>,
53        instruction: &Instruction<D>,
54    );
55    fn register_warp_instruction_extension(
56        extensions: &mut Vec<Self::Extension>,
57        instruction: &WarpInstruction<D>,
58    );
59    #[allow(unused_variables)]
60    fn register_wmma_instruction_extension(
61        extensions: &mut Vec<Self::Extension>,
62        instruction: &WmmaInstruction<D>,
63    ) {
64    }
65}
66
67// Types
68
69pub trait DialectTypes<D: Dialect> {
70    fn item_can_be_optimized() -> bool;
71    fn compile_elem(
72        f: &mut std::fmt::Formatter<'_>,
73        elem: &Elem<D>,
74        word: bool,
75    ) -> std::fmt::Result;
76
77    fn compile_atomic_kind(
78        f: &mut std::fmt::Formatter<'_>,
79        kind: &AtomicKind<D>,
80    ) -> std::fmt::Result {
81        match kind {
82            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90            AtomicKind::_Dialect(_) => Ok(()),
91        }
92    }
93
94    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95    fn compile_type_definitions(
96        f: &mut std::fmt::Formatter<'_>,
97        items: &HashSet<Item<D>>,
98        scalars: &[(Elem<D>, usize)],
99        flags: &Flags,
100    ) -> std::fmt::Result;
101    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102    fn compile_shared_memory_declaration(
103        f: &mut std::fmt::Formatter<'_>,
104        shared: &SharedMemory<D>,
105    ) -> std::fmt::Result {
106        match shared {
107            SharedMemory::Array {
108                index,
109                item,
110                length,
111                offset,
112                ..
113            } => {
114                let size_bytes = *length * item.size() as u32;
115                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116                writeln!(
117                    f,
118                    "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119                )
120            }
121            SharedMemory::Value {
122                index,
123                item,
124                offset,
125                ..
126            } => {
127                let size_bytes = item.size() as u32;
128                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129                writeln!(
130                    f,
131                    "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132                )
133            }
134        }
135    }
136    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
137        Ok(())
138    }
139    /// Address space (for Metal dialect only).
140    fn address_space_for_variable(_variable: &Variable<D>) -> String {
141        "".to_string()
142    }
143}
144
145// Kernel argument bindings
146
147pub trait DialectBindings<D: Dialect> {
148    fn compile_kernel_signature(
149        f: &mut std::fmt::Formatter<'_>,
150        kernel_name: &str,
151        tensor_maps: &[Binding<D>],
152        buffers: &[Binding<D>],
153        scalars: &[(Elem<D>, usize)],
154        flags: &Flags,
155    ) -> std::fmt::Result;
156    fn compile_bindings_body(
157        _f: &mut std::fmt::Formatter<'_>,
158        _body: &Body<D>,
159    ) -> std::fmt::Result {
160        Ok(())
161    }
162}
163
164// Cube builtins dialect
165
166pub trait DialectCubeBuiltins<D: Dialect> {
167    /// Depending on the dialect available built-in variables the
168    /// inclusion rules might change.
169    /// For instance in metal we have a built-in for the Unit plane position
170    /// but in other dialects there is none so we have to compute it using
171    /// other built-ins.
172    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173        let unit_pos_plane = flags.unit_pos_plane;
174        let plane_dim_checked = flags.plane_dim_checked;
175        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176        let plane_index = flags.plane_index;
177        let absolute_pos = flags.absolute_pos || unit_pos_plane;
178        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179        let cube_dim = flags.cube_dim;
180        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181        let unit_pos = flags.unit_pos;
182        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183        let cube_count = flags.cube_count;
184        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185        let cube_pos = flags.cube_pos;
186        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187        let cluster_group = flags.cluster_pos;
188
189        CubeIndexFlags {
190            absolute_pos,
191            absolute_pos_tuple,
192            cube_count,
193            cube_count_tuple,
194            cube_dim,
195            cube_dim_tuple,
196            cube_pos,
197            cube_pos_tuple,
198            plane_dim,
199            plane_dim_checked,
200            plane_index,
201            unit_pos_tuple,
202            unit_pos,
203            unit_pos_plane,
204            cluster_pos: cluster_group,
205        }
206    }
207
208    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        let variable = Variable::<D>::AbsolutePosBaseName;
210        let ty = variable.item();
211        let cube_pos_x = Variable::<D>::CubePosX;
212        let cube_pos_y = Variable::<D>::CubePosY;
213        let cube_pos_z = Variable::<D>::CubePosZ;
214        let cube_dim_x = Variable::<D>::CubeDimX;
215        let cube_dim_y = Variable::<D>::CubeDimY;
216        let cube_dim_z = Variable::<D>::CubeDimZ;
217        let unit_pos_x = Variable::<D>::UnitPosX;
218        let unit_pos_y = Variable::<D>::UnitPosY;
219        let unit_pos_z = Variable::<D>::UnitPosZ;
220        writeln!(
221            f,
222            "{ty} {variable} = make_{ty}(
223    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227        )
228    }
229
230    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.write_str("absoluteIdx")
232    }
233
234    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.write_str("idxGlobal")
236    }
237
238    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        Self::compile_absolute_pos_base_name(f)?;
240        write!(f, ".x")
241    }
242
243    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        Self::compile_absolute_pos_base_name(f)?;
245        write!(f, ".y")
246    }
247
248    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        Self::compile_absolute_pos_base_name(f)?;
250        write!(f, ".z")
251    }
252
253    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.write_str("gridDim")
255    }
256
257    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.write_str("gridDimGlobal")
259    }
260
261    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        Self::compile_cube_count_base_name(f)?;
263        write!(f, ".x")
264    }
265
266    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        Self::compile_cube_count_base_name(f)?;
268        write!(f, ".y")
269    }
270
271    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        Self::compile_cube_count_base_name(f)?;
273        write!(f, ".z")
274    }
275
276    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_str("blockDim")
278    }
279
280    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.write_str("blockDimGlobal")
282    }
283
284    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        Self::compile_cube_dim_base_name(f)?;
286        write!(f, ".x")
287    }
288
289    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290        Self::compile_cube_dim_base_name(f)?;
291        write!(f, ".y")
292    }
293
294    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        Self::compile_cube_dim_base_name(f)?;
296        write!(f, ".z")
297    }
298
299    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.write_str("blockIdx")
301    }
302
303    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        f.write_str("blockIdxGlobal")
305    }
306
307    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        Self::compile_cube_pos_base_name(f)?;
309        write!(f, ".x")
310    }
311
312    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        Self::compile_cube_pos_base_name(f)?;
314        write!(f, ".y")
315    }
316
317    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        Self::compile_cube_pos_base_name(f)?;
319        write!(f, ".z")
320    }
321
322    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        let variable = Variable::<D>::UnitPos;
324        let ty = variable.item();
325        let cube_dim_x = Variable::<D>::CubeDimX;
326        let cube_dim_y = Variable::<D>::CubeDimY;
327        let unit_pos_x = Variable::<D>::UnitPosX;
328        let unit_pos_y = Variable::<D>::UnitPosY;
329        let unit_pos_z = Variable::<D>::UnitPosZ;
330        writeln!(
331            f,
332            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333        )
334    }
335
336    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.write_str("threadIdxGlobal")
338    }
339
340    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.write_str("threadIdx")
342    }
343
344    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        Self::compile_unit_pos_base_name(f)?;
346        write!(f, ".x")
347    }
348
349    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        Self::compile_unit_pos_base_name(f)?;
351        write!(f, ".y")
352    }
353
354    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        Self::compile_unit_pos_base_name(f)?;
356        write!(f, ".z")
357    }
358
359    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.write_str("warpSize")
361    }
362
363    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        f.write_str("warpSizeChecked")
365    }
366
367    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        let unit_pos_x = Variable::<D>::UnitPosX;
369        let plane_dim = Variable::<D>::PlaneDim;
370        write!(f, "{unit_pos_x} / {plane_dim}")
371    }
372
373    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        let absolute_pos = Variable::<D>::AbsolutePos;
375        let plane_dim = Variable::<D>::PlaneDim;
376        write!(f, "{absolute_pos} % {plane_dim}")
377    }
378
379    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        write!(f, "0")
381    }
382    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        write!(f, "0")
384    }
385    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        write!(f, "0")
387    }
388    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        write!(f, "0")
390    }
391}
392
393// Instructions
394
395pub trait DialectInstructions<D: Dialect> {
396    // atomics
397    fn compile_atomic_add(
398        f: &mut std::fmt::Formatter<'_>,
399        lhs: &Variable<D>,
400        rhs: &Variable<D>,
401        out: &Variable<D>,
402    ) -> std::fmt::Result {
403        let out = out.fmt_left();
404        match rhs.elem() {
405            Elem::I64 => writeln!(
406                f,
407                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
408                uint = Elem::<D>::U64
409            ),
410            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
411        }
412    }
413
414    fn compile_atomic_and(
415        f: &mut std::fmt::Formatter<'_>,
416        lhs: &Variable<D>,
417        rhs: &Variable<D>,
418        out: &Variable<D>,
419    ) -> std::fmt::Result {
420        let out = out.fmt_left();
421        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
422    }
423
424    fn compile_atomic_cas(
425        f: &mut std::fmt::Formatter<'_>,
426        input: &Variable<D>,
427        cmp: &Variable<D>,
428        val: &Variable<D>,
429        out: &Variable<D>,
430    ) -> std::fmt::Result {
431        let out = out.fmt_left();
432        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
433    }
434
435    fn compile_atomic_load(
436        f: &mut std::fmt::Formatter<'_>,
437        input: &Variable<D>,
438        out: &Variable<D>,
439    ) -> std::fmt::Result {
440        let out = out.fmt_left();
441        writeln!(f, "{out} = atomicAdd({input}, 0);")
442    }
443
444    fn compile_atomic_max(
445        f: &mut std::fmt::Formatter<'_>,
446        lhs: &Variable<D>,
447        rhs: &Variable<D>,
448        out: &Variable<D>,
449    ) -> std::fmt::Result {
450        let out = out.fmt_left();
451        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
452    }
453
454    fn compile_atomic_min(
455        f: &mut std::fmt::Formatter<'_>,
456        lhs: &Variable<D>,
457        rhs: &Variable<D>,
458        out: &Variable<D>,
459    ) -> std::fmt::Result {
460        let out = out.fmt_left();
461        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
462    }
463
464    fn compile_atomic_or(
465        f: &mut std::fmt::Formatter<'_>,
466        lhs: &Variable<D>,
467        rhs: &Variable<D>,
468        out: &Variable<D>,
469    ) -> std::fmt::Result {
470        let out = out.fmt_left();
471        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
472    }
473
474    fn compile_atomic_store(
475        f: &mut std::fmt::Formatter<'_>,
476        input: &Variable<D>,
477        out: &Variable<D>,
478    ) -> std::fmt::Result {
479        writeln!(f, "atomicExch({out}, {input});")
480    }
481
482    fn compile_atomic_sub(
483        f: &mut std::fmt::Formatter<'_>,
484        lhs: &Variable<D>,
485        rhs: &Variable<D>,
486        out: &Variable<D>,
487    ) -> std::fmt::Result {
488        let out = out.fmt_left();
489        match rhs.elem() {
490            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
491            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
492            Elem::I64 => writeln!(
493                f,
494                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
495                uint = Elem::<D>::U64
496            ),
497            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
498        }
499    }
500
501    fn compile_atomic_swap(
502        f: &mut std::fmt::Formatter<'_>,
503        lhs: &Variable<D>,
504        rhs: &Variable<D>,
505        out: &Variable<D>,
506    ) -> std::fmt::Result {
507        let out = out.fmt_left();
508        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
509    }
510
511    fn compile_atomic_xor(
512        f: &mut std::fmt::Formatter<'_>,
513        lhs: &Variable<D>,
514        rhs: &Variable<D>,
515        out: &Variable<D>,
516    ) -> std::fmt::Result {
517        let out = out.fmt_left();
518        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
519    }
520
521    fn compile_saturating_add(
522        f: &mut std::fmt::Formatter<'_>,
523        lhs: impl Display,
524        rhs: impl Display,
525        item: Item<D>,
526    ) -> std::fmt::Result;
527
528    fn compile_saturating_sub(
529        f: &mut std::fmt::Formatter<'_>,
530        lhs: impl Display,
531        rhs: impl Display,
532        item: Item<D>,
533    ) -> std::fmt::Result;
534
535    // debug
536    fn compile_instruction_printf(
537        f: &mut std::fmt::Formatter<'_>,
538        format_string: &str,
539        args: &[Variable<D>],
540    ) -> std::fmt::Result {
541        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
542        let args = match args.is_empty() {
543            true => "".to_string(),
544            false => format!(", {}", args.join(",")),
545        };
546        writeln!(f, "printf({format_string:?}{args});")
547    }
548
549    // logs
550    fn compile_instruction_log1p_scalar<T: Component<D>>(
551        f: &mut std::fmt::Formatter<'_>,
552        input: T,
553    ) -> std::fmt::Result {
554        let elem = input.elem();
555        match elem {
556            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557                write!(f, "{elem}(log1p(float({input})))")
558            }
559            _ => write!(f, "log1p({input})"),
560        }
561    }
562
563    // sync
564    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
565    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567
568    // trigo
569    fn compile_instruction_tanh_scalar<T: Component<D>>(
570        f: &mut std::fmt::Formatter<'_>,
571        input: T,
572    ) -> std::fmt::Result {
573        let elem = input.elem();
574        match elem {
575            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
576                write!(f, "{elem}(tanh(float({input})))")
577            }
578            _ => write!(f, "tanh({input})"),
579        }
580    }
581
582    // unary
583    fn compile_instruction_find_first_set<T: Component<D>>(
584        f: &mut std::fmt::Formatter<'_>,
585        input: T,
586        out_elem: Elem<D>,
587    ) -> std::fmt::Result;
588    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
589        f: &mut std::fmt::Formatter<'_>,
590        input: T,
591        out_elem: Elem<D>,
592    ) -> std::fmt::Result;
593
594    fn compile_instruction_popcount_scalar<T: Component<D>>(
595        f: &mut std::fmt::Formatter<'_>,
596        input: T,
597        out_elem: Elem<D>,
598    ) -> std::fmt::Result {
599        write!(f, "{out_elem}(")?;
600        match input.elem() {
601            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
602            Elem::U32 => write!(f, "__popc({input})"),
603            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
604            Elem::U64 => write!(f, "__popcll({input})"),
605            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
606        }?;
607        write!(f, ")")
608    }
609
610    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
611        f: &mut std::fmt::Formatter<'_>,
612        input: T,
613        out_elem: Elem<D>,
614    ) -> std::fmt::Result {
615        write!(f, "{out_elem}(")?;
616        match out_elem {
617            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
618            Elem::U32 => write!(f, "__brev({input})"),
619            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
620            Elem::U64 => write!(f, "__brevll({input})"),
621            _ => write!(
622                f,
623                "__brev({}) >> {}",
624                super::unary::zero_extend(input),
625                (size_of::<u32>() - out_elem.size()) * 8
626            ),
627        }?;
628        write!(f, ")")
629    }
630
631    // others
632    fn compile_instruction_max_function_name(
633        f: &mut std::fmt::Formatter<'_>,
634        item: Item<D>,
635    ) -> std::fmt::Result;
636
637    fn compile_instruction_min_function_name(
638        f: &mut std::fmt::Formatter<'_>,
639        item: Item<D>,
640    ) -> std::fmt::Result;
641
642    fn compile_instruction_powf(
643        f: &mut std::fmt::Formatter<'_>,
644        lhs: &str,
645        rhs: &str,
646        elem: Elem<D>,
647    ) -> std::fmt::Result {
648        match elem {
649            Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
650            Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
651            _ => panic!("Unsupported type for powf"),
652        }
653    }
654
655    fn compile_instruction_half_function_name_prefix() -> &'static str {
656        "h"
657    }
658
659    fn compile_instruction_half2_function_name_prefix() -> &'static str {
660        "h2"
661    }
662
663    // warp
664    fn compile_warp_shuffle(
665        f: &mut std::fmt::Formatter<'_>,
666        var: &str,
667        source: &str,
668    ) -> std::fmt::Result;
669    fn compile_warp_shuffle_xor(
670        f: &mut std::fmt::Formatter<'_>,
671        var: &str,
672        elem: &Elem<D>,
673        offset: &str,
674    ) -> std::fmt::Result;
675    fn compile_warp_shuffle_up(
676        f: &mut std::fmt::Formatter<'_>,
677        var: &str,
678        offset: &str,
679    ) -> std::fmt::Result;
680    fn compile_warp_shuffle_down(
681        f: &mut std::fmt::Formatter<'_>,
682        var: &str,
683        offset: &str,
684    ) -> std::fmt::Result;
685    fn compile_warp_all<T: Component<D>>(
686        f: &mut std::fmt::Formatter<'_>,
687        input: &T,
688    ) -> std::fmt::Result;
689    fn compile_warp_any<T: Component<D>>(
690        f: &mut std::fmt::Formatter<'_>,
691        input: &T,
692    ) -> std::fmt::Result;
693    fn compile_warp_ballot(
694        f: &mut std::fmt::Formatter<'_>,
695        input: &Variable<D>,
696        out_elem: &Elem<D>,
697    ) -> std::fmt::Result;
698    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
699        write!(
700            f,
701            "
702unsigned int mask = __activemask();
703unsigned int leader = __ffs(mask) - 1;
704{out} = threadIdx.x % warpSize == leader;
705            "
706        )
707    }
708}
709
710#[derive(Debug, Clone, Copy, new)]
711pub struct ManualMma<'a, D: Dialect> {
712    pub shape: MmaShape<D>,
713    pub frag_a: &'a Variable<D>,
714    pub frag_b: &'a Variable<D>,
715    pub frag_c: &'a Variable<D>,
716    pub frag_d: &'a Variable<D>,
717}
718
719pub trait DialectWarpReduceCompiler<D: Dialect>:
720    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
721{
722    fn warp_reduce_sum(
723        f: &mut core::fmt::Formatter<'_>,
724        input: &Variable<D>,
725        out: &Variable<D>,
726    ) -> core::fmt::Result {
727        reduce_operator(f, input, out, "+=")
728    }
729    fn warp_reduce_prod(
730        f: &mut core::fmt::Formatter<'_>,
731        input: &Variable<D>,
732        out: &Variable<D>,
733    ) -> core::fmt::Result {
734        reduce_operator(f, input, out, "*=")
735    }
736    fn warp_reduce_max(
737        f: &mut core::fmt::Formatter<'_>,
738        input: &Variable<D>,
739        out: &Variable<D>,
740    ) -> core::fmt::Result {
741        reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
742    }
743    fn warp_reduce_min(
744        f: &mut core::fmt::Formatter<'_>,
745        input: &Variable<D>,
746        out: &Variable<D>,
747    ) -> core::fmt::Result {
748        reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
749    }
750    fn warp_reduce_all(
751        f: &mut core::fmt::Formatter<'_>,
752        input: &Variable<D>,
753        out: &Variable<D>,
754    ) -> core::fmt::Result {
755        reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
756    }
757    fn warp_reduce_any(
758        f: &mut core::fmt::Formatter<'_>,
759        input: &Variable<D>,
760        out: &Variable<D>,
761    ) -> core::fmt::Result {
762        reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
763    }
764    fn warp_reduce_sum_inclusive(
765        f: &mut core::fmt::Formatter<'_>,
766        input: &Variable<D>,
767        out: &Variable<D>,
768    ) -> core::fmt::Result {
769        reduce_inclusive(f, input, out, "+=")
770    }
771    fn warp_reduce_prod_inclusive(
772        f: &mut core::fmt::Formatter<'_>,
773        input: &Variable<D>,
774        out: &Variable<D>,
775    ) -> core::fmt::Result {
776        reduce_inclusive(f, input, out, "*=")
777    }
778    fn warp_reduce_sum_exclusive(
779        f: &mut core::fmt::Formatter<'_>,
780        input: &Variable<D>,
781        out: &Variable<D>,
782    ) -> core::fmt::Result {
783        reduce_exclusive(f, input, out, "+=", "0")
784    }
785    fn warp_reduce_prod_exclusive(
786        f: &mut core::fmt::Formatter<'_>,
787        input: &Variable<D>,
788        out: &Variable<D>,
789    ) -> core::fmt::Result {
790        reduce_exclusive(f, input, out, "*=", "1")
791    }
792}
793
794pub trait DialectWmmaCompiler<D: Dialect>:
795    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
796{
797    #[allow(unused_variables)]
798    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
799        Ok(())
800    }
801    #[allow(unused_variables)]
802    fn compile_wmma_type_definitions(
803        f: &mut std::fmt::Formatter<'_>,
804        flags: &Flags,
805    ) -> std::fmt::Result {
806        Ok(())
807    }
808    #[allow(unused_variables)]
809    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
810        Ok(())
811    }
812    #[allow(unused_variables)]
813    fn compile_wwma_fragment_ident(
814        f: &mut std::fmt::Formatter<'_>,
815        ident: &FragmentIdent<D>,
816    ) -> std::fmt::Result {
817        Ok(())
818    }
819    #[allow(unused_variables)]
820    fn compile_wmma_fragment_layout(
821        f: &mut std::fmt::Formatter<'_>,
822        layout: &FragmentLayout<D>,
823    ) -> std::fmt::Result {
824        Ok(())
825    }
826    #[allow(unused_variables)]
827    fn compile_wmma_fragment(
828        f: &mut std::fmt::Formatter<'_>,
829        fragment: &Fragment<D>,
830    ) -> std::fmt::Result {
831        Ok(())
832    }
833
834    fn compile_wmma_fragment_declaration(
835        f: &mut std::fmt::Formatter<'_>,
836        var: &Variable<D>,
837    ) -> std::fmt::Result;
838
839    fn compile_wmma_instruction(
840        f: &mut std::fmt::Formatter<'_>,
841        instruction: &WmmaInstruction<D>,
842    ) -> std::fmt::Result;
843    fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
844    fn compile_scaled_mma(
845        f: &mut std::fmt::Formatter<'_>,
846        mma: ManualMma<D>,
847        scales_a: Variable<D>,
848        scales_b: Variable<D>,
849        scales_factor: u32,
850    ) -> std::fmt::Result;
851    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
852    fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
853    fn supported_scaled_mma_combinations(
854        _arch: &D::Architecture,
855    ) -> SupportedScaledMmaCombinations {
856        Vec::new()
857    }
858}
859
860/// IR Processors to be applied to the scopes during processing. [`CheckedIO`] is always applied
861/// by default, so these are only for target specific processors like MMA index processors.
862pub trait DialectProcessors<D: Dialect> {
863    fn processors() -> Vec<Box<dyn Processor>>;
864}