1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug};
3
4use cubecl_core::ir::Id;
5
6use crate::shared::FmtLeft;
7
8use super::{
9 Architecture, AtomicKind, Binding, Component, CubeIndexFlags, Elem, Flags, Fragment,
10 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, SupportedWmmaCombinations,
11 Variable, WarpInstruction, WmmaInstruction,
12};
13
14pub trait Dialect:
17 DialectIncludes<Self>
18 + DialectTypes<Self>
19 + DialectBindings<Self>
20 + DialectCubeBuiltins<Self>
21 + DialectInstructions<Self>
22 + DialectWmmaCompiler<Self>
23 + Default
24 + Clone
25 + Copy
26 + Debug
27 + Send
28 + Sync
29 + Eq
30 + Hash
31 + 'static
32{
33}
34
35pub trait DialectIncludes<D: Dialect> {
38 type Extension: Debug + Clone + Sync + Send;
39
40 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
41 fn compile_extensions(
42 f: &mut std::fmt::Formatter<'_>,
43 extensions: &[Self::Extension],
44 ) -> std::fmt::Result;
45 fn register_instruction_extension(
46 extensions: &mut Vec<Self::Extension>,
47 instruction: &Instruction<D>,
48 );
49 fn register_warp_instruction_extension(
50 extensions: &mut Vec<Self::Extension>,
51 instruction: &WarpInstruction<D>,
52 );
53}
54
55pub trait DialectTypes<D: Dialect> {
58 fn item_can_be_optimized() -> bool;
59 fn compile_elem(
60 f: &mut std::fmt::Formatter<'_>,
61 elem: &Elem<D>,
62 word: bool,
63 ) -> std::fmt::Result;
64
65 fn compile_atomic_kind(
66 f: &mut std::fmt::Formatter<'_>,
67 kind: &AtomicKind<D>,
68 ) -> std::fmt::Result {
69 match kind {
70 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
71 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
72 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
73 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
74 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
75 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
76 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
77 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
78 AtomicKind::_Dialect(_) => Ok(()),
79 }
80 }
81
82 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
83 fn compile_type_definitions(
84 f: &mut std::fmt::Formatter<'_>,
85 items: &HashSet<Item<D>>,
86 scalars: &[(Elem<D>, usize)],
87 flags: &Flags,
88 ) -> std::fmt::Result;
89 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
90 fn compile_shared_memory_qualifier(
91 f: &mut std::fmt::Formatter<'_>,
92 shared: &SharedMemory<D>,
93 ) -> std::fmt::Result;
94 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
95 Ok(())
96 }
97 fn address_space_for_variable(_variable: &Variable<D>) -> String {
99 "".to_string()
100 }
101}
102
103pub trait DialectBindings<D: Dialect> {
106 fn compile_kernel_signature(
107 f: &mut std::fmt::Formatter<'_>,
108 kernel_name: &str,
109 tensor_maps: &[Id],
110 buffers: &[Binding<D>],
111 scalars: &[(Elem<D>, usize)],
112 flags: &Flags,
113 ) -> std::fmt::Result;
114}
115
116pub trait DialectCubeBuiltins<D: Dialect> {
119 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
125 let unit_pos_plane = flags.unit_pos_plane;
126 let plane_dim_checked = flags.plane_dim_checked;
127 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
128 let plane_index = flags.plane_index;
129 let absolute_pos = flags.absolute_pos || unit_pos_plane;
130 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
131 let cube_dim = flags.cube_dim;
132 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
133 let unit_pos = flags.unit_pos;
134 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
135 let cube_count = flags.cube_count;
136 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
137 let cube_pos = flags.cube_pos;
138 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
139 let cluster_group = flags.cluster_pos;
140
141 CubeIndexFlags {
142 absolute_pos,
143 absolute_pos_tuple,
144 cube_count,
145 cube_count_tuple,
146 cube_dim,
147 cube_dim_tuple,
148 cube_pos,
149 cube_pos_tuple,
150 plane_dim,
151 plane_dim_checked,
152 plane_index,
153 unit_pos_tuple,
154 unit_pos,
155 unit_pos_plane,
156 cluster_pos: cluster_group,
157 }
158 }
159
160 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 let variable = Variable::<D>::AbsolutePosBaseName;
162 let ty = variable.item();
163 let cube_pos_x = Variable::<D>::CubePosX;
164 let cube_pos_y = Variable::<D>::CubePosY;
165 let cube_pos_z = Variable::<D>::CubePosZ;
166 let cube_dim_x = Variable::<D>::CubeDimX;
167 let cube_dim_y = Variable::<D>::CubeDimY;
168 let cube_dim_z = Variable::<D>::CubeDimZ;
169 let unit_pos_x = Variable::<D>::UnitPosX;
170 let unit_pos_y = Variable::<D>::UnitPosY;
171 let unit_pos_z = Variable::<D>::UnitPosZ;
172 writeln!(
173 f,
174 "{ty} {variable} = make_{ty}(
175 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
176 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
177 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
178);"
179 )
180 }
181
182 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.write_str("absoluteIdx")
184 }
185
186 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.write_str("idxGlobal")
188 }
189
190 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 Self::compile_absolute_pos_base_name(f)?;
192 write!(f, ".x")
193 }
194
195 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 Self::compile_absolute_pos_base_name(f)?;
197 write!(f, ".y")
198 }
199
200 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 Self::compile_absolute_pos_base_name(f)?;
202 write!(f, ".z")
203 }
204
205 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.write_str("gridDim")
207 }
208
209 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.write_str("gridDimGlobal")
211 }
212
213 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 Self::compile_cube_count_base_name(f)?;
215 write!(f, ".x")
216 }
217
218 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 Self::compile_cube_count_base_name(f)?;
220 write!(f, ".y")
221 }
222
223 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 Self::compile_cube_count_base_name(f)?;
225 write!(f, ".z")
226 }
227
228 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 f.write_str("blockDim")
230 }
231
232 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 f.write_str("blockDimGlobal")
234 }
235
236 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 Self::compile_cube_dim_base_name(f)?;
238 write!(f, ".x")
239 }
240
241 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 Self::compile_cube_dim_base_name(f)?;
243 write!(f, ".y")
244 }
245
246 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 Self::compile_cube_dim_base_name(f)?;
248 write!(f, ".z")
249 }
250
251 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.write_str("blockIdx")
253 }
254
255 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 f.write_str("blockIdxGlobal")
257 }
258
259 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 Self::compile_cube_pos_base_name(f)?;
261 write!(f, ".x")
262 }
263
264 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 Self::compile_cube_pos_base_name(f)?;
266 write!(f, ".y")
267 }
268
269 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 Self::compile_cube_pos_base_name(f)?;
271 write!(f, ".z")
272 }
273
274 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 let variable = Variable::<D>::UnitPos;
276 let ty = variable.item();
277 let cube_dim_x = Variable::<D>::CubeDimX;
278 let cube_dim_y = Variable::<D>::CubeDimY;
279 let unit_pos_x = Variable::<D>::UnitPosX;
280 let unit_pos_y = Variable::<D>::UnitPosY;
281 let unit_pos_z = Variable::<D>::UnitPosZ;
282 writeln!(
283 f,
284 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
285 )
286 }
287
288 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 f.write_str("threadIdxGlobal")
290 }
291
292 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.write_str("threadIdx")
294 }
295
296 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 Self::compile_unit_pos_base_name(f)?;
298 write!(f, ".x")
299 }
300
301 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 Self::compile_unit_pos_base_name(f)?;
303 write!(f, ".y")
304 }
305
306 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 Self::compile_unit_pos_base_name(f)?;
308 write!(f, ".z")
309 }
310
311 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 f.write_str("warpSize")
313 }
314
315 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 f.write_str("warpSizeChecked")
317 }
318
319 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 let unit_pos_x = Variable::<D>::UnitPosX;
321 let plane_dim = Variable::<D>::PlaneDim;
322 write!(f, "{unit_pos_x} / {plane_dim}")
323 }
324
325 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 let absolute_pos = Variable::<D>::AbsolutePos;
327 let plane_dim = Variable::<D>::PlaneDim;
328 write!(f, "{absolute_pos} % {plane_dim}")
329 }
330
331 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 write!(f, "0")
333 }
334 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 write!(f, "0")
336 }
337 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 write!(f, "0")
339 }
340 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 write!(f, "0")
342 }
343}
344
345pub trait DialectInstructions<D: Dialect> {
348 fn compile_atomic_add(
350 f: &mut std::fmt::Formatter<'_>,
351 lhs: &Variable<D>,
352 rhs: &Variable<D>,
353 out: &Variable<D>,
354 ) -> std::fmt::Result {
355 let out = out.fmt_left();
356 match rhs.elem() {
357 Elem::I64 => writeln!(
358 f,
359 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
360 uint = Elem::<D>::U64
361 ),
362 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
363 }
364 }
365
366 fn compile_atomic_and(
367 f: &mut std::fmt::Formatter<'_>,
368 lhs: &Variable<D>,
369 rhs: &Variable<D>,
370 out: &Variable<D>,
371 ) -> std::fmt::Result {
372 let out = out.fmt_left();
373 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
374 }
375
376 fn compile_atomic_cas(
377 f: &mut std::fmt::Formatter<'_>,
378 input: &Variable<D>,
379 cmp: &Variable<D>,
380 val: &Variable<D>,
381 out: &Variable<D>,
382 ) -> std::fmt::Result {
383 let out = out.fmt_left();
384 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
385 }
386
387 fn compile_atomic_load(
388 f: &mut std::fmt::Formatter<'_>,
389 input: &Variable<D>,
390 out: &Variable<D>,
391 ) -> std::fmt::Result {
392 let out = out.fmt_left();
393 writeln!(f, "{out} = atomicAdd({input}, 0);")
394 }
395
396 fn compile_atomic_max(
397 f: &mut std::fmt::Formatter<'_>,
398 lhs: &Variable<D>,
399 rhs: &Variable<D>,
400 out: &Variable<D>,
401 ) -> std::fmt::Result {
402 let out = out.fmt_left();
403 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
404 }
405
406 fn compile_atomic_min(
407 f: &mut std::fmt::Formatter<'_>,
408 lhs: &Variable<D>,
409 rhs: &Variable<D>,
410 out: &Variable<D>,
411 ) -> std::fmt::Result {
412 let out = out.fmt_left();
413 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
414 }
415
416 fn compile_atomic_or(
417 f: &mut std::fmt::Formatter<'_>,
418 lhs: &Variable<D>,
419 rhs: &Variable<D>,
420 out: &Variable<D>,
421 ) -> std::fmt::Result {
422 let out = out.fmt_left();
423 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
424 }
425
426 fn compile_atomic_store(
427 f: &mut std::fmt::Formatter<'_>,
428 input: &Variable<D>,
429 out: &Variable<D>,
430 ) -> std::fmt::Result {
431 writeln!(f, "atomicExch({out}, {input});")
432 }
433
434 fn compile_atomic_sub(
435 f: &mut std::fmt::Formatter<'_>,
436 lhs: &Variable<D>,
437 rhs: &Variable<D>,
438 out: &Variable<D>,
439 ) -> std::fmt::Result {
440 let out = out.fmt_left();
441 match rhs.elem() {
442 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
443 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
444 Elem::I64 => writeln!(
445 f,
446 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
447 uint = Elem::<D>::U64
448 ),
449 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
450 }
451 }
452
453 fn compile_atomic_swap(
454 f: &mut std::fmt::Formatter<'_>,
455 lhs: &Variable<D>,
456 rhs: &Variable<D>,
457 out: &Variable<D>,
458 ) -> std::fmt::Result {
459 let out = out.fmt_left();
460 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
461 }
462
463 fn compile_atomic_xor(
464 f: &mut std::fmt::Formatter<'_>,
465 lhs: &Variable<D>,
466 rhs: &Variable<D>,
467 out: &Variable<D>,
468 ) -> std::fmt::Result {
469 let out = out.fmt_left();
470 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
471 }
472
473 fn compile_instruction_printf(
475 f: &mut std::fmt::Formatter<'_>,
476 format_string: &str,
477 args: &[Variable<D>],
478 ) -> std::fmt::Result {
479 let format_string = format_string
480 .replace("\t", "\\t")
481 .replace("\n", "\\n")
482 .replace("\r", "\\r");
483 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
484 let args = match args.is_empty() {
485 true => "".to_string(),
486 false => format!(", {}", args.join(",")),
487 };
488 writeln!(f, "printf(\"{format_string}\"{args});")
489 }
490
491 fn compile_instruction_log1p_scalar<T: Component<D>>(
493 f: &mut std::fmt::Formatter<'_>,
494 input: T,
495 ) -> std::fmt::Result {
496 let elem = input.elem();
497 match elem {
498 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
499 write!(f, "{}(log1p(float({input})))", elem)
500 }
501 _ => write!(f, "log1p({input})"),
502 }
503 }
504
505 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
507 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
508
509 fn compile_instruction_tanh_scalar<T: Component<D>>(
511 f: &mut std::fmt::Formatter<'_>,
512 input: T,
513 ) -> std::fmt::Result {
514 let elem = input.elem();
515 match elem {
516 Elem::F16 | Elem::F162 | Elem::BF16 | Elem::BF162 => {
517 write!(f, "{}(tanh(float({input})))", elem)
518 }
519 _ => write!(f, "tanh({input})"),
520 }
521 }
522
523 fn compile_instruction_find_first_set<T: Component<D>>(
525 f: &mut std::fmt::Formatter<'_>,
526 input: T,
527 out_elem: Elem<D>,
528 ) -> std::fmt::Result;
529 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
530 f: &mut std::fmt::Formatter<'_>,
531 input: T,
532 out_elem: Elem<D>,
533 ) -> std::fmt::Result;
534
535 fn compile_instruction_popcount_scalar<T: Component<D>>(
536 f: &mut std::fmt::Formatter<'_>,
537 input: T,
538 out_elem: Elem<D>,
539 ) -> std::fmt::Result {
540 write!(f, "{out_elem}(")?;
541 match input.elem() {
542 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
543 Elem::U32 => write!(f, "__popc({input})"),
544 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
545 Elem::U64 => write!(f, "__popcll({input})"),
546 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
547 }?;
548 write!(f, ")")
549 }
550
551 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
552 f: &mut std::fmt::Formatter<'_>,
553 input: T,
554 out_elem: Elem<D>,
555 ) -> std::fmt::Result {
556 write!(f, "{out_elem}(")?;
557 match out_elem {
558 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
559 Elem::U32 => write!(f, "__brev({input})"),
560 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
561 Elem::U64 => write!(f, "__brevll({input})"),
562 _ => write!(
563 f,
564 "__brev({}) >> {}",
565 super::unary::zero_extend(input),
566 (size_of::<u32>() - out_elem.size()) * 8
567 ),
568 }?;
569 write!(f, ")")
570 }
571
572 fn compile_instruction_max_function_name(
574 f: &mut std::fmt::Formatter<'_>,
575 item: Item<D>,
576 ) -> std::fmt::Result;
577
578 fn compile_instruction_min_function_name(
579 f: &mut std::fmt::Formatter<'_>,
580 item: Item<D>,
581 ) -> std::fmt::Result;
582
583 fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 write!(f, "powf")
585 }
586
587 fn compile_instruction_half_function_name_prefix() -> &'static str {
588 "h"
589 }
590
591 fn compile_instruction_half2_function_name_prefix() -> &'static str {
592 "h2"
593 }
594
595 fn compile_warp_shuffle(
597 f: &mut std::fmt::Formatter<'_>,
598 var: &str,
599 source: &str,
600 ) -> std::fmt::Result;
601 fn compile_warp_shuffle_xor(
602 f: &mut std::fmt::Formatter<'_>,
603 var: &str,
604 elem: &Elem<D>,
605 offset: &str,
606 ) -> std::fmt::Result;
607 fn compile_warp_shuffle_up(
608 f: &mut std::fmt::Formatter<'_>,
609 var: &str,
610 offset: &str,
611 ) -> std::fmt::Result;
612 fn compile_warp_shuffle_down(
613 f: &mut std::fmt::Formatter<'_>,
614 var: &str,
615 offset: &str,
616 ) -> std::fmt::Result;
617 fn compile_warp_all<T: Component<D>>(
618 f: &mut std::fmt::Formatter<'_>,
619 input: &T,
620 ) -> std::fmt::Result;
621 fn compile_warp_any<T: Component<D>>(
622 f: &mut std::fmt::Formatter<'_>,
623 input: &T,
624 ) -> std::fmt::Result;
625 fn compile_warp_ballot(
626 f: &mut std::fmt::Formatter<'_>,
627 input: &Variable<D>,
628 out_elem: &Elem<D>,
629 ) -> std::fmt::Result;
630}
631
632pub trait DialectWmmaCompiler<D: Dialect>:
635 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
636{
637 type Architecture: Architecture;
638
639 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
640 fn compile_wmma_type_definitions(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
641 fn compile_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
642 fn compile_fragment_ident(
643 ident: &FragmentIdent<D>,
644 f: &mut std::fmt::Formatter<'_>,
645 ) -> std::fmt::Result;
646 fn compile_fragment_layout(
647 layout: &FragmentLayout<D>,
648 f: &mut std::fmt::Formatter<'_>,
649 ) -> std::fmt::Result;
650 fn compile_fragment(
651 fragment: &Fragment<D>,
652 f: &mut std::fmt::Formatter<'_>,
653 ) -> std::fmt::Result;
654 fn compile_instruction(
655 instruction: &WmmaInstruction<D>,
656 f: &mut std::fmt::Formatter<'_>,
657 ) -> std::fmt::Result;
658 fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations;
659}