cubecl_cpp/shared/
dialect.rs

1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7    FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8    reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12    Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14    WmmaInstruction,
15};
16
17// Base dialect
18
19pub trait Dialect:
20    DialectIncludes<Self>
21    + DialectTypes<Self>
22    + DialectBindings<Self>
23    + DialectWarpReduceCompiler<Self>
24    + DialectCubeBuiltins<Self>
25    + DialectInstructions<Self>
26    + DialectWmmaCompiler<Self>
27    + DialectProcessors<Self>
28    + Default
29    + Clone
30    + Copy
31    + Debug
32    + Send
33    + Sync
34    + Eq
35    + Hash
36    + 'static
37{
38    type Architecture: Architecture;
39}
40
41// Includes
42
43pub trait DialectIncludes<D: Dialect> {
44    type Extension: Debug + Clone + Sync + Send;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47    fn compile_extensions(
48        f: &mut std::fmt::Formatter<'_>,
49        extensions: &[Self::Extension],
50    ) -> std::fmt::Result;
51    fn register_instruction_extension(
52        extensions: &mut Vec<Self::Extension>,
53        instruction: &Instruction<D>,
54    );
55    fn register_warp_instruction_extension(
56        extensions: &mut Vec<Self::Extension>,
57        instruction: &WarpInstruction<D>,
58    );
59    #[allow(unused_variables)]
60    fn register_wmma_instruction_extension(
61        extensions: &mut Vec<Self::Extension>,
62        instruction: &WmmaInstruction<D>,
63    ) {
64    }
65}
66
67// Types
68
69pub trait DialectTypes<D: Dialect> {
70    fn item_can_be_optimized() -> bool;
71    fn compile_elem(
72        f: &mut std::fmt::Formatter<'_>,
73        elem: &Elem<D>,
74        word: bool,
75    ) -> std::fmt::Result;
76
77    fn compile_atomic_kind(
78        f: &mut std::fmt::Formatter<'_>,
79        kind: &AtomicKind<D>,
80    ) -> std::fmt::Result {
81        match kind {
82            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90            AtomicKind::_Dialect(_) => Ok(()),
91        }
92    }
93
94    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95    fn compile_type_definitions(
96        f: &mut std::fmt::Formatter<'_>,
97        items: &HashSet<Item<D>>,
98        scalars: &[(Elem<D>, usize)],
99        flags: &Flags,
100    ) -> std::fmt::Result;
101    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102    fn compile_shared_memory_declaration(
103        f: &mut std::fmt::Formatter<'_>,
104        shared: &SharedMemory<D>,
105    ) -> std::fmt::Result {
106        match shared {
107            SharedMemory::Array {
108                index,
109                item,
110                length,
111                offset,
112                ..
113            } => {
114                let size_bytes = *length * item.size() as u32;
115                writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116                writeln!(
117                    f,
118                    "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119                )
120            }
121            SharedMemory::Value {
122                index,
123                item,
124                offset,
125                ..
126            } => {
127                let size_bytes = item.size() as u32;
128                writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129                writeln!(
130                    f,
131                    "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132                )
133            }
134        }
135    }
136    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
137        Ok(())
138    }
139    /// Address space (for Metal dialect only).
140    fn address_space_for_variable(_variable: &Variable<D>) -> String {
141        "".to_string()
142    }
143}
144
145// Kernel argument bindings
146
147pub trait DialectBindings<D: Dialect> {
148    fn compile_kernel_signature(
149        f: &mut std::fmt::Formatter<'_>,
150        kernel_name: &str,
151        tensor_maps: &[Binding<D>],
152        buffers: &[Binding<D>],
153        scalars: &[(Elem<D>, usize)],
154        flags: &Flags,
155    ) -> std::fmt::Result;
156    fn compile_bindings_body(
157        _f: &mut std::fmt::Formatter<'_>,
158        _body: &Body<D>,
159    ) -> std::fmt::Result {
160        Ok(())
161    }
162}
163
164// Cube builtins dialect
165
166pub trait DialectCubeBuiltins<D: Dialect> {
167    /// Depending on the dialect available built-in variables the
168    /// inclusion rules might change.
169    /// For instance in metal we have a built-in for the Unit plane position
170    /// but in other dialects there is none so we have to compute it using
171    /// other built-ins.
172    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173        let unit_pos_plane = flags.unit_pos_plane;
174        let plane_dim_checked = flags.plane_dim_checked;
175        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176        let plane_index = flags.plane_index;
177        let absolute_pos = flags.absolute_pos || unit_pos_plane;
178        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179        let cube_dim = flags.cube_dim;
180        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181        let unit_pos = flags.unit_pos;
182        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183        let cube_count = flags.cube_count;
184        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185        let cube_pos = flags.cube_pos;
186        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187        let cluster_group = flags.cluster_pos;
188
189        CubeIndexFlags {
190            absolute_pos,
191            absolute_pos_tuple,
192            cube_count,
193            cube_count_tuple,
194            cube_dim,
195            cube_dim_tuple,
196            cube_pos,
197            cube_pos_tuple,
198            plane_dim,
199            plane_dim_checked,
200            plane_index,
201            unit_pos_tuple,
202            unit_pos,
203            unit_pos_plane,
204            cluster_pos: cluster_group,
205        }
206    }
207
208    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        let variable = Variable::<D>::AbsolutePosBaseName;
210        let ty = variable.item();
211        let cube_pos_x = Variable::<D>::CubePosX;
212        let cube_pos_y = Variable::<D>::CubePosY;
213        let cube_pos_z = Variable::<D>::CubePosZ;
214        let cube_dim_x = Variable::<D>::CubeDimX;
215        let cube_dim_y = Variable::<D>::CubeDimY;
216        let cube_dim_z = Variable::<D>::CubeDimZ;
217        let unit_pos_x = Variable::<D>::UnitPosX;
218        let unit_pos_y = Variable::<D>::UnitPosY;
219        let unit_pos_z = Variable::<D>::UnitPosZ;
220        writeln!(
221            f,
222            "{ty} {variable} = make_{ty}(
223    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227        )
228    }
229
230    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.write_str("absoluteIdx")
232    }
233
234    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.write_str("idxGlobal")
236    }
237
238    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        Self::compile_absolute_pos_base_name(f)?;
240        write!(f, ".x")
241    }
242
243    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        Self::compile_absolute_pos_base_name(f)?;
245        write!(f, ".y")
246    }
247
248    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        Self::compile_absolute_pos_base_name(f)?;
250        write!(f, ".z")
251    }
252
253    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.write_str("gridDim")
255    }
256
257    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.write_str("gridDimGlobal")
259    }
260
261    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        Self::compile_cube_count_base_name(f)?;
263        write!(f, ".x")
264    }
265
266    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        Self::compile_cube_count_base_name(f)?;
268        write!(f, ".y")
269    }
270
271    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        Self::compile_cube_count_base_name(f)?;
273        write!(f, ".z")
274    }
275
276    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.write_str("blockDim")
278    }
279
280    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.write_str("blockDimGlobal")
282    }
283
284    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        Self::compile_cube_dim_base_name(f)?;
286        write!(f, ".x")
287    }
288
289    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290        Self::compile_cube_dim_base_name(f)?;
291        write!(f, ".y")
292    }
293
294    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        Self::compile_cube_dim_base_name(f)?;
296        write!(f, ".z")
297    }
298
299    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.write_str("blockIdx")
301    }
302
303    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        f.write_str("blockIdxGlobal")
305    }
306
307    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        Self::compile_cube_pos_base_name(f)?;
309        write!(f, ".x")
310    }
311
312    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        Self::compile_cube_pos_base_name(f)?;
314        write!(f, ".y")
315    }
316
317    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        Self::compile_cube_pos_base_name(f)?;
319        write!(f, ".z")
320    }
321
322    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        let variable = Variable::<D>::UnitPos;
324        let ty = variable.item();
325        let cube_dim_x = Variable::<D>::CubeDimX;
326        let cube_dim_y = Variable::<D>::CubeDimY;
327        let unit_pos_x = Variable::<D>::UnitPosX;
328        let unit_pos_y = Variable::<D>::UnitPosY;
329        let unit_pos_z = Variable::<D>::UnitPosZ;
330        writeln!(
331            f,
332            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333        )
334    }
335
336    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.write_str("threadIdxGlobal")
338    }
339
340    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.write_str("threadIdx")
342    }
343
344    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        Self::compile_unit_pos_base_name(f)?;
346        write!(f, ".x")
347    }
348
349    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        Self::compile_unit_pos_base_name(f)?;
351        write!(f, ".y")
352    }
353
354    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        Self::compile_unit_pos_base_name(f)?;
356        write!(f, ".z")
357    }
358
359    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.write_str("warpSize")
361    }
362
363    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        f.write_str("warpSizeChecked")
365    }
366
367    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        let unit_pos_x = Variable::<D>::UnitPosX;
369        let plane_dim = Variable::<D>::PlaneDim;
370        write!(f, "{unit_pos_x} / {plane_dim}")
371    }
372
373    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        let absolute_pos = Variable::<D>::AbsolutePos;
375        let plane_dim = Variable::<D>::PlaneDim;
376        write!(f, "{absolute_pos} % {plane_dim}")
377    }
378
379    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        write!(f, "0")
381    }
382    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        write!(f, "0")
384    }
385    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        write!(f, "0")
387    }
388    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        write!(f, "0")
390    }
391}
392
393// Instructions
394
395pub trait DialectInstructions<D: Dialect> {
396    // atomics
397    fn compile_atomic_add(
398        f: &mut std::fmt::Formatter<'_>,
399        lhs: &Variable<D>,
400        rhs: &Variable<D>,
401        out: &Variable<D>,
402    ) -> std::fmt::Result {
403        let out = out.fmt_left();
404        match rhs.elem() {
405            Elem::I64 => writeln!(
406                f,
407                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
408                uint = Elem::<D>::U64
409            ),
410            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
411        }
412    }
413
414    fn compile_atomic_and(
415        f: &mut std::fmt::Formatter<'_>,
416        lhs: &Variable<D>,
417        rhs: &Variable<D>,
418        out: &Variable<D>,
419    ) -> std::fmt::Result {
420        let out = out.fmt_left();
421        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
422    }
423
424    fn compile_atomic_cas(
425        f: &mut std::fmt::Formatter<'_>,
426        input: &Variable<D>,
427        cmp: &Variable<D>,
428        val: &Variable<D>,
429        out: &Variable<D>,
430    ) -> std::fmt::Result {
431        let out = out.fmt_left();
432        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
433    }
434
435    fn compile_atomic_load(
436        f: &mut std::fmt::Formatter<'_>,
437        input: &Variable<D>,
438        out: &Variable<D>,
439    ) -> std::fmt::Result {
440        let out = out.fmt_left();
441        writeln!(f, "{out} = atomicAdd({input}, 0);")
442    }
443
444    fn compile_atomic_max(
445        f: &mut std::fmt::Formatter<'_>,
446        lhs: &Variable<D>,
447        rhs: &Variable<D>,
448        out: &Variable<D>,
449    ) -> std::fmt::Result {
450        let out = out.fmt_left();
451        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
452    }
453
454    fn compile_atomic_min(
455        f: &mut std::fmt::Formatter<'_>,
456        lhs: &Variable<D>,
457        rhs: &Variable<D>,
458        out: &Variable<D>,
459    ) -> std::fmt::Result {
460        let out = out.fmt_left();
461        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
462    }
463
464    fn compile_atomic_or(
465        f: &mut std::fmt::Formatter<'_>,
466        lhs: &Variable<D>,
467        rhs: &Variable<D>,
468        out: &Variable<D>,
469    ) -> std::fmt::Result {
470        let out = out.fmt_left();
471        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
472    }
473
474    fn compile_atomic_store(
475        f: &mut std::fmt::Formatter<'_>,
476        input: &Variable<D>,
477        out: &Variable<D>,
478    ) -> std::fmt::Result {
479        writeln!(f, "atomicExch({out}, {input});")
480    }
481
482    fn compile_atomic_sub(
483        f: &mut std::fmt::Formatter<'_>,
484        lhs: &Variable<D>,
485        rhs: &Variable<D>,
486        out: &Variable<D>,
487    ) -> std::fmt::Result {
488        let out = out.fmt_left();
489        match rhs.elem() {
490            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
491            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
492            Elem::I64 => writeln!(
493                f,
494                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
495                uint = Elem::<D>::U64
496            ),
497            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
498        }
499    }
500
501    fn compile_atomic_swap(
502        f: &mut std::fmt::Formatter<'_>,
503        lhs: &Variable<D>,
504        rhs: &Variable<D>,
505        out: &Variable<D>,
506    ) -> std::fmt::Result {
507        let out = out.fmt_left();
508        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
509    }
510
511    fn compile_atomic_xor(
512        f: &mut std::fmt::Formatter<'_>,
513        lhs: &Variable<D>,
514        rhs: &Variable<D>,
515        out: &Variable<D>,
516    ) -> std::fmt::Result {
517        let out = out.fmt_left();
518        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
519    }
520
521    fn compile_saturating_add(
522        f: &mut std::fmt::Formatter<'_>,
523        lhs: impl Display,
524        rhs: impl Display,
525        item: Item<D>,
526    ) -> std::fmt::Result;
527
528    fn compile_saturating_sub(
529        f: &mut std::fmt::Formatter<'_>,
530        lhs: impl Display,
531        rhs: impl Display,
532        item: Item<D>,
533    ) -> std::fmt::Result;
534
535    // debug
536    fn compile_instruction_printf(
537        f: &mut std::fmt::Formatter<'_>,
538        format_string: &str,
539        args: &[Variable<D>],
540    ) -> std::fmt::Result {
541        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
542        let args = match args.is_empty() {
543            true => "".to_string(),
544            false => format!(", {}", args.join(",")),
545        };
546        writeln!(f, "printf({format_string:?}{args});")
547    }
548
549    // logs
550    fn compile_instruction_log1p_scalar<T: Component<D>>(
551        f: &mut std::fmt::Formatter<'_>,
552        input: T,
553    ) -> std::fmt::Result {
554        let elem = input.elem();
555        match elem {
556            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557                write!(f, "{elem}(log1p(float({input})))")
558            }
559            _ => write!(f, "log1p({input})"),
560        }
561    }
562
563    // sync
564    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
565    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567
568    // trigo
569    fn compile_instruction_tanh_scalar<T: Component<D>>(
570        f: &mut std::fmt::Formatter<'_>,
571        input: T,
572    ) -> std::fmt::Result {
573        let elem = input.elem();
574        match elem {
575            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
576                write!(f, "{elem}(tanh(float({input})))")
577            }
578            _ => write!(f, "tanh({input})"),
579        }
580    }
581
582    // unary
583    fn compile_instruction_find_first_set<T: Component<D>>(
584        f: &mut std::fmt::Formatter<'_>,
585        input: T,
586        out_elem: Elem<D>,
587    ) -> std::fmt::Result;
588    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
589        f: &mut std::fmt::Formatter<'_>,
590        input: T,
591        out_elem: Elem<D>,
592    ) -> std::fmt::Result;
593
594    fn compile_instruction_popcount_scalar<T: Component<D>>(
595        f: &mut std::fmt::Formatter<'_>,
596        input: T,
597        out_elem: Elem<D>,
598    ) -> std::fmt::Result {
599        write!(f, "{out_elem}(")?;
600        match input.elem() {
601            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
602            Elem::U32 => write!(f, "__popc({input})"),
603            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
604            Elem::U64 => write!(f, "__popcll({input})"),
605            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
606        }?;
607        write!(f, ")")
608    }
609
610    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
611        f: &mut std::fmt::Formatter<'_>,
612        input: T,
613        out_elem: Elem<D>,
614    ) -> std::fmt::Result {
615        write!(f, "{out_elem}(")?;
616        match out_elem {
617            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
618            Elem::U32 => write!(f, "__brev({input})"),
619            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
620            Elem::U64 => write!(f, "__brevll({input})"),
621            _ => write!(
622                f,
623                "__brev({}) >> {}",
624                super::unary::zero_extend(input),
625                (size_of::<u32>() - out_elem.size()) * 8
626            ),
627        }?;
628        write!(f, ")")
629    }
630
631    // others
632    fn compile_instruction_max_function_name(
633        f: &mut std::fmt::Formatter<'_>,
634        item: Item<D>,
635    ) -> std::fmt::Result;
636
637    fn compile_instruction_min_function_name(
638        f: &mut std::fmt::Formatter<'_>,
639        item: Item<D>,
640    ) -> std::fmt::Result;
641
642    fn compile_instruction_powf(
643        f: &mut std::fmt::Formatter<'_>,
644        lhs: &str,
645        rhs: &str,
646        elem: Elem<D>,
647    ) -> std::fmt::Result {
648        match elem {
649            Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
650            Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
651            _ => write!(f, "#error Unsupported type for powf: {elem}"),
652        }
653    }
654
655    fn compile_instruction_hypot(
656        f: &mut std::fmt::Formatter<'_>,
657        lhs: &str,
658        rhs: &str,
659        elem: Elem<D>,
660    ) -> std::fmt::Result {
661        match elem {
662            Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
663            Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
664            _ => write!(f, "#error Unsupported type for hypot: {elem}"),
665        }
666    }
667
668    fn compile_instruction_rhypot(
669        f: &mut std::fmt::Formatter<'_>,
670        lhs: &str,
671        rhs: &str,
672        elem: Elem<D>,
673    ) -> std::fmt::Result {
674        match elem {
675            Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
676            Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
677            _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
678        }
679    }
680
681    fn compile_instruction_half_function_name_prefix() -> &'static str {
682        "h"
683    }
684
685    fn compile_instruction_half2_function_name_prefix() -> &'static str {
686        "h2"
687    }
688
689    // warp
690    fn compile_warp_shuffle(
691        f: &mut std::fmt::Formatter<'_>,
692        var: &str,
693        source: &str,
694    ) -> std::fmt::Result;
695    fn compile_warp_shuffle_xor(
696        f: &mut std::fmt::Formatter<'_>,
697        var: &str,
698        elem: &Elem<D>,
699        offset: &str,
700    ) -> std::fmt::Result;
701    fn compile_warp_shuffle_up(
702        f: &mut std::fmt::Formatter<'_>,
703        var: &str,
704        offset: &str,
705    ) -> std::fmt::Result;
706    fn compile_warp_shuffle_down(
707        f: &mut std::fmt::Formatter<'_>,
708        var: &str,
709        offset: &str,
710    ) -> std::fmt::Result;
711    fn compile_warp_all<T: Component<D>>(
712        f: &mut std::fmt::Formatter<'_>,
713        input: &T,
714    ) -> std::fmt::Result;
715    fn compile_warp_any<T: Component<D>>(
716        f: &mut std::fmt::Formatter<'_>,
717        input: &T,
718    ) -> std::fmt::Result;
719    fn compile_warp_ballot(
720        f: &mut std::fmt::Formatter<'_>,
721        input: &Variable<D>,
722        out_elem: &Elem<D>,
723    ) -> std::fmt::Result;
724    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
725        write!(
726            f,
727            "
728unsigned int mask = __activemask();
729unsigned int leader = __ffs(mask) - 1;
730{out} = threadIdx.x % warpSize == leader;
731            "
732        )
733    }
734}
735
736#[derive(Debug, Clone, Copy, new)]
737pub struct ManualMma<'a, D: Dialect> {
738    pub shape: MmaShape<D>,
739    pub frag_a: &'a Variable<D>,
740    pub frag_b: &'a Variable<D>,
741    pub frag_c: &'a Variable<D>,
742    pub frag_d: &'a Variable<D>,
743}
744
745pub trait DialectWarpReduceCompiler<D: Dialect>:
746    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
747{
748    fn warp_reduce_sum(
749        f: &mut core::fmt::Formatter<'_>,
750        input: &Variable<D>,
751        out: &Variable<D>,
752    ) -> core::fmt::Result {
753        reduce_operator(f, input, out, "+=")
754    }
755    fn warp_reduce_prod(
756        f: &mut core::fmt::Formatter<'_>,
757        input: &Variable<D>,
758        out: &Variable<D>,
759    ) -> core::fmt::Result {
760        reduce_operator(f, input, out, "*=")
761    }
762    fn warp_reduce_max(
763        f: &mut core::fmt::Formatter<'_>,
764        input: &Variable<D>,
765        out: &Variable<D>,
766    ) -> core::fmt::Result {
767        reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
768    }
769    fn warp_reduce_min(
770        f: &mut core::fmt::Formatter<'_>,
771        input: &Variable<D>,
772        out: &Variable<D>,
773    ) -> core::fmt::Result {
774        reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
775    }
776    fn warp_reduce_all(
777        f: &mut core::fmt::Formatter<'_>,
778        input: &Variable<D>,
779        out: &Variable<D>,
780    ) -> core::fmt::Result {
781        reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
782    }
783    fn warp_reduce_any(
784        f: &mut core::fmt::Formatter<'_>,
785        input: &Variable<D>,
786        out: &Variable<D>,
787    ) -> core::fmt::Result {
788        reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
789    }
790    fn warp_reduce_sum_inclusive(
791        f: &mut core::fmt::Formatter<'_>,
792        input: &Variable<D>,
793        out: &Variable<D>,
794    ) -> core::fmt::Result {
795        reduce_inclusive(f, input, out, "+=")
796    }
797    fn warp_reduce_prod_inclusive(
798        f: &mut core::fmt::Formatter<'_>,
799        input: &Variable<D>,
800        out: &Variable<D>,
801    ) -> core::fmt::Result {
802        reduce_inclusive(f, input, out, "*=")
803    }
804    fn warp_reduce_sum_exclusive(
805        f: &mut core::fmt::Formatter<'_>,
806        input: &Variable<D>,
807        out: &Variable<D>,
808    ) -> core::fmt::Result {
809        reduce_exclusive(f, input, out, "+=", "0")
810    }
811    fn warp_reduce_prod_exclusive(
812        f: &mut core::fmt::Formatter<'_>,
813        input: &Variable<D>,
814        out: &Variable<D>,
815    ) -> core::fmt::Result {
816        reduce_exclusive(f, input, out, "*=", "1")
817    }
818}
819
820pub trait DialectWmmaCompiler<D: Dialect>:
821    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
822{
823    #[allow(unused_variables)]
824    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
825        Ok(())
826    }
827    #[allow(unused_variables)]
828    fn compile_wmma_type_definitions(
829        f: &mut std::fmt::Formatter<'_>,
830        flags: &Flags,
831    ) -> std::fmt::Result {
832        Ok(())
833    }
834    #[allow(unused_variables)]
835    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836        Ok(())
837    }
838    #[allow(unused_variables)]
839    fn compile_wwma_fragment_ident(
840        f: &mut std::fmt::Formatter<'_>,
841        ident: &FragmentIdent<D>,
842    ) -> std::fmt::Result {
843        Ok(())
844    }
845    #[allow(unused_variables)]
846    fn compile_wmma_fragment_layout(
847        f: &mut std::fmt::Formatter<'_>,
848        layout: &FragmentLayout<D>,
849    ) -> std::fmt::Result {
850        Ok(())
851    }
852    #[allow(unused_variables)]
853    fn compile_wmma_fragment(
854        f: &mut std::fmt::Formatter<'_>,
855        fragment: &Fragment<D>,
856    ) -> std::fmt::Result {
857        Ok(())
858    }
859
860    fn compile_wmma_fragment_declaration(
861        f: &mut std::fmt::Formatter<'_>,
862        var: &Variable<D>,
863    ) -> std::fmt::Result;
864
865    fn compile_wmma_instruction(
866        f: &mut std::fmt::Formatter<'_>,
867        instruction: &WmmaInstruction<D>,
868    ) -> std::fmt::Result;
869    fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
870    fn compile_scaled_mma(
871        f: &mut std::fmt::Formatter<'_>,
872        mma: ManualMma<D>,
873        scales_a: Variable<D>,
874        scales_b: Variable<D>,
875        scales_factor: u32,
876    ) -> std::fmt::Result;
877    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
878    fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
879    fn supported_scaled_mma_combinations(
880        _arch: &D::Architecture,
881    ) -> SupportedScaledMmaCombinations {
882        Vec::new()
883    }
884}
885
886/// IR Processors to be applied to the scopes during processing. [`CheckedIO`] is always applied
887/// by default, so these are only for target specific processors like MMA index processors.
888pub trait DialectProcessors<D: Dialect> {
889    fn processors() -> Vec<Box<dyn Processor>>;
890}