cubecl_cpp/shared/
dialect.rs

1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7    FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8    reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12    Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14    WmmaInstruction,
15};
16
17// Base dialect
18
19pub trait Dialect:
20    DialectIncludes<Self>
21    + DialectTypes<Self>
22    + DialectBindings<Self>
23    + DialectWarpReduceCompiler<Self>
24    + DialectCubeBuiltins<Self>
25    + DialectInstructions<Self>
26    + DialectWmmaCompiler<Self>
27    + DialectProcessors<Self>
28    + Default
29    + Clone
30    + Copy
31    + Debug
32    + Send
33    + Sync
34    + Eq
35    + Hash
36    + 'static
37{
38    type Architecture: Architecture;
39}
40
41// Includes
42
43pub trait DialectIncludes<D: Dialect> {
44    type Extension: Debug + Clone + Sync + Send;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47    fn compile_extensions(
48        f: &mut std::fmt::Formatter<'_>,
49        extensions: &[Self::Extension],
50    ) -> std::fmt::Result;
51    fn register_instruction_extension(
52        extensions: &mut Vec<Self::Extension>,
53        instruction: &Instruction<D>,
54    );
55    fn register_warp_instruction_extension(
56        extensions: &mut Vec<Self::Extension>,
57        instruction: &WarpInstruction<D>,
58    );
59    #[allow(unused_variables)]
60    fn register_wmma_instruction_extension(
61        extensions: &mut Vec<Self::Extension>,
62        instruction: &WmmaInstruction<D>,
63    ) {
64    }
65}
66
67// Types
68
69pub trait DialectTypes<D: Dialect> {
70    fn item_can_be_optimized() -> bool;
71    fn compile_elem(
72        f: &mut std::fmt::Formatter<'_>,
73        elem: &Elem<D>,
74        word: bool,
75    ) -> std::fmt::Result;
76
77    fn compile_atomic_kind(
78        f: &mut std::fmt::Formatter<'_>,
79        kind: &AtomicKind<D>,
80    ) -> std::fmt::Result {
81        match kind {
82            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90            AtomicKind::_Dialect(_) => Ok(()),
91        }
92    }
93
94    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95    fn compile_type_definitions(
96        f: &mut std::fmt::Formatter<'_>,
97        items: &HashSet<Item<D>>,
98        scalars: &[(Elem<D>, usize)],
99        flags: &Flags,
100    ) -> std::fmt::Result;
101    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102    fn compile_shared_memory_declaration(
103        f: &mut std::fmt::Formatter<'_>,
104        shared: &SharedMemory<D>,
105    ) -> std::fmt::Result {
106        let item = shared.item;
107        let index = shared.index;
108        let offset = shared.offset;
109        let size = shared.length;
110        let size_bytes = size * shared.item.size() as u32;
111        writeln!(f, "// Shared memory size: {size}, {size_bytes} bytes")?;
112        writeln!(
113            f,
114            "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
115        )
116    }
117    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
118        Ok(())
119    }
120    /// Address space (for Metal dialect only).
121    fn address_space_for_variable(_variable: &Variable<D>) -> String {
122        "".to_string()
123    }
124}
125
126// Kernel argument bindings
127
128pub trait DialectBindings<D: Dialect> {
129    fn compile_kernel_signature(
130        f: &mut std::fmt::Formatter<'_>,
131        kernel_name: &str,
132        tensor_maps: &[Binding<D>],
133        buffers: &[Binding<D>],
134        scalars: &[(Elem<D>, usize)],
135        flags: &Flags,
136    ) -> std::fmt::Result;
137    fn compile_bindings_body(
138        _f: &mut std::fmt::Formatter<'_>,
139        _body: &Body<D>,
140    ) -> std::fmt::Result {
141        Ok(())
142    }
143}
144
145// Cube builtins dialect
146
147pub trait DialectCubeBuiltins<D: Dialect> {
148    /// Depending on the dialect available built-in variables the
149    /// inclusion rules might change.
150    /// For instance in metal we have a built-in for the Unit plane position
151    /// but in other dialects there is none so we have to compute it using
152    /// other built-ins.
153    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
154        let unit_pos_plane = flags.unit_pos_plane;
155        let plane_dim_checked = flags.plane_dim_checked;
156        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
157        let plane_index = flags.plane_index;
158        let absolute_pos = flags.absolute_pos || unit_pos_plane;
159        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
160        let cube_dim = flags.cube_dim;
161        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
162        let unit_pos = flags.unit_pos;
163        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
164        let cube_count = flags.cube_count;
165        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
166        let cube_pos = flags.cube_pos;
167        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
168        let cluster_group = flags.cluster_pos;
169
170        CubeIndexFlags {
171            absolute_pos,
172            absolute_pos_tuple,
173            cube_count,
174            cube_count_tuple,
175            cube_dim,
176            cube_dim_tuple,
177            cube_pos,
178            cube_pos_tuple,
179            plane_dim,
180            plane_dim_checked,
181            plane_index,
182            unit_pos_tuple,
183            unit_pos,
184            unit_pos_plane,
185            cluster_pos: cluster_group,
186        }
187    }
188
189    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        let variable = Variable::<D>::AbsolutePosBaseName;
191        let ty = variable.item();
192        let cube_pos_x = Variable::<D>::CubePosX;
193        let cube_pos_y = Variable::<D>::CubePosY;
194        let cube_pos_z = Variable::<D>::CubePosZ;
195        let cube_dim_x = Variable::<D>::CubeDimX;
196        let cube_dim_y = Variable::<D>::CubeDimY;
197        let cube_dim_z = Variable::<D>::CubeDimZ;
198        let unit_pos_x = Variable::<D>::UnitPosX;
199        let unit_pos_y = Variable::<D>::UnitPosY;
200        let unit_pos_z = Variable::<D>::UnitPosZ;
201        writeln!(
202            f,
203            "{ty} {variable} = make_{ty}(
204    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
205    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
206    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
207);"
208        )
209    }
210
211    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        f.write_str("absoluteIdx")
213    }
214
215    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.write_str("idxGlobal")
217    }
218
219    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        Self::compile_absolute_pos_base_name(f)?;
221        write!(f, ".x")
222    }
223
224    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        Self::compile_absolute_pos_base_name(f)?;
226        write!(f, ".y")
227    }
228
229    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        Self::compile_absolute_pos_base_name(f)?;
231        write!(f, ".z")
232    }
233
234    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.write_str("gridDim")
236    }
237
238    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.write_str("gridDimGlobal")
240    }
241
242    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        Self::compile_cube_count_base_name(f)?;
244        write!(f, ".x")
245    }
246
247    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        Self::compile_cube_count_base_name(f)?;
249        write!(f, ".y")
250    }
251
252    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        Self::compile_cube_count_base_name(f)?;
254        write!(f, ".z")
255    }
256
257    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        f.write_str("blockDim")
259    }
260
261    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.write_str("blockDimGlobal")
263    }
264
265    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        Self::compile_cube_dim_base_name(f)?;
267        write!(f, ".x")
268    }
269
270    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        Self::compile_cube_dim_base_name(f)?;
272        write!(f, ".y")
273    }
274
275    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        Self::compile_cube_dim_base_name(f)?;
277        write!(f, ".z")
278    }
279
280    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.write_str("blockIdx")
282    }
283
284    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.write_str("blockIdxGlobal")
286    }
287
288    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        Self::compile_cube_pos_base_name(f)?;
290        write!(f, ".x")
291    }
292
293    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        Self::compile_cube_pos_base_name(f)?;
295        write!(f, ".y")
296    }
297
298    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        Self::compile_cube_pos_base_name(f)?;
300        write!(f, ".z")
301    }
302
303    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        let variable = Variable::<D>::UnitPos;
305        let ty = variable.item();
306        let cube_dim_x = Variable::<D>::CubeDimX;
307        let cube_dim_y = Variable::<D>::CubeDimY;
308        let unit_pos_x = Variable::<D>::UnitPosX;
309        let unit_pos_y = Variable::<D>::UnitPosY;
310        let unit_pos_z = Variable::<D>::UnitPosZ;
311        writeln!(
312            f,
313            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
314        )
315    }
316
317    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        f.write_str("threadIdxGlobal")
319    }
320
321    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322        f.write_str("threadIdx")
323    }
324
325    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        Self::compile_unit_pos_base_name(f)?;
327        write!(f, ".x")
328    }
329
330    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        Self::compile_unit_pos_base_name(f)?;
332        write!(f, ".y")
333    }
334
335    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        Self::compile_unit_pos_base_name(f)?;
337        write!(f, ".z")
338    }
339
340    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.write_str("warpSize")
342    }
343
344    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        f.write_str("warpSizeChecked")
346    }
347
348    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349        let unit_pos_x = Variable::<D>::UnitPosX;
350        let plane_dim = Variable::<D>::PlaneDim;
351        write!(f, "{unit_pos_x} / {plane_dim}")
352    }
353
354    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        let absolute_pos = Variable::<D>::AbsolutePos;
356        let plane_dim = Variable::<D>::PlaneDim;
357        write!(f, "{absolute_pos} % {plane_dim}")
358    }
359
360    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        write!(f, "0")
362    }
363    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        write!(f, "0")
365    }
366    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        write!(f, "0")
368    }
369    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        write!(f, "0")
371    }
372}
373
374// Instructions
375
376pub trait DialectInstructions<D: Dialect> {
377    // atomics
378    fn compile_atomic_add(
379        f: &mut std::fmt::Formatter<'_>,
380        lhs: &Variable<D>,
381        rhs: &Variable<D>,
382        out: &Variable<D>,
383    ) -> std::fmt::Result {
384        let out = out.fmt_left();
385        match rhs.elem() {
386            Elem::I64 => writeln!(
387                f,
388                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
389                uint = Elem::<D>::U64
390            ),
391            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
392        }
393    }
394
395    fn compile_atomic_and(
396        f: &mut std::fmt::Formatter<'_>,
397        lhs: &Variable<D>,
398        rhs: &Variable<D>,
399        out: &Variable<D>,
400    ) -> std::fmt::Result {
401        let out = out.fmt_left();
402        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
403    }
404
405    fn compile_atomic_cas(
406        f: &mut std::fmt::Formatter<'_>,
407        input: &Variable<D>,
408        cmp: &Variable<D>,
409        val: &Variable<D>,
410        out: &Variable<D>,
411    ) -> std::fmt::Result {
412        let out = out.fmt_left();
413        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
414    }
415
416    fn compile_atomic_load(
417        f: &mut std::fmt::Formatter<'_>,
418        input: &Variable<D>,
419        out: &Variable<D>,
420    ) -> std::fmt::Result {
421        let out = out.fmt_left();
422        writeln!(f, "{out} = atomicAdd({input}, 0);")
423    }
424
425    fn compile_atomic_max(
426        f: &mut std::fmt::Formatter<'_>,
427        lhs: &Variable<D>,
428        rhs: &Variable<D>,
429        out: &Variable<D>,
430    ) -> std::fmt::Result {
431        let out = out.fmt_left();
432        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
433    }
434
435    fn compile_atomic_min(
436        f: &mut std::fmt::Formatter<'_>,
437        lhs: &Variable<D>,
438        rhs: &Variable<D>,
439        out: &Variable<D>,
440    ) -> std::fmt::Result {
441        let out = out.fmt_left();
442        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
443    }
444
445    fn compile_atomic_or(
446        f: &mut std::fmt::Formatter<'_>,
447        lhs: &Variable<D>,
448        rhs: &Variable<D>,
449        out: &Variable<D>,
450    ) -> std::fmt::Result {
451        let out = out.fmt_left();
452        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
453    }
454
455    fn compile_atomic_store(
456        f: &mut std::fmt::Formatter<'_>,
457        input: &Variable<D>,
458        out: &Variable<D>,
459    ) -> std::fmt::Result {
460        writeln!(f, "atomicExch({out}, {input});")
461    }
462
463    fn compile_atomic_sub(
464        f: &mut std::fmt::Formatter<'_>,
465        lhs: &Variable<D>,
466        rhs: &Variable<D>,
467        out: &Variable<D>,
468    ) -> std::fmt::Result {
469        let out = out.fmt_left();
470        match rhs.elem() {
471            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
472            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
473            Elem::I64 => writeln!(
474                f,
475                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
476                uint = Elem::<D>::U64
477            ),
478            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
479        }
480    }
481
482    fn compile_atomic_swap(
483        f: &mut std::fmt::Formatter<'_>,
484        lhs: &Variable<D>,
485        rhs: &Variable<D>,
486        out: &Variable<D>,
487    ) -> std::fmt::Result {
488        let out = out.fmt_left();
489        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
490    }
491
492    fn compile_atomic_xor(
493        f: &mut std::fmt::Formatter<'_>,
494        lhs: &Variable<D>,
495        rhs: &Variable<D>,
496        out: &Variable<D>,
497    ) -> std::fmt::Result {
498        let out = out.fmt_left();
499        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
500    }
501
502    fn compile_saturating_add(
503        f: &mut std::fmt::Formatter<'_>,
504        lhs: impl Display,
505        rhs: impl Display,
506        item: Item<D>,
507    ) -> std::fmt::Result;
508
509    fn compile_saturating_sub(
510        f: &mut std::fmt::Formatter<'_>,
511        lhs: impl Display,
512        rhs: impl Display,
513        item: Item<D>,
514    ) -> std::fmt::Result;
515
516    // debug
517    fn compile_instruction_printf(
518        f: &mut std::fmt::Formatter<'_>,
519        format_string: &str,
520        args: &[Variable<D>],
521    ) -> std::fmt::Result {
522        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
523        let args = match args.is_empty() {
524            true => "".to_string(),
525            false => format!(", {}", args.join(",")),
526        };
527        writeln!(f, "printf({format_string:?}{args});")
528    }
529
530    // logs
531    fn compile_instruction_log1p_scalar<T: Component<D>>(
532        f: &mut std::fmt::Formatter<'_>,
533        input: T,
534    ) -> std::fmt::Result {
535        let elem = input.elem();
536        match elem {
537            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
538                write!(f, "{elem}(log1p(float({input})))")
539            }
540            _ => write!(f, "log1p({input})"),
541        }
542    }
543
544    // sync
545    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
546    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
547    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
548
549    // trigo
550    fn compile_instruction_tanh_scalar<T: Component<D>>(
551        f: &mut std::fmt::Formatter<'_>,
552        input: T,
553    ) -> std::fmt::Result {
554        let elem = input.elem();
555        match elem {
556            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557                write!(f, "{elem}(tanh(float({input})))")
558            }
559            _ => write!(f, "tanh({input})"),
560        }
561    }
562
563    // unary
564    fn compile_instruction_find_first_set<T: Component<D>>(
565        f: &mut std::fmt::Formatter<'_>,
566        input: T,
567        out_elem: Elem<D>,
568    ) -> std::fmt::Result;
569    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
570        f: &mut std::fmt::Formatter<'_>,
571        input: T,
572        out_elem: Elem<D>,
573    ) -> std::fmt::Result;
574
575    fn compile_instruction_popcount_scalar<T: Component<D>>(
576        f: &mut std::fmt::Formatter<'_>,
577        input: T,
578        out_elem: Elem<D>,
579    ) -> std::fmt::Result {
580        write!(f, "{out_elem}(")?;
581        match input.elem() {
582            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
583            Elem::U32 => write!(f, "__popc({input})"),
584            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
585            Elem::U64 => write!(f, "__popcll({input})"),
586            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
587        }?;
588        write!(f, ")")
589    }
590
591    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
592        f: &mut std::fmt::Formatter<'_>,
593        input: T,
594        out_elem: Elem<D>,
595    ) -> std::fmt::Result {
596        write!(f, "{out_elem}(")?;
597        match out_elem {
598            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
599            Elem::U32 => write!(f, "__brev({input})"),
600            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
601            Elem::U64 => write!(f, "__brevll({input})"),
602            _ => write!(
603                f,
604                "__brev({}) >> {}",
605                super::unary::zero_extend(input),
606                (size_of::<u32>() - out_elem.size()) * 8
607            ),
608        }?;
609        write!(f, ")")
610    }
611
612    // others
613    fn compile_instruction_max_function_name(
614        f: &mut std::fmt::Formatter<'_>,
615        item: Item<D>,
616    ) -> std::fmt::Result;
617
618    fn compile_instruction_min_function_name(
619        f: &mut std::fmt::Formatter<'_>,
620        item: Item<D>,
621    ) -> std::fmt::Result;
622
623    fn compile_instruction_powf(
624        f: &mut std::fmt::Formatter<'_>,
625        lhs: &str,
626        rhs: &str,
627        elem: Elem<D>,
628    ) -> std::fmt::Result {
629        match elem {
630            Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
631            Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
632            _ => panic!("Unsupported type for powf"),
633        }
634    }
635
636    fn compile_instruction_half_function_name_prefix() -> &'static str {
637        "h"
638    }
639
640    fn compile_instruction_half2_function_name_prefix() -> &'static str {
641        "h2"
642    }
643
644    // warp
645    fn compile_warp_shuffle(
646        f: &mut std::fmt::Formatter<'_>,
647        var: &str,
648        source: &str,
649    ) -> std::fmt::Result;
650    fn compile_warp_shuffle_xor(
651        f: &mut std::fmt::Formatter<'_>,
652        var: &str,
653        elem: &Elem<D>,
654        offset: &str,
655    ) -> std::fmt::Result;
656    fn compile_warp_shuffle_up(
657        f: &mut std::fmt::Formatter<'_>,
658        var: &str,
659        offset: &str,
660    ) -> std::fmt::Result;
661    fn compile_warp_shuffle_down(
662        f: &mut std::fmt::Formatter<'_>,
663        var: &str,
664        offset: &str,
665    ) -> std::fmt::Result;
666    fn compile_warp_all<T: Component<D>>(
667        f: &mut std::fmt::Formatter<'_>,
668        input: &T,
669    ) -> std::fmt::Result;
670    fn compile_warp_any<T: Component<D>>(
671        f: &mut std::fmt::Formatter<'_>,
672        input: &T,
673    ) -> std::fmt::Result;
674    fn compile_warp_ballot(
675        f: &mut std::fmt::Formatter<'_>,
676        input: &Variable<D>,
677        out_elem: &Elem<D>,
678    ) -> std::fmt::Result;
679}
680
681#[derive(Debug, Clone, Copy, new)]
682pub struct ManualMma<'a, D: Dialect> {
683    pub shape: MmaShape<D>,
684    pub frag_a: &'a [Variable<D>],
685    pub frag_b: &'a [Variable<D>],
686    pub frag_c: &'a [Variable<D>],
687    pub frag_d: &'a Variable<D>,
688}
689
690pub trait DialectWarpReduceCompiler<D: Dialect>:
691    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
692{
693    fn warp_reduce_sum(
694        f: &mut core::fmt::Formatter<'_>,
695        input: &Variable<D>,
696        out: &Variable<D>,
697    ) -> core::fmt::Result {
698        reduce_operator(f, input, out, "+=")
699    }
700    fn warp_reduce_prod(
701        f: &mut core::fmt::Formatter<'_>,
702        input: &Variable<D>,
703        out: &Variable<D>,
704    ) -> core::fmt::Result {
705        reduce_operator(f, input, out, "*=")
706    }
707    fn warp_reduce_max(
708        f: &mut core::fmt::Formatter<'_>,
709        input: &Variable<D>,
710        out: &Variable<D>,
711    ) -> core::fmt::Result {
712        reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
713    }
714    fn warp_reduce_min(
715        f: &mut core::fmt::Formatter<'_>,
716        input: &Variable<D>,
717        out: &Variable<D>,
718    ) -> core::fmt::Result {
719        reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
720    }
721    fn warp_reduce_all(
722        f: &mut core::fmt::Formatter<'_>,
723        input: &Variable<D>,
724        out: &Variable<D>,
725    ) -> core::fmt::Result {
726        reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
727    }
728    fn warp_reduce_any(
729        f: &mut core::fmt::Formatter<'_>,
730        input: &Variable<D>,
731        out: &Variable<D>,
732    ) -> core::fmt::Result {
733        reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
734    }
735    fn warp_reduce_sum_inclusive(
736        f: &mut core::fmt::Formatter<'_>,
737        input: &Variable<D>,
738        out: &Variable<D>,
739    ) -> core::fmt::Result {
740        reduce_inclusive(f, input, out, "+=")
741    }
742    fn warp_reduce_prod_inclusive(
743        f: &mut core::fmt::Formatter<'_>,
744        input: &Variable<D>,
745        out: &Variable<D>,
746    ) -> core::fmt::Result {
747        reduce_inclusive(f, input, out, "*=")
748    }
749    fn warp_reduce_sum_exclusive(
750        f: &mut core::fmt::Formatter<'_>,
751        input: &Variable<D>,
752        out: &Variable<D>,
753    ) -> core::fmt::Result {
754        reduce_exclusive(f, input, out, "+=", "0")
755    }
756    fn warp_reduce_prod_exclusive(
757        f: &mut core::fmt::Formatter<'_>,
758        input: &Variable<D>,
759        out: &Variable<D>,
760    ) -> core::fmt::Result {
761        reduce_exclusive(f, input, out, "*=", "1")
762    }
763}
764
765pub trait DialectWmmaCompiler<D: Dialect>:
766    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
767{
768    #[allow(unused_variables)]
769    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
770        Ok(())
771    }
772    #[allow(unused_variables)]
773    fn compile_wmma_type_definitions(
774        f: &mut std::fmt::Formatter<'_>,
775        flags: &Flags,
776    ) -> std::fmt::Result {
777        Ok(())
778    }
779    #[allow(unused_variables)]
780    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
781        Ok(())
782    }
783    #[allow(unused_variables)]
784    fn compile_wwma_fragment_ident(
785        f: &mut std::fmt::Formatter<'_>,
786        ident: &FragmentIdent<D>,
787    ) -> std::fmt::Result {
788        Ok(())
789    }
790    #[allow(unused_variables)]
791    fn compile_wmma_fragment_layout(
792        f: &mut std::fmt::Formatter<'_>,
793        layout: &FragmentLayout<D>,
794    ) -> std::fmt::Result {
795        Ok(())
796    }
797    #[allow(unused_variables)]
798    fn compile_wmma_fragment(
799        f: &mut std::fmt::Formatter<'_>,
800        fragment: &Fragment<D>,
801    ) -> std::fmt::Result {
802        Ok(())
803    }
804
805    fn compile_wmma_fragment_declaration(
806        f: &mut std::fmt::Formatter<'_>,
807        var: &Variable<D>,
808    ) -> std::fmt::Result;
809
810    fn compile_wmma_instruction(
811        f: &mut std::fmt::Formatter<'_>,
812        instruction: &WmmaInstruction<D>,
813    ) -> std::fmt::Result;
814    fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
815    fn compile_scaled_mma(
816        f: &mut std::fmt::Formatter<'_>,
817        mma: ManualMma<D>,
818        scales_a: Variable<D>,
819        scales_b: Variable<D>,
820        scales_factor: u32,
821    ) -> std::fmt::Result;
822    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
823    fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
824    fn supported_scaled_mma_combinations(
825        _arch: &D::Architecture,
826    ) -> SupportedScaledMmaCombinations {
827        Vec::new()
828    }
829}
830
831/// IR Processors to be applied to the scopes during processing. [`CheckedIO`] is always applied
832/// by default, so these are only for target specific processors like MMA index processors.
833pub trait DialectProcessors<D: Dialect> {
834    fn processors() -> Vec<Box<dyn Processor>>;
835}