cubecl_cpp/shared/
dialect.rs

1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug};
3
4use cubecl_core::ir::Id;
5
6use crate::shared::FmtLeft;
7
8use super::{
9    Architecture, AtomicKind, Binding, Component, CubeIndexFlags, Elem, Flags, Fragment,
10    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, SupportedWmmaCombinations,
11    Variable, WarpInstruction, WmmaInstruction,
12};
13
14// Base dialect
15
16pub trait Dialect:
17    DialectIncludes<Self>
18    + DialectTypes<Self>
19    + DialectBindings<Self>
20    + DialectCubeBuiltins<Self>
21    + DialectInstructions<Self>
22    + DialectWmmaCompiler<Self>
23    + Default
24    + Clone
25    + Copy
26    + Debug
27    + Send
28    + Sync
29    + Eq
30    + Hash
31    + 'static
32{
33}
34
35// Includes
36
37pub trait DialectIncludes<D: Dialect> {
38    type Extension: Debug + Clone + Sync + Send;
39
40    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
41    fn compile_extensions(
42        f: &mut std::fmt::Formatter<'_>,
43        extensions: &[Self::Extension],
44    ) -> std::fmt::Result;
45    fn register_instruction_extension(
46        extensions: &mut Vec<Self::Extension>,
47        instruction: &Instruction<D>,
48    );
49    fn register_warp_instruction_extension(
50        extensions: &mut Vec<Self::Extension>,
51        instruction: &WarpInstruction<D>,
52    );
53}
54
55// Types
56
57pub trait DialectTypes<D: Dialect> {
58    fn item_can_be_optimized() -> bool;
59    fn compile_elem(
60        f: &mut std::fmt::Formatter<'_>,
61        elem: &Elem<D>,
62        word: bool,
63    ) -> std::fmt::Result;
64
65    fn compile_atomic_kind(
66        f: &mut std::fmt::Formatter<'_>,
67        kind: &AtomicKind<D>,
68    ) -> std::fmt::Result {
69        match kind {
70            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
71            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
72            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
73            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
74            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
75            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
76            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
77            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
78            AtomicKind::_Dialect(_) => Ok(()),
79        }
80    }
81
82    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
83    fn compile_type_definitions(
84        f: &mut std::fmt::Formatter<'_>,
85        items: &HashSet<Item<D>>,
86        scalars: &[(Elem<D>, usize)],
87        flags: &Flags,
88    ) -> std::fmt::Result;
89    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
90    fn compile_shared_memory_qualifier(
91        f: &mut std::fmt::Formatter<'_>,
92        shared: &SharedMemory<D>,
93    ) -> std::fmt::Result;
94    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
95        Ok(())
96    }
97    /// Address space (for Metal dialect only).
98    fn address_space_for_variable(_variable: &Variable<D>) -> String {
99        "".to_string()
100    }
101}
102
103// Kernel argument bindings
104
105pub trait DialectBindings<D: Dialect> {
106    fn compile_kernel_signature(
107        f: &mut std::fmt::Formatter<'_>,
108        kernel_name: &str,
109        tensor_maps: &[Id],
110        buffers: &[Binding<D>],
111        scalars: &[(Elem<D>, usize)],
112        flags: &Flags,
113    ) -> std::fmt::Result;
114}
115
116// Cube builtins dialect
117
118pub trait DialectCubeBuiltins<D: Dialect> {
119    /// Depending on the dialect available built-in variables the
120    /// inclusion rules might change.
121    /// For instance in metal we have a built-in for the Unit plane position
122    /// but in other dialects there is none so we have to compute it using
123    /// other built-ins.
124    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
125        let unit_pos_plane = flags.unit_pos_plane;
126        let plane_dim_checked = flags.plane_dim_checked;
127        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
128        let plane_index = flags.plane_index;
129        let absolute_pos = flags.absolute_pos || unit_pos_plane;
130        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
131        let cube_dim = flags.cube_dim;
132        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
133        let unit_pos = flags.unit_pos;
134        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
135        let cube_count = flags.cube_count;
136        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
137        let cube_pos = flags.cube_pos;
138        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
139        let cluster_group = flags.cluster_pos;
140
141        CubeIndexFlags {
142            absolute_pos,
143            absolute_pos_tuple,
144            cube_count,
145            cube_count_tuple,
146            cube_dim,
147            cube_dim_tuple,
148            cube_pos,
149            cube_pos_tuple,
150            plane_dim,
151            plane_dim_checked,
152            plane_index,
153            unit_pos_tuple,
154            unit_pos,
155            unit_pos_plane,
156            cluster_pos: cluster_group,
157        }
158    }
159
160    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        let variable = Variable::<D>::AbsolutePosBaseName;
162        let ty = variable.item();
163        let cube_pos_x = Variable::<D>::CubePosX;
164        let cube_pos_y = Variable::<D>::CubePosY;
165        let cube_pos_z = Variable::<D>::CubePosZ;
166        let cube_dim_x = Variable::<D>::CubeDimX;
167        let cube_dim_y = Variable::<D>::CubeDimY;
168        let cube_dim_z = Variable::<D>::CubeDimZ;
169        let unit_pos_x = Variable::<D>::UnitPosX;
170        let unit_pos_y = Variable::<D>::UnitPosY;
171        let unit_pos_z = Variable::<D>::UnitPosZ;
172        writeln!(
173            f,
174            "{ty} {variable} = make_{ty}(
175    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
176    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
177    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
178);"
179        )
180    }
181
182    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.write_str("absoluteIdx")
184    }
185
186    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.write_str("idxGlobal")
188    }
189
190    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        Self::compile_absolute_pos_base_name(f)?;
192        write!(f, ".x")
193    }
194
195    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        Self::compile_absolute_pos_base_name(f)?;
197        write!(f, ".y")
198    }
199
200    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        Self::compile_absolute_pos_base_name(f)?;
202        write!(f, ".z")
203    }
204
205    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.write_str("gridDim")
207    }
208
209    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        f.write_str("gridDimGlobal")
211    }
212
213    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        Self::compile_cube_count_base_name(f)?;
215        write!(f, ".x")
216    }
217
218    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        Self::compile_cube_count_base_name(f)?;
220        write!(f, ".y")
221    }
222
223    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        Self::compile_cube_count_base_name(f)?;
225        write!(f, ".z")
226    }
227
228    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        f.write_str("blockDim")
230    }
231
232    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        f.write_str("blockDimGlobal")
234    }
235
236    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        Self::compile_cube_dim_base_name(f)?;
238        write!(f, ".x")
239    }
240
241    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        Self::compile_cube_dim_base_name(f)?;
243        write!(f, ".y")
244    }
245
246    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        Self::compile_cube_dim_base_name(f)?;
248        write!(f, ".z")
249    }
250
251    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        f.write_str("blockIdx")
253    }
254
255    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        f.write_str("blockIdxGlobal")
257    }
258
259    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        Self::compile_cube_pos_base_name(f)?;
261        write!(f, ".x")
262    }
263
264    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        Self::compile_cube_pos_base_name(f)?;
266        write!(f, ".y")
267    }
268
269    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        Self::compile_cube_pos_base_name(f)?;
271        write!(f, ".z")
272    }
273
274    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        let variable = Variable::<D>::UnitPos;
276        let ty = variable.item();
277        let cube_dim_x = Variable::<D>::CubeDimX;
278        let cube_dim_y = Variable::<D>::CubeDimY;
279        let unit_pos_x = Variable::<D>::UnitPosX;
280        let unit_pos_y = Variable::<D>::UnitPosY;
281        let unit_pos_z = Variable::<D>::UnitPosZ;
282        writeln!(
283            f,
284            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
285        )
286    }
287
288    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        f.write_str("threadIdxGlobal")
290    }
291
292    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.write_str("threadIdx")
294    }
295
296    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        Self::compile_unit_pos_base_name(f)?;
298        write!(f, ".x")
299    }
300
301    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        Self::compile_unit_pos_base_name(f)?;
303        write!(f, ".y")
304    }
305
306    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        Self::compile_unit_pos_base_name(f)?;
308        write!(f, ".z")
309    }
310
311    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        f.write_str("warpSize")
313    }
314
315    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.write_str("warpSizeChecked")
317    }
318
319    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        let unit_pos_x = Variable::<D>::UnitPosX;
321        let plane_dim = Variable::<D>::PlaneDim;
322        write!(f, "{unit_pos_x} / {plane_dim}")
323    }
324
325    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        let absolute_pos = Variable::<D>::AbsolutePos;
327        let plane_dim = Variable::<D>::PlaneDim;
328        write!(f, "{absolute_pos} % {plane_dim}")
329    }
330
331    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        write!(f, "0")
333    }
334    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335        write!(f, "0")
336    }
337    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        write!(f, "0")
339    }
340    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(f, "0")
342    }
343}
344
345// Instructions
346
347pub trait DialectInstructions<D: Dialect> {
348    // atomics
349    fn compile_atomic_add(
350        f: &mut std::fmt::Formatter<'_>,
351        lhs: &Variable<D>,
352        rhs: &Variable<D>,
353        out: &Variable<D>,
354    ) -> std::fmt::Result {
355        let out = out.fmt_left();
356        match rhs.elem() {
357            Elem::I64 => writeln!(
358                f,
359                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
360                uint = Elem::<D>::U64
361            ),
362            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
363        }
364    }
365
366    fn compile_atomic_and(
367        f: &mut std::fmt::Formatter<'_>,
368        lhs: &Variable<D>,
369        rhs: &Variable<D>,
370        out: &Variable<D>,
371    ) -> std::fmt::Result {
372        let out = out.fmt_left();
373        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
374    }
375
376    fn compile_atomic_cas(
377        f: &mut std::fmt::Formatter<'_>,
378        input: &Variable<D>,
379        cmp: &Variable<D>,
380        val: &Variable<D>,
381        out: &Variable<D>,
382    ) -> std::fmt::Result {
383        let out = out.fmt_left();
384        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
385    }
386
387    fn compile_atomic_load(
388        f: &mut std::fmt::Formatter<'_>,
389        input: &Variable<D>,
390        out: &Variable<D>,
391    ) -> std::fmt::Result {
392        let out = out.fmt_left();
393        writeln!(f, "{out} = atomicAdd({input}, 0);")
394    }
395
396    fn compile_atomic_max(
397        f: &mut std::fmt::Formatter<'_>,
398        lhs: &Variable<D>,
399        rhs: &Variable<D>,
400        out: &Variable<D>,
401    ) -> std::fmt::Result {
402        let out = out.fmt_left();
403        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
404    }
405
406    fn compile_atomic_min(
407        f: &mut std::fmt::Formatter<'_>,
408        lhs: &Variable<D>,
409        rhs: &Variable<D>,
410        out: &Variable<D>,
411    ) -> std::fmt::Result {
412        let out = out.fmt_left();
413        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
414    }
415
416    fn compile_atomic_or(
417        f: &mut std::fmt::Formatter<'_>,
418        lhs: &Variable<D>,
419        rhs: &Variable<D>,
420        out: &Variable<D>,
421    ) -> std::fmt::Result {
422        let out = out.fmt_left();
423        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
424    }
425
426    fn compile_atomic_store(
427        f: &mut std::fmt::Formatter<'_>,
428        input: &Variable<D>,
429        out: &Variable<D>,
430    ) -> std::fmt::Result {
431        writeln!(f, "atomicExch({out}, {input});")
432    }
433
434    fn compile_atomic_sub(
435        f: &mut std::fmt::Formatter<'_>,
436        lhs: &Variable<D>,
437        rhs: &Variable<D>,
438        out: &Variable<D>,
439    ) -> std::fmt::Result {
440        let out = out.fmt_left();
441        match rhs.elem() {
442            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
443            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
444            Elem::I64 => writeln!(
445                f,
446                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
447                uint = Elem::<D>::U64
448            ),
449            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
450        }
451    }
452
453    fn compile_atomic_swap(
454        f: &mut std::fmt::Formatter<'_>,
455        lhs: &Variable<D>,
456        rhs: &Variable<D>,
457        out: &Variable<D>,
458    ) -> std::fmt::Result {
459        let out = out.fmt_left();
460        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
461    }
462
463    fn compile_atomic_xor(
464        f: &mut std::fmt::Formatter<'_>,
465        lhs: &Variable<D>,
466        rhs: &Variable<D>,
467        out: &Variable<D>,
468    ) -> std::fmt::Result {
469        let out = out.fmt_left();
470        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
471    }
472
473    // debug
474    fn compile_instruction_printf(
475        f: &mut std::fmt::Formatter<'_>,
476        format_string: &str,
477        args: &[Variable<D>],
478    ) -> std::fmt::Result {
479        let format_string = format_string
480            .replace("\t", "\\t")
481            .replace("\n", "\\n")
482            .replace("\r", "\\r");
483        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
484        let args = match args.is_empty() {
485            true => "".to_string(),
486            false => format!(", {}", args.join(",")),
487        };
488        writeln!(f, "printf(\"{format_string}\"{args});")
489    }
490
491    // logs
492    fn compile_instruction_log1p_scalar<T: Component<D>>(
493        f: &mut std::fmt::Formatter<'_>,
494        input: T,
495    ) -> std::fmt::Result {
496        let elem = input.elem();
497        match elem {
498            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
499                write!(f, "{}(log1p(float({input})))", elem)
500            }
501            _ => write!(f, "log1p({input})"),
502        }
503    }
504
505    // sync
506    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
507    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
508
509    // trigo
510    fn compile_instruction_tanh_scalar<T: Component<D>>(
511        f: &mut std::fmt::Formatter<'_>,
512        input: T,
513    ) -> std::fmt::Result {
514        let elem = input.elem();
515        match elem {
516            Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
517                write!(f, "{}(tanh(float({input})))", elem)
518            }
519            _ => write!(f, "tanh({input})"),
520        }
521    }
522
523    // unary
524    fn compile_instruction_find_first_set<T: Component<D>>(
525        f: &mut std::fmt::Formatter<'_>,
526        input: T,
527        out_elem: Elem<D>,
528    ) -> std::fmt::Result;
529    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
530        f: &mut std::fmt::Formatter<'_>,
531        input: T,
532        out_elem: Elem<D>,
533    ) -> std::fmt::Result;
534
535    fn compile_instruction_popcount_scalar<T: Component<D>>(
536        f: &mut std::fmt::Formatter<'_>,
537        input: T,
538        out_elem: Elem<D>,
539    ) -> std::fmt::Result {
540        write!(f, "{out_elem}(")?;
541        match input.elem() {
542            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
543            Elem::U32 => write!(f, "__popc({input})"),
544            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
545            Elem::U64 => write!(f, "__popcll({input})"),
546            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
547        }?;
548        write!(f, ")")
549    }
550
551    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
552        f: &mut std::fmt::Formatter<'_>,
553        input: T,
554        out_elem: Elem<D>,
555    ) -> std::fmt::Result {
556        write!(f, "{out_elem}(")?;
557        match out_elem {
558            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
559            Elem::U32 => write!(f, "__brev({input})"),
560            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
561            Elem::U64 => write!(f, "__brevll({input})"),
562            _ => write!(
563                f,
564                "__brev({}) >> {}",
565                super::unary::zero_extend(input),
566                (size_of::<u32>() - out_elem.size()) * 8
567            ),
568        }?;
569        write!(f, ")")
570    }
571
572    // others
573    fn compile_instruction_max_function_name(
574        f: &mut std::fmt::Formatter<'_>,
575        item: Item<D>,
576    ) -> std::fmt::Result;
577
578    fn compile_instruction_min_function_name(
579        f: &mut std::fmt::Formatter<'_>,
580        item: Item<D>,
581    ) -> std::fmt::Result;
582
583    fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584        write!(f, "powf")
585    }
586
587    fn compile_instruction_half_function_name_prefix() -> &'static str {
588        "h"
589    }
590
591    fn compile_instruction_half2_function_name_prefix() -> &'static str {
592        "h2"
593    }
594
595    // warp
596    fn compile_warp_shuffle(
597        f: &mut std::fmt::Formatter<'_>,
598        var: &str,
599        source: &str,
600    ) -> std::fmt::Result;
601    fn compile_warp_shuffle_xor(
602        f: &mut std::fmt::Formatter<'_>,
603        var: &str,
604        elem: &Elem<D>,
605        offset: &str,
606    ) -> std::fmt::Result;
607    fn compile_warp_shuffle_up(
608        f: &mut std::fmt::Formatter<'_>,
609        var: &str,
610        offset: &str,
611    ) -> std::fmt::Result;
612    fn compile_warp_shuffle_down(
613        f: &mut std::fmt::Formatter<'_>,
614        var: &str,
615        offset: &str,
616    ) -> std::fmt::Result;
617    fn compile_warp_all<T: Component<D>>(
618        f: &mut std::fmt::Formatter<'_>,
619        input: &T,
620    ) -> std::fmt::Result;
621    fn compile_warp_any<T: Component<D>>(
622        f: &mut std::fmt::Formatter<'_>,
623        input: &T,
624    ) -> std::fmt::Result;
625    fn compile_warp_ballot(
626        f: &mut std::fmt::Formatter<'_>,
627        input: &Variable<D>,
628        out_elem: &Elem<D>,
629    ) -> std::fmt::Result;
630}
631
632// Coop Matrices dialect
633
634pub trait DialectWmmaCompiler<D: Dialect>:
635    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
636{
637    type Architecture: Architecture;
638
639    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
640    fn compile_wmma_type_definitions(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
641    fn compile_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
642    fn compile_fragment_ident(
643        ident: &FragmentIdent<D>,
644        f: &mut std::fmt::Formatter<'_>,
645    ) -> std::fmt::Result;
646    fn compile_fragment_layout(
647        layout: &FragmentLayout<D>,
648        f: &mut std::fmt::Formatter<'_>,
649    ) -> std::fmt::Result;
650    fn compile_fragment(
651        fragment: &Fragment<D>,
652        f: &mut std::fmt::Formatter<'_>,
653    ) -> std::fmt::Result;
654    fn compile_instruction(
655        instruction: &WmmaInstruction<D>,
656        f: &mut std::fmt::Formatter<'_>,
657    ) -> std::fmt::Result;
658    fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations;
659}