1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug};
3
4use cubecl_core::ir::Id;
5
6use crate::shared::FmtLeft;
7
8use super::{
9 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
10 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, SupportedWmmaCombinations,
11 Variable, WarpInstruction, WmmaInstruction,
12};
13
14pub trait Dialect:
17 DialectIncludes<Self>
18 + DialectTypes<Self>
19 + DialectBindings<Self>
20 + DialectCubeBuiltins<Self>
21 + DialectInstructions<Self>
22 + DialectWmmaCompiler<Self>
23 + Default
24 + Clone
25 + Copy
26 + Debug
27 + Send
28 + Sync
29 + Eq
30 + Hash
31 + 'static
32{
33 type Architecture: Architecture;
34}
35
36pub trait DialectIncludes<D: Dialect> {
39 type Extension: Debug + Clone + Sync + Send;
40
41 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
42 fn compile_extensions(
43 f: &mut std::fmt::Formatter<'_>,
44 extensions: &[Self::Extension],
45 ) -> std::fmt::Result;
46 fn register_instruction_extension(
47 extensions: &mut Vec<Self::Extension>,
48 instruction: &Instruction<D>,
49 );
50 fn register_warp_instruction_extension(
51 extensions: &mut Vec<Self::Extension>,
52 instruction: &WarpInstruction<D>,
53 );
54}
55
56pub trait DialectTypes<D: Dialect> {
59 fn item_can_be_optimized() -> bool;
60 fn compile_elem(
61 f: &mut std::fmt::Formatter<'_>,
62 elem: &Elem<D>,
63 word: bool,
64 ) -> std::fmt::Result;
65
66 fn compile_atomic_kind(
67 f: &mut std::fmt::Formatter<'_>,
68 kind: &AtomicKind<D>,
69 ) -> std::fmt::Result {
70 match kind {
71 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
72 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
73 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
74 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
75 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
76 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
77 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
78 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
79 AtomicKind::_Dialect(_) => Ok(()),
80 }
81 }
82
83 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
84 fn compile_type_definitions(
85 f: &mut std::fmt::Formatter<'_>,
86 items: &HashSet<Item<D>>,
87 scalars: &[(Elem<D>, usize)],
88 flags: &Flags,
89 ) -> std::fmt::Result;
90 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
91 fn compile_shared_memory_declaration(
92 f: &mut std::fmt::Formatter<'_>,
93 shared: &SharedMemory<D>,
94 ) -> std::fmt::Result;
95 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
96 Ok(())
97 }
98 fn address_space_for_variable(_variable: &Variable<D>) -> String {
100 "".to_string()
101 }
102}
103
104pub trait DialectBindings<D: Dialect> {
107 fn compile_kernel_signature(
108 f: &mut std::fmt::Formatter<'_>,
109 kernel_name: &str,
110 tensor_maps: &[Id],
111 buffers: &[Binding<D>],
112 scalars: &[(Elem<D>, usize)],
113 flags: &Flags,
114 ) -> std::fmt::Result;
115 fn compile_bindings_body(
116 _f: &mut std::fmt::Formatter<'_>,
117 _body: &Body<D>,
118 ) -> std::fmt::Result {
119 Ok(())
120 }
121}
122
123pub trait DialectCubeBuiltins<D: Dialect> {
126 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
132 let unit_pos_plane = flags.unit_pos_plane;
133 let plane_dim_checked = flags.plane_dim_checked;
134 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
135 let plane_index = flags.plane_index;
136 let absolute_pos = flags.absolute_pos || unit_pos_plane;
137 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
138 let cube_dim = flags.cube_dim;
139 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
140 let unit_pos = flags.unit_pos;
141 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
142 let cube_count = flags.cube_count;
143 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
144 let cube_pos = flags.cube_pos;
145 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
146 let cluster_group = flags.cluster_pos;
147
148 CubeIndexFlags {
149 absolute_pos,
150 absolute_pos_tuple,
151 cube_count,
152 cube_count_tuple,
153 cube_dim,
154 cube_dim_tuple,
155 cube_pos,
156 cube_pos_tuple,
157 plane_dim,
158 plane_dim_checked,
159 plane_index,
160 unit_pos_tuple,
161 unit_pos,
162 unit_pos_plane,
163 cluster_pos: cluster_group,
164 }
165 }
166
167 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 let variable = Variable::<D>::AbsolutePosBaseName;
169 let ty = variable.item();
170 let cube_pos_x = Variable::<D>::CubePosX;
171 let cube_pos_y = Variable::<D>::CubePosY;
172 let cube_pos_z = Variable::<D>::CubePosZ;
173 let cube_dim_x = Variable::<D>::CubeDimX;
174 let cube_dim_y = Variable::<D>::CubeDimY;
175 let cube_dim_z = Variable::<D>::CubeDimZ;
176 let unit_pos_x = Variable::<D>::UnitPosX;
177 let unit_pos_y = Variable::<D>::UnitPosY;
178 let unit_pos_z = Variable::<D>::UnitPosZ;
179 writeln!(
180 f,
181 "{ty} {variable} = make_{ty}(
182 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
183 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
184 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
185);"
186 )
187 }
188
189 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.write_str("absoluteIdx")
191 }
192
193 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.write_str("idxGlobal")
195 }
196
197 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 Self::compile_absolute_pos_base_name(f)?;
199 write!(f, ".x")
200 }
201
202 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 Self::compile_absolute_pos_base_name(f)?;
204 write!(f, ".y")
205 }
206
207 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 Self::compile_absolute_pos_base_name(f)?;
209 write!(f, ".z")
210 }
211
212 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.write_str("gridDim")
214 }
215
216 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 f.write_str("gridDimGlobal")
218 }
219
220 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 Self::compile_cube_count_base_name(f)?;
222 write!(f, ".x")
223 }
224
225 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 Self::compile_cube_count_base_name(f)?;
227 write!(f, ".y")
228 }
229
230 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 Self::compile_cube_count_base_name(f)?;
232 write!(f, ".z")
233 }
234
235 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.write_str("blockDim")
237 }
238
239 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 f.write_str("blockDimGlobal")
241 }
242
243 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 Self::compile_cube_dim_base_name(f)?;
245 write!(f, ".x")
246 }
247
248 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 Self::compile_cube_dim_base_name(f)?;
250 write!(f, ".y")
251 }
252
253 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 Self::compile_cube_dim_base_name(f)?;
255 write!(f, ".z")
256 }
257
258 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.write_str("blockIdx")
260 }
261
262 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 f.write_str("blockIdxGlobal")
264 }
265
266 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 Self::compile_cube_pos_base_name(f)?;
268 write!(f, ".x")
269 }
270
271 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 Self::compile_cube_pos_base_name(f)?;
273 write!(f, ".y")
274 }
275
276 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 Self::compile_cube_pos_base_name(f)?;
278 write!(f, ".z")
279 }
280
281 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 let variable = Variable::<D>::UnitPos;
283 let ty = variable.item();
284 let cube_dim_x = Variable::<D>::CubeDimX;
285 let cube_dim_y = Variable::<D>::CubeDimY;
286 let unit_pos_x = Variable::<D>::UnitPosX;
287 let unit_pos_y = Variable::<D>::UnitPosY;
288 let unit_pos_z = Variable::<D>::UnitPosZ;
289 writeln!(
290 f,
291 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
292 )
293 }
294
295 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296 f.write_str("threadIdxGlobal")
297 }
298
299 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.write_str("threadIdx")
301 }
302
303 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 Self::compile_unit_pos_base_name(f)?;
305 write!(f, ".x")
306 }
307
308 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 Self::compile_unit_pos_base_name(f)?;
310 write!(f, ".y")
311 }
312
313 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 Self::compile_unit_pos_base_name(f)?;
315 write!(f, ".z")
316 }
317
318 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.write_str("warpSize")
320 }
321
322 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 f.write_str("warpSizeChecked")
324 }
325
326 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 let unit_pos_x = Variable::<D>::UnitPosX;
328 let plane_dim = Variable::<D>::PlaneDim;
329 write!(f, "{unit_pos_x} / {plane_dim}")
330 }
331
332 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 let absolute_pos = Variable::<D>::AbsolutePos;
334 let plane_dim = Variable::<D>::PlaneDim;
335 write!(f, "{absolute_pos} % {plane_dim}")
336 }
337
338 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 write!(f, "0")
340 }
341 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 write!(f, "0")
343 }
344 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 write!(f, "0")
346 }
347 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 write!(f, "0")
349 }
350}
351
352pub trait DialectInstructions<D: Dialect> {
355 fn compile_atomic_add(
357 f: &mut std::fmt::Formatter<'_>,
358 lhs: &Variable<D>,
359 rhs: &Variable<D>,
360 out: &Variable<D>,
361 ) -> std::fmt::Result {
362 let out = out.fmt_left();
363 match rhs.elem() {
364 Elem::I64 => writeln!(
365 f,
366 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
367 uint = Elem::<D>::U64
368 ),
369 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
370 }
371 }
372
373 fn compile_atomic_and(
374 f: &mut std::fmt::Formatter<'_>,
375 lhs: &Variable<D>,
376 rhs: &Variable<D>,
377 out: &Variable<D>,
378 ) -> std::fmt::Result {
379 let out = out.fmt_left();
380 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
381 }
382
383 fn compile_atomic_cas(
384 f: &mut std::fmt::Formatter<'_>,
385 input: &Variable<D>,
386 cmp: &Variable<D>,
387 val: &Variable<D>,
388 out: &Variable<D>,
389 ) -> std::fmt::Result {
390 let out = out.fmt_left();
391 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
392 }
393
394 fn compile_atomic_load(
395 f: &mut std::fmt::Formatter<'_>,
396 input: &Variable<D>,
397 out: &Variable<D>,
398 ) -> std::fmt::Result {
399 let out = out.fmt_left();
400 writeln!(f, "{out} = atomicAdd({input}, 0);")
401 }
402
403 fn compile_atomic_max(
404 f: &mut std::fmt::Formatter<'_>,
405 lhs: &Variable<D>,
406 rhs: &Variable<D>,
407 out: &Variable<D>,
408 ) -> std::fmt::Result {
409 let out = out.fmt_left();
410 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
411 }
412
413 fn compile_atomic_min(
414 f: &mut std::fmt::Formatter<'_>,
415 lhs: &Variable<D>,
416 rhs: &Variable<D>,
417 out: &Variable<D>,
418 ) -> std::fmt::Result {
419 let out = out.fmt_left();
420 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
421 }
422
423 fn compile_atomic_or(
424 f: &mut std::fmt::Formatter<'_>,
425 lhs: &Variable<D>,
426 rhs: &Variable<D>,
427 out: &Variable<D>,
428 ) -> std::fmt::Result {
429 let out = out.fmt_left();
430 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
431 }
432
433 fn compile_atomic_store(
434 f: &mut std::fmt::Formatter<'_>,
435 input: &Variable<D>,
436 out: &Variable<D>,
437 ) -> std::fmt::Result {
438 writeln!(f, "atomicExch({out}, {input});")
439 }
440
441 fn compile_atomic_sub(
442 f: &mut std::fmt::Formatter<'_>,
443 lhs: &Variable<D>,
444 rhs: &Variable<D>,
445 out: &Variable<D>,
446 ) -> std::fmt::Result {
447 let out = out.fmt_left();
448 match rhs.elem() {
449 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
450 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
451 Elem::I64 => writeln!(
452 f,
453 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
454 uint = Elem::<D>::U64
455 ),
456 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
457 }
458 }
459
460 fn compile_atomic_swap(
461 f: &mut std::fmt::Formatter<'_>,
462 lhs: &Variable<D>,
463 rhs: &Variable<D>,
464 out: &Variable<D>,
465 ) -> std::fmt::Result {
466 let out = out.fmt_left();
467 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
468 }
469
470 fn compile_atomic_xor(
471 f: &mut std::fmt::Formatter<'_>,
472 lhs: &Variable<D>,
473 rhs: &Variable<D>,
474 out: &Variable<D>,
475 ) -> std::fmt::Result {
476 let out = out.fmt_left();
477 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
478 }
479
480 fn compile_instruction_printf(
482 f: &mut std::fmt::Formatter<'_>,
483 format_string: &str,
484 args: &[Variable<D>],
485 ) -> std::fmt::Result {
486 let format_string = format_string
487 .replace("\t", "\\t")
488 .replace("\n", "\\n")
489 .replace("\r", "\\r");
490 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
491 let args = match args.is_empty() {
492 true => "".to_string(),
493 false => format!(", {}", args.join(",")),
494 };
495 writeln!(f, "printf(\"{format_string}\"{args});")
496 }
497
498 fn compile_instruction_log1p_scalar<T: Component<D>>(
500 f: &mut std::fmt::Formatter<'_>,
501 input: T,
502 ) -> std::fmt::Result {
503 let elem = input.elem();
504 match elem {
505 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
506 write!(f, "{elem}(log1p(float({input})))")
507 }
508 _ => write!(f, "log1p({input})"),
509 }
510 }
511
512 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
514 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
515 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
516
517 fn compile_instruction_tanh_scalar<T: Component<D>>(
519 f: &mut std::fmt::Formatter<'_>,
520 input: T,
521 ) -> std::fmt::Result {
522 let elem = input.elem();
523 match elem {
524 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
525 write!(f, "{elem}(tanh(float({input})))")
526 }
527 _ => write!(f, "tanh({input})"),
528 }
529 }
530
531 fn compile_instruction_find_first_set<T: Component<D>>(
533 f: &mut std::fmt::Formatter<'_>,
534 input: T,
535 out_elem: Elem<D>,
536 ) -> std::fmt::Result;
537 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
538 f: &mut std::fmt::Formatter<'_>,
539 input: T,
540 out_elem: Elem<D>,
541 ) -> std::fmt::Result;
542
543 fn compile_instruction_popcount_scalar<T: Component<D>>(
544 f: &mut std::fmt::Formatter<'_>,
545 input: T,
546 out_elem: Elem<D>,
547 ) -> std::fmt::Result {
548 write!(f, "{out_elem}(")?;
549 match input.elem() {
550 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
551 Elem::U32 => write!(f, "__popc({input})"),
552 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
553 Elem::U64 => write!(f, "__popcll({input})"),
554 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
555 }?;
556 write!(f, ")")
557 }
558
559 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
560 f: &mut std::fmt::Formatter<'_>,
561 input: T,
562 out_elem: Elem<D>,
563 ) -> std::fmt::Result {
564 write!(f, "{out_elem}(")?;
565 match out_elem {
566 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
567 Elem::U32 => write!(f, "__brev({input})"),
568 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
569 Elem::U64 => write!(f, "__brevll({input})"),
570 _ => write!(
571 f,
572 "__brev({}) >> {}",
573 super::unary::zero_extend(input),
574 (size_of::<u32>() - out_elem.size()) * 8
575 ),
576 }?;
577 write!(f, ")")
578 }
579
580 fn compile_instruction_max_function_name(
582 f: &mut std::fmt::Formatter<'_>,
583 item: Item<D>,
584 ) -> std::fmt::Result;
585
586 fn compile_instruction_min_function_name(
587 f: &mut std::fmt::Formatter<'_>,
588 item: Item<D>,
589 ) -> std::fmt::Result;
590
591 fn compile_instruction_powf(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 write!(f, "powf")
593 }
594
595 fn compile_instruction_half_function_name_prefix() -> &'static str {
596 "h"
597 }
598
599 fn compile_instruction_half2_function_name_prefix() -> &'static str {
600 "h2"
601 }
602
603 fn compile_warp_shuffle(
605 f: &mut std::fmt::Formatter<'_>,
606 var: &str,
607 source: &str,
608 ) -> std::fmt::Result;
609 fn compile_warp_shuffle_xor(
610 f: &mut std::fmt::Formatter<'_>,
611 var: &str,
612 elem: &Elem<D>,
613 offset: &str,
614 ) -> std::fmt::Result;
615 fn compile_warp_shuffle_up(
616 f: &mut std::fmt::Formatter<'_>,
617 var: &str,
618 offset: &str,
619 ) -> std::fmt::Result;
620 fn compile_warp_shuffle_down(
621 f: &mut std::fmt::Formatter<'_>,
622 var: &str,
623 offset: &str,
624 ) -> std::fmt::Result;
625 fn compile_warp_all<T: Component<D>>(
626 f: &mut std::fmt::Formatter<'_>,
627 input: &T,
628 ) -> std::fmt::Result;
629 fn compile_warp_any<T: Component<D>>(
630 f: &mut std::fmt::Formatter<'_>,
631 input: &T,
632 ) -> std::fmt::Result;
633 fn compile_warp_ballot(
634 f: &mut std::fmt::Formatter<'_>,
635 input: &Variable<D>,
636 out_elem: &Elem<D>,
637 ) -> std::fmt::Result;
638}
639
640pub trait DialectWmmaCompiler<D: Dialect>:
643 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
644{
645 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
646 fn compile_wmma_type_definitions(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
647 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
648 fn compile_wmma_fragment_declaration(
649 f: &mut std::fmt::Formatter<'_>,
650 var: &Variable<D>,
651 ) -> std::fmt::Result;
652 fn compile_wwma_fragment_ident(
653 f: &mut std::fmt::Formatter<'_>,
654 ident: &FragmentIdent<D>,
655 ) -> std::fmt::Result;
656 fn compile_wmma_fragment_layout(
657 f: &mut std::fmt::Formatter<'_>,
658 layout: &FragmentLayout<D>,
659 ) -> std::fmt::Result;
660 fn compile_wmma_fragment(
661 f: &mut std::fmt::Formatter<'_>,
662 fragment: &Fragment<D>,
663 ) -> std::fmt::Result;
664 fn compile_wmma_instruction(
665 f: &mut std::fmt::Formatter<'_>,
666 instruction: &WmmaInstruction<D>,
667 ) -> std::fmt::Result;
668 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedWmmaCombinations;
669}