cubecl_cpp/shared/
dialect.rs

1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug};
3
4use cubecl_core::ir::Id;
5
6use crate::shared::FmtLeft;
7
8use super::{
9    Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
10    FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, SupportedWmmaCombinations,
11    Variable, WarpInstruction, WmmaInstruction,
12};
13
14// Base dialect
15
16pub trait Dialect:
17    DialectIncludes<Self>
18    + DialectTypes<Self>
19    + DialectBindings<Self>
20    + DialectCubeBuiltins<Self>
21    + DialectInstructions<Self>
22    + DialectWmmaCompiler<Self>
23    + Default
24    + Clone
25    + Copy
26    + Debug
27    + Send
28    + Sync
29    + Eq
30    + Hash
31    + 'static
32{
33    type Architecture: Architecture;
34}
35
36// Includes
37
38pub trait DialectIncludes<D: Dialect> {
39    type Extension: Debug + Clone + Sync + Send;
40
41    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
42    fn compile_extensions(
43        f: &mut std::fmt::Formatter<'_>,
44        extensions: &[Self::Extension],
45    ) -> std::fmt::Result;
46    fn register_instruction_extension(
47        extensions: &mut Vec<Self::Extension>,
48        instruction: &Instruction<D>,
49    );
50    fn register_warp_instruction_extension(
51        extensions: &mut Vec<Self::Extension>,
52        instruction: &WarpInstruction<D>,
53    );
54}
55
56// Types
57
58pub trait DialectTypes<D: Dialect> {
59    fn item_can_be_optimized() -> bool;
60    fn compile_elem(
61        f: &mut std::fmt::Formatter<'_>,
62        elem: &Elem<D>,
63        word: bool,
64    ) -> std::fmt::Result;
65
66    fn compile_atomic_kind(
67        f: &mut std::fmt::Formatter<'_>,
68        kind: &AtomicKind<D>,
69    ) -> std::fmt::Result {
70        match kind {
71            AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
72            AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
73            AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
74            AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
75            AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
76            AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
77            AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
78            AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
79            AtomicKind::_Dialect(_) => Ok(()),
80        }
81    }
82
83    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
84    fn compile_type_definitions(
85        f: &mut std::fmt::Formatter<'_>,
86        items: &HashSet<Item<D>>,
87        scalars: &[(Elem<D>, usize)],
88        flags: &Flags,
89    ) -> std::fmt::Result;
90    fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
91    fn compile_shared_memory_declaration(
92        f: &mut std::fmt::Formatter<'_>,
93        shared: &SharedMemory<D>,
94    ) -> std::fmt::Result;
95    fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
96        Ok(())
97    }
98    /// Address space (for Metal dialect only).
99    fn address_space_for_variable(_variable: &Variable<D>) -> String {
100        "".to_string()
101    }
102}
103
104// Kernel argument bindings
105
106pub trait DialectBindings<D: Dialect> {
107    fn compile_kernel_signature(
108        f: &mut std::fmt::Formatter<'_>,
109        kernel_name: &str,
110        tensor_maps: &[Id],
111        buffers: &[Binding<D>],
112        scalars: &[(Elem<D>, usize)],
113        flags: &Flags,
114    ) -> std::fmt::Result;
115    fn compile_bindings_body(
116        _f: &mut std::fmt::Formatter<'_>,
117        _body: &Body<D>,
118    ) -> std::fmt::Result {
119        Ok(())
120    }
121}
122
123// Cube builtins dialect
124
125pub trait DialectCubeBuiltins<D: Dialect> {
126    /// Depending on the dialect available built-in variables the
127    /// inclusion rules might change.
128    /// For instance in metal we have a built-in for the Unit plane position
129    /// but in other dialects there is none so we have to compute it using
130    /// other built-ins.
131    fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
132        let unit_pos_plane = flags.unit_pos_plane;
133        let plane_dim_checked = flags.plane_dim_checked;
134        let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
135        let plane_index = flags.plane_index;
136        let absolute_pos = flags.absolute_pos || unit_pos_plane;
137        let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
138        let cube_dim = flags.cube_dim;
139        let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
140        let unit_pos = flags.unit_pos;
141        let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
142        let cube_count = flags.cube_count;
143        let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
144        let cube_pos = flags.cube_pos;
145        let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
146        let cluster_group = flags.cluster_pos;
147
148        CubeIndexFlags {
149            absolute_pos,
150            absolute_pos_tuple,
151            cube_count,
152            cube_count_tuple,
153            cube_dim,
154            cube_dim_tuple,
155            cube_pos,
156            cube_pos_tuple,
157            plane_dim,
158            plane_dim_checked,
159            plane_index,
160            unit_pos_tuple,
161            unit_pos,
162            unit_pos_plane,
163            cluster_pos: cluster_group,
164        }
165    }
166
167    fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        let variable = Variable::<D>::AbsolutePosBaseName;
169        let ty = variable.item();
170        let cube_pos_x = Variable::<D>::CubePosX;
171        let cube_pos_y = Variable::<D>::CubePosY;
172        let cube_pos_z = Variable::<D>::CubePosZ;
173        let cube_dim_x = Variable::<D>::CubeDimX;
174        let cube_dim_y = Variable::<D>::CubeDimY;
175        let cube_dim_z = Variable::<D>::CubeDimZ;
176        let unit_pos_x = Variable::<D>::UnitPosX;
177        let unit_pos_y = Variable::<D>::UnitPosY;
178        let unit_pos_z = Variable::<D>::UnitPosZ;
179        writeln!(
180            f,
181            "{ty} {variable} = make_{ty}(
182    {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
183    {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
184    {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
185);"
186        )
187    }
188
189    fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.write_str("absoluteIdx")
191    }
192
193    fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.write_str("idxGlobal")
195    }
196
197    fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        Self::compile_absolute_pos_base_name(f)?;
199        write!(f, ".x")
200    }
201
202    fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        Self::compile_absolute_pos_base_name(f)?;
204        write!(f, ".y")
205    }
206
207    fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        Self::compile_absolute_pos_base_name(f)?;
209        write!(f, ".z")
210    }
211
212    fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.write_str("gridDim")
214    }
215
216    fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        f.write_str("gridDimGlobal")
218    }
219
220    fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        Self::compile_cube_count_base_name(f)?;
222        write!(f, ".x")
223    }
224
225    fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        Self::compile_cube_count_base_name(f)?;
227        write!(f, ".y")
228    }
229
230    fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        Self::compile_cube_count_base_name(f)?;
232        write!(f, ".z")
233    }
234
235    fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        f.write_str("blockDim")
237    }
238
239    fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.write_str("blockDimGlobal")
241    }
242
243    fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        Self::compile_cube_dim_base_name(f)?;
245        write!(f, ".x")
246    }
247
248    fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        Self::compile_cube_dim_base_name(f)?;
250        write!(f, ".y")
251    }
252
253    fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        Self::compile_cube_dim_base_name(f)?;
255        write!(f, ".z")
256    }
257
258    fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.write_str("blockIdx")
260    }
261
262    fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        f.write_str("blockIdxGlobal")
264    }
265
266    fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        Self::compile_cube_pos_base_name(f)?;
268        write!(f, ".x")
269    }
270
271    fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        Self::compile_cube_pos_base_name(f)?;
273        write!(f, ".y")
274    }
275
276    fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        Self::compile_cube_pos_base_name(f)?;
278        write!(f, ".z")
279    }
280
281    fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        let variable = Variable::<D>::UnitPos;
283        let ty = variable.item();
284        let cube_dim_x = Variable::<D>::CubeDimX;
285        let cube_dim_y = Variable::<D>::CubeDimY;
286        let unit_pos_x = Variable::<D>::UnitPosX;
287        let unit_pos_y = Variable::<D>::UnitPosY;
288        let unit_pos_z = Variable::<D>::UnitPosZ;
289        writeln!(
290            f,
291            "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
292        )
293    }
294
295    fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        f.write_str("threadIdxGlobal")
297    }
298
299    fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.write_str("threadIdx")
301    }
302
303    fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        Self::compile_unit_pos_base_name(f)?;
305        write!(f, ".x")
306    }
307
308    fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        Self::compile_unit_pos_base_name(f)?;
310        write!(f, ".y")
311    }
312
313    fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        Self::compile_unit_pos_base_name(f)?;
315        write!(f, ".z")
316    }
317
318    fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        f.write_str("warpSize")
320    }
321
322    fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        f.write_str("warpSizeChecked")
324    }
325
326    fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        let unit_pos_x = Variable::<D>::UnitPosX;
328        let plane_dim = Variable::<D>::PlaneDim;
329        write!(f, "{unit_pos_x} / {plane_dim}")
330    }
331
332    fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        let absolute_pos = Variable::<D>::AbsolutePos;
334        let plane_dim = Variable::<D>::PlaneDim;
335        write!(f, "{absolute_pos} % {plane_dim}")
336    }
337
338    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        write!(f, "0")
340    }
341    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "0")
343    }
344    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        write!(f, "0")
346    }
347    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        write!(f, "0")
349    }
350}
351
352// Instructions
353
354pub trait DialectInstructions<D: Dialect> {
355    // atomics
356    fn compile_atomic_add(
357        f: &mut std::fmt::Formatter<'_>,
358        lhs: &Variable<D>,
359        rhs: &Variable<D>,
360        out: &Variable<D>,
361    ) -> std::fmt::Result {
362        let out = out.fmt_left();
363        match rhs.elem() {
364            Elem::I64 => writeln!(
365                f,
366                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
367                uint = Elem::<D>::U64
368            ),
369            _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
370        }
371    }
372
373    fn compile_atomic_and(
374        f: &mut std::fmt::Formatter<'_>,
375        lhs: &Variable<D>,
376        rhs: &Variable<D>,
377        out: &Variable<D>,
378    ) -> std::fmt::Result {
379        let out = out.fmt_left();
380        writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
381    }
382
383    fn compile_atomic_cas(
384        f: &mut std::fmt::Formatter<'_>,
385        input: &Variable<D>,
386        cmp: &Variable<D>,
387        val: &Variable<D>,
388        out: &Variable<D>,
389    ) -> std::fmt::Result {
390        let out = out.fmt_left();
391        writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
392    }
393
394    fn compile_atomic_load(
395        f: &mut std::fmt::Formatter<'_>,
396        input: &Variable<D>,
397        out: &Variable<D>,
398    ) -> std::fmt::Result {
399        let out = out.fmt_left();
400        writeln!(f, "{out} = atomicAdd({input}, 0);")
401    }
402
403    fn compile_atomic_max(
404        f: &mut std::fmt::Formatter<'_>,
405        lhs: &Variable<D>,
406        rhs: &Variable<D>,
407        out: &Variable<D>,
408    ) -> std::fmt::Result {
409        let out = out.fmt_left();
410        writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
411    }
412
413    fn compile_atomic_min(
414        f: &mut std::fmt::Formatter<'_>,
415        lhs: &Variable<D>,
416        rhs: &Variable<D>,
417        out: &Variable<D>,
418    ) -> std::fmt::Result {
419        let out = out.fmt_left();
420        writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
421    }
422
423    fn compile_atomic_or(
424        f: &mut std::fmt::Formatter<'_>,
425        lhs: &Variable<D>,
426        rhs: &Variable<D>,
427        out: &Variable<D>,
428    ) -> std::fmt::Result {
429        let out = out.fmt_left();
430        writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
431    }
432
433    fn compile_atomic_store(
434        f: &mut std::fmt::Formatter<'_>,
435        input: &Variable<D>,
436        out: &Variable<D>,
437    ) -> std::fmt::Result {
438        writeln!(f, "atomicExch({out}, {input});")
439    }
440
441    fn compile_atomic_sub(
442        f: &mut std::fmt::Formatter<'_>,
443        lhs: &Variable<D>,
444        rhs: &Variable<D>,
445        out: &Variable<D>,
446    ) -> std::fmt::Result {
447        let out = out.fmt_left();
448        match rhs.elem() {
449            Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
450            Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
451            Elem::I64 => writeln!(
452                f,
453                "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
454                uint = Elem::<D>::U64
455            ),
456            _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
457        }
458    }
459
460    fn compile_atomic_swap(
461        f: &mut std::fmt::Formatter<'_>,
462        lhs: &Variable<D>,
463        rhs: &Variable<D>,
464        out: &Variable<D>,
465    ) -> std::fmt::Result {
466        let out = out.fmt_left();
467        writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
468    }
469
470    fn compile_atomic_xor(
471        f: &mut std::fmt::Formatter<'_>,
472        lhs: &Variable<D>,
473        rhs: &Variable<D>,
474        out: &Variable<D>,
475    ) -> std::fmt::Result {
476        let out = out.fmt_left();
477        writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
478    }
479
480    // debug
481    fn compile_instruction_printf(
482        f: &mut std::fmt::Formatter<'_>,
483        format_string: &str,
484        args: &[Variable<D>],
485    ) -> std::fmt::Result {
486        let format_string = format_string
487            .replace("\t", "\\t")
488            .replace("\n", "\\n")
489            .replace("\r", "\\r");
490        let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
491        let args = match args.is_empty() {
492            true => "".to_string(),
493            false => format!(", {}", args.join(",")),
494        };
495        writeln!(f, "printf(\"{format_string}\"{args});")
496    }
497
498    // logs
499    fn compile_instruction_log1p_scalar<T: Component<D>>(
500        f: &mut std::fmt::Formatter<'_>,
501        input: T,
502    ) -> std::fmt::Result {
503        let elem = input.elem();
504        match elem {
505            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
506                write!(f, "{elem}(log1p(float({input})))")
507            }
508            _ => write!(f, "log1p({input})"),
509        }
510    }
511
512    // sync
513    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
514    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
515    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
516
517    // trigo
518    fn compile_instruction_tanh_scalar<T: Component<D>>(
519        f: &mut std::fmt::Formatter<'_>,
520        input: T,
521    ) -> std::fmt::Result {
522        let elem = input.elem();
523        match elem {
524            Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
525                write!(f, "{elem}(tanh(float({input})))")
526            }
527            _ => write!(f, "tanh({input})"),
528        }
529    }
530
531    // unary
532    fn compile_instruction_find_first_set<T: Component<D>>(
533        f: &mut std::fmt::Formatter<'_>,
534        input: T,
535        out_elem: Elem<D>,
536    ) -> std::fmt::Result;
537    fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
538        f: &mut std::fmt::Formatter<'_>,
539        input: T,
540        out_elem: Elem<D>,
541    ) -> std::fmt::Result;
542
543    fn compile_instruction_popcount_scalar<T: Component<D>>(
544        f: &mut std::fmt::Formatter<'_>,
545        input: T,
546        out_elem: Elem<D>,
547    ) -> std::fmt::Result {
548        write!(f, "{out_elem}(")?;
549        match input.elem() {
550            Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
551            Elem::U32 => write!(f, "__popc({input})"),
552            Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
553            Elem::U64 => write!(f, "__popcll({input})"),
554            _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
555        }?;
556        write!(f, ")")
557    }
558
559    fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
560        f: &mut std::fmt::Formatter<'_>,
561        input: T,
562        out_elem: Elem<D>,
563    ) -> std::fmt::Result {
564        write!(f, "{out_elem}(")?;
565        match out_elem {
566            Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
567            Elem::U32 => write!(f, "__brev({input})"),
568            Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
569            Elem::U64 => write!(f, "__brevll({input})"),
570            _ => write!(
571                f,
572                "__brev({}) >> {}",
573                super::unary::zero_extend(input),
574                (size_of::<u32>() - out_elem.size()) * 8
575            ),
576        }?;
577        write!(f, ")")
578    }
579
580    // others
581    fn compile_instruction_max_function_name(
582        f: &mut std::fmt::Formatter<'_>,
583        item: Item<D>,
584    ) -> std::fmt::Result;
585
586    fn compile_instruction_min_function_name(
587        f: &mut std::fmt::Formatter<'_>,
588        item: Item<D>,
589    ) -> std::fmt::Result;
590
591    fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        write!(f, "powf")
593    }
594
595    fn compile_instruction_half_function_name_prefix() -> &'static str {
596        "h"
597    }
598
599    fn compile_instruction_half2_function_name_prefix() -> &'static str {
600        "h2"
601    }
602
603    // warp
604    fn compile_warp_shuffle(
605        f: &mut std::fmt::Formatter<'_>,
606        var: &str,
607        source: &str,
608    ) -> std::fmt::Result;
609    fn compile_warp_shuffle_xor(
610        f: &mut std::fmt::Formatter<'_>,
611        var: &str,
612        elem: &Elem<D>,
613        offset: &str,
614    ) -> std::fmt::Result;
615    fn compile_warp_shuffle_up(
616        f: &mut std::fmt::Formatter<'_>,
617        var: &str,
618        offset: &str,
619    ) -> std::fmt::Result;
620    fn compile_warp_shuffle_down(
621        f: &mut std::fmt::Formatter<'_>,
622        var: &str,
623        offset: &str,
624    ) -> std::fmt::Result;
625    fn compile_warp_all<T: Component<D>>(
626        f: &mut std::fmt::Formatter<'_>,
627        input: &T,
628    ) -> std::fmt::Result;
629    fn compile_warp_any<T: Component<D>>(
630        f: &mut std::fmt::Formatter<'_>,
631        input: &T,
632    ) -> std::fmt::Result;
633    fn compile_warp_ballot(
634        f: &mut std::fmt::Formatter<'_>,
635        input: &Variable<D>,
636        out_elem: &Elem<D>,
637    ) -> std::fmt::Result;
638}
639
640// Coop Matrices dialect
641
642pub trait DialectWmmaCompiler<D: Dialect>:
643    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
644{
645    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
646    fn compile_wmma_type_definitions(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
647    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
648    fn compile_wmma_fragment_declaration(
649        f: &mut std::fmt::Formatter<'_>,
650        var: &Variable<D>,
651    ) -> std::fmt::Result;
652    fn compile_wwma_fragment_ident(
653        f: &mut std::fmt::Formatter<'_>,
654        ident: &FragmentIdent<D>,
655    ) -> std::fmt::Result;
656    fn compile_wmma_fragment_layout(
657        f: &mut std::fmt::Formatter<'_>,
658        layout: &FragmentLayout<D>,
659    ) -> std::fmt::Result;
660    fn compile_wmma_fragment(
661        f: &mut std::fmt::Formatter<'_>,
662        fragment: &Fragment<D>,
663    ) -> std::fmt::Result;
664    fn compile_wmma_instruction(
665        f: &mut std::fmt::Formatter<'_>,
666        instruction: &WmmaInstruction<D>,
667    ) -> std::fmt::Result;
668    fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedWmmaCombinations;
669}