1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14 WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90 AtomicKind::_Dialect(_) => Ok(()),
91 }
92 }
93
94 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95 fn compile_type_definitions(
96 f: &mut std::fmt::Formatter<'_>,
97 items: &HashSet<Item<D>>,
98 scalars: &[(Elem<D>, usize)],
99 flags: &Flags,
100 ) -> std::fmt::Result;
101 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102 fn compile_shared_memory_declaration(
103 f: &mut std::fmt::Formatter<'_>,
104 shared: &SharedMemory<D>,
105 ) -> std::fmt::Result {
106 let item = shared.item;
107 let index = shared.index;
108 let offset = shared.offset;
109 let size = shared.length;
110 let size_bytes = size * shared.item.size() as u32;
111 writeln!(f, "// Shared memory size: {size}, {size_bytes} bytes")?;
112 writeln!(
113 f,
114 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
115 )
116 }
117 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
118 Ok(())
119 }
120 fn address_space_for_variable(_variable: &Variable<D>) -> String {
122 "".to_string()
123 }
124}
125
126pub trait DialectBindings<D: Dialect> {
129 fn compile_kernel_signature(
130 f: &mut std::fmt::Formatter<'_>,
131 kernel_name: &str,
132 tensor_maps: &[Binding<D>],
133 buffers: &[Binding<D>],
134 scalars: &[(Elem<D>, usize)],
135 flags: &Flags,
136 ) -> std::fmt::Result;
137 fn compile_bindings_body(
138 _f: &mut std::fmt::Formatter<'_>,
139 _body: &Body<D>,
140 ) -> std::fmt::Result {
141 Ok(())
142 }
143}
144
145pub trait DialectCubeBuiltins<D: Dialect> {
148 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
154 let unit_pos_plane = flags.unit_pos_plane;
155 let plane_dim_checked = flags.plane_dim_checked;
156 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
157 let plane_index = flags.plane_index;
158 let absolute_pos = flags.absolute_pos || unit_pos_plane;
159 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
160 let cube_dim = flags.cube_dim;
161 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
162 let unit_pos = flags.unit_pos;
163 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
164 let cube_count = flags.cube_count;
165 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
166 let cube_pos = flags.cube_pos;
167 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
168 let cluster_group = flags.cluster_pos;
169
170 CubeIndexFlags {
171 absolute_pos,
172 absolute_pos_tuple,
173 cube_count,
174 cube_count_tuple,
175 cube_dim,
176 cube_dim_tuple,
177 cube_pos,
178 cube_pos_tuple,
179 plane_dim,
180 plane_dim_checked,
181 plane_index,
182 unit_pos_tuple,
183 unit_pos,
184 unit_pos_plane,
185 cluster_pos: cluster_group,
186 }
187 }
188
189 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 let variable = Variable::<D>::AbsolutePosBaseName;
191 let ty = variable.item();
192 let cube_pos_x = Variable::<D>::CubePosX;
193 let cube_pos_y = Variable::<D>::CubePosY;
194 let cube_pos_z = Variable::<D>::CubePosZ;
195 let cube_dim_x = Variable::<D>::CubeDimX;
196 let cube_dim_y = Variable::<D>::CubeDimY;
197 let cube_dim_z = Variable::<D>::CubeDimZ;
198 let unit_pos_x = Variable::<D>::UnitPosX;
199 let unit_pos_y = Variable::<D>::UnitPosY;
200 let unit_pos_z = Variable::<D>::UnitPosZ;
201 writeln!(
202 f,
203 "{ty} {variable} = make_{ty}(
204 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
205 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
206 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
207);"
208 )
209 }
210
211 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 f.write_str("absoluteIdx")
213 }
214
215 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.write_str("idxGlobal")
217 }
218
219 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 Self::compile_absolute_pos_base_name(f)?;
221 write!(f, ".x")
222 }
223
224 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 Self::compile_absolute_pos_base_name(f)?;
226 write!(f, ".y")
227 }
228
229 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 Self::compile_absolute_pos_base_name(f)?;
231 write!(f, ".z")
232 }
233
234 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("gridDim")
236 }
237
238 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.write_str("gridDimGlobal")
240 }
241
242 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 Self::compile_cube_count_base_name(f)?;
244 write!(f, ".x")
245 }
246
247 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 Self::compile_cube_count_base_name(f)?;
249 write!(f, ".y")
250 }
251
252 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 Self::compile_cube_count_base_name(f)?;
254 write!(f, ".z")
255 }
256
257 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.write_str("blockDim")
259 }
260
261 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.write_str("blockDimGlobal")
263 }
264
265 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 Self::compile_cube_dim_base_name(f)?;
267 write!(f, ".x")
268 }
269
270 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 Self::compile_cube_dim_base_name(f)?;
272 write!(f, ".y")
273 }
274
275 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 Self::compile_cube_dim_base_name(f)?;
277 write!(f, ".z")
278 }
279
280 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.write_str("blockIdx")
282 }
283
284 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.write_str("blockIdxGlobal")
286 }
287
288 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 Self::compile_cube_pos_base_name(f)?;
290 write!(f, ".x")
291 }
292
293 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 Self::compile_cube_pos_base_name(f)?;
295 write!(f, ".y")
296 }
297
298 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 Self::compile_cube_pos_base_name(f)?;
300 write!(f, ".z")
301 }
302
303 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 let variable = Variable::<D>::UnitPos;
305 let ty = variable.item();
306 let cube_dim_x = Variable::<D>::CubeDimX;
307 let cube_dim_y = Variable::<D>::CubeDimY;
308 let unit_pos_x = Variable::<D>::UnitPosX;
309 let unit_pos_y = Variable::<D>::UnitPosY;
310 let unit_pos_z = Variable::<D>::UnitPosZ;
311 writeln!(
312 f,
313 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
314 )
315 }
316
317 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 f.write_str("threadIdxGlobal")
319 }
320
321 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 f.write_str("threadIdx")
323 }
324
325 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 Self::compile_unit_pos_base_name(f)?;
327 write!(f, ".x")
328 }
329
330 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 Self::compile_unit_pos_base_name(f)?;
332 write!(f, ".y")
333 }
334
335 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 Self::compile_unit_pos_base_name(f)?;
337 write!(f, ".z")
338 }
339
340 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.write_str("warpSize")
342 }
343
344 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 f.write_str("warpSizeChecked")
346 }
347
348 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349 let unit_pos_x = Variable::<D>::UnitPosX;
350 let plane_dim = Variable::<D>::PlaneDim;
351 write!(f, "{unit_pos_x} / {plane_dim}")
352 }
353
354 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 let absolute_pos = Variable::<D>::AbsolutePos;
356 let plane_dim = Variable::<D>::PlaneDim;
357 write!(f, "{absolute_pos} % {plane_dim}")
358 }
359
360 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 write!(f, "0")
362 }
363 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(f, "0")
365 }
366 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 write!(f, "0")
368 }
369 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 write!(f, "0")
371 }
372}
373
374pub trait DialectInstructions<D: Dialect> {
377 fn compile_atomic_add(
379 f: &mut std::fmt::Formatter<'_>,
380 lhs: &Variable<D>,
381 rhs: &Variable<D>,
382 out: &Variable<D>,
383 ) -> std::fmt::Result {
384 let out = out.fmt_left();
385 match rhs.elem() {
386 Elem::I64 => writeln!(
387 f,
388 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
389 uint = Elem::<D>::U64
390 ),
391 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
392 }
393 }
394
395 fn compile_atomic_and(
396 f: &mut std::fmt::Formatter<'_>,
397 lhs: &Variable<D>,
398 rhs: &Variable<D>,
399 out: &Variable<D>,
400 ) -> std::fmt::Result {
401 let out = out.fmt_left();
402 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
403 }
404
405 fn compile_atomic_cas(
406 f: &mut std::fmt::Formatter<'_>,
407 input: &Variable<D>,
408 cmp: &Variable<D>,
409 val: &Variable<D>,
410 out: &Variable<D>,
411 ) -> std::fmt::Result {
412 let out = out.fmt_left();
413 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
414 }
415
416 fn compile_atomic_load(
417 f: &mut std::fmt::Formatter<'_>,
418 input: &Variable<D>,
419 out: &Variable<D>,
420 ) -> std::fmt::Result {
421 let out = out.fmt_left();
422 writeln!(f, "{out} = atomicAdd({input}, 0);")
423 }
424
425 fn compile_atomic_max(
426 f: &mut std::fmt::Formatter<'_>,
427 lhs: &Variable<D>,
428 rhs: &Variable<D>,
429 out: &Variable<D>,
430 ) -> std::fmt::Result {
431 let out = out.fmt_left();
432 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
433 }
434
435 fn compile_atomic_min(
436 f: &mut std::fmt::Formatter<'_>,
437 lhs: &Variable<D>,
438 rhs: &Variable<D>,
439 out: &Variable<D>,
440 ) -> std::fmt::Result {
441 let out = out.fmt_left();
442 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
443 }
444
445 fn compile_atomic_or(
446 f: &mut std::fmt::Formatter<'_>,
447 lhs: &Variable<D>,
448 rhs: &Variable<D>,
449 out: &Variable<D>,
450 ) -> std::fmt::Result {
451 let out = out.fmt_left();
452 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
453 }
454
455 fn compile_atomic_store(
456 f: &mut std::fmt::Formatter<'_>,
457 input: &Variable<D>,
458 out: &Variable<D>,
459 ) -> std::fmt::Result {
460 writeln!(f, "atomicExch({out}, {input});")
461 }
462
463 fn compile_atomic_sub(
464 f: &mut std::fmt::Formatter<'_>,
465 lhs: &Variable<D>,
466 rhs: &Variable<D>,
467 out: &Variable<D>,
468 ) -> std::fmt::Result {
469 let out = out.fmt_left();
470 match rhs.elem() {
471 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
472 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
473 Elem::I64 => writeln!(
474 f,
475 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
476 uint = Elem::<D>::U64
477 ),
478 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
479 }
480 }
481
482 fn compile_atomic_swap(
483 f: &mut std::fmt::Formatter<'_>,
484 lhs: &Variable<D>,
485 rhs: &Variable<D>,
486 out: &Variable<D>,
487 ) -> std::fmt::Result {
488 let out = out.fmt_left();
489 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
490 }
491
492 fn compile_atomic_xor(
493 f: &mut std::fmt::Formatter<'_>,
494 lhs: &Variable<D>,
495 rhs: &Variable<D>,
496 out: &Variable<D>,
497 ) -> std::fmt::Result {
498 let out = out.fmt_left();
499 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
500 }
501
502 fn compile_saturating_add(
503 f: &mut std::fmt::Formatter<'_>,
504 lhs: impl Display,
505 rhs: impl Display,
506 item: Item<D>,
507 ) -> std::fmt::Result;
508
509 fn compile_saturating_sub(
510 f: &mut std::fmt::Formatter<'_>,
511 lhs: impl Display,
512 rhs: impl Display,
513 item: Item<D>,
514 ) -> std::fmt::Result;
515
516 fn compile_instruction_printf(
518 f: &mut std::fmt::Formatter<'_>,
519 format_string: &str,
520 args: &[Variable<D>],
521 ) -> std::fmt::Result {
522 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
523 let args = match args.is_empty() {
524 true => "".to_string(),
525 false => format!(", {}", args.join(",")),
526 };
527 writeln!(f, "printf({format_string:?}{args});")
528 }
529
530 fn compile_instruction_log1p_scalar<T: Component<D>>(
532 f: &mut std::fmt::Formatter<'_>,
533 input: T,
534 ) -> std::fmt::Result {
535 let elem = input.elem();
536 match elem {
537 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
538 write!(f, "{elem}(log1p(float({input})))")
539 }
540 _ => write!(f, "log1p({input})"),
541 }
542 }
543
544 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
546 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
547 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
548
549 fn compile_instruction_tanh_scalar<T: Component<D>>(
551 f: &mut std::fmt::Formatter<'_>,
552 input: T,
553 ) -> std::fmt::Result {
554 let elem = input.elem();
555 match elem {
556 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557 write!(f, "{elem}(tanh(float({input})))")
558 }
559 _ => write!(f, "tanh({input})"),
560 }
561 }
562
563 fn compile_instruction_find_first_set<T: Component<D>>(
565 f: &mut std::fmt::Formatter<'_>,
566 input: T,
567 out_elem: Elem<D>,
568 ) -> std::fmt::Result;
569 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
570 f: &mut std::fmt::Formatter<'_>,
571 input: T,
572 out_elem: Elem<D>,
573 ) -> std::fmt::Result;
574
575 fn compile_instruction_popcount_scalar<T: Component<D>>(
576 f: &mut std::fmt::Formatter<'_>,
577 input: T,
578 out_elem: Elem<D>,
579 ) -> std::fmt::Result {
580 write!(f, "{out_elem}(")?;
581 match input.elem() {
582 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
583 Elem::U32 => write!(f, "__popc({input})"),
584 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
585 Elem::U64 => write!(f, "__popcll({input})"),
586 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
587 }?;
588 write!(f, ")")
589 }
590
591 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
592 f: &mut std::fmt::Formatter<'_>,
593 input: T,
594 out_elem: Elem<D>,
595 ) -> std::fmt::Result {
596 write!(f, "{out_elem}(")?;
597 match out_elem {
598 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
599 Elem::U32 => write!(f, "__brev({input})"),
600 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
601 Elem::U64 => write!(f, "__brevll({input})"),
602 _ => write!(
603 f,
604 "__brev({}) >> {}",
605 super::unary::zero_extend(input),
606 (size_of::<u32>() - out_elem.size()) * 8
607 ),
608 }?;
609 write!(f, ")")
610 }
611
612 fn compile_instruction_max_function_name(
614 f: &mut std::fmt::Formatter<'_>,
615 item: Item<D>,
616 ) -> std::fmt::Result;
617
618 fn compile_instruction_min_function_name(
619 f: &mut std::fmt::Formatter<'_>,
620 item: Item<D>,
621 ) -> std::fmt::Result;
622
623 fn compile_instruction_powf(
624 f: &mut std::fmt::Formatter<'_>,
625 lhs: &str,
626 rhs: &str,
627 elem: Elem<D>,
628 ) -> std::fmt::Result {
629 match elem {
630 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
631 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
632 _ => panic!("Unsupported type for powf"),
633 }
634 }
635
636 fn compile_instruction_half_function_name_prefix() -> &'static str {
637 "h"
638 }
639
640 fn compile_instruction_half2_function_name_prefix() -> &'static str {
641 "h2"
642 }
643
644 fn compile_warp_shuffle(
646 f: &mut std::fmt::Formatter<'_>,
647 var: &str,
648 source: &str,
649 ) -> std::fmt::Result;
650 fn compile_warp_shuffle_xor(
651 f: &mut std::fmt::Formatter<'_>,
652 var: &str,
653 elem: &Elem<D>,
654 offset: &str,
655 ) -> std::fmt::Result;
656 fn compile_warp_shuffle_up(
657 f: &mut std::fmt::Formatter<'_>,
658 var: &str,
659 offset: &str,
660 ) -> std::fmt::Result;
661 fn compile_warp_shuffle_down(
662 f: &mut std::fmt::Formatter<'_>,
663 var: &str,
664 offset: &str,
665 ) -> std::fmt::Result;
666 fn compile_warp_all<T: Component<D>>(
667 f: &mut std::fmt::Formatter<'_>,
668 input: &T,
669 ) -> std::fmt::Result;
670 fn compile_warp_any<T: Component<D>>(
671 f: &mut std::fmt::Formatter<'_>,
672 input: &T,
673 ) -> std::fmt::Result;
674 fn compile_warp_ballot(
675 f: &mut std::fmt::Formatter<'_>,
676 input: &Variable<D>,
677 out_elem: &Elem<D>,
678 ) -> std::fmt::Result;
679}
680
681#[derive(Debug, Clone, Copy, new)]
682pub struct ManualMma<'a, D: Dialect> {
683 pub shape: MmaShape<D>,
684 pub frag_a: &'a [Variable<D>],
685 pub frag_b: &'a [Variable<D>],
686 pub frag_c: &'a [Variable<D>],
687 pub frag_d: &'a Variable<D>,
688}
689
690pub trait DialectWarpReduceCompiler<D: Dialect>:
691 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
692{
693 fn warp_reduce_sum(
694 f: &mut core::fmt::Formatter<'_>,
695 input: &Variable<D>,
696 out: &Variable<D>,
697 ) -> core::fmt::Result {
698 reduce_operator(f, input, out, "+=")
699 }
700 fn warp_reduce_prod(
701 f: &mut core::fmt::Formatter<'_>,
702 input: &Variable<D>,
703 out: &Variable<D>,
704 ) -> core::fmt::Result {
705 reduce_operator(f, input, out, "*=")
706 }
707 fn warp_reduce_max(
708 f: &mut core::fmt::Formatter<'_>,
709 input: &Variable<D>,
710 out: &Variable<D>,
711 ) -> core::fmt::Result {
712 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
713 }
714 fn warp_reduce_min(
715 f: &mut core::fmt::Formatter<'_>,
716 input: &Variable<D>,
717 out: &Variable<D>,
718 ) -> core::fmt::Result {
719 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
720 }
721 fn warp_reduce_all(
722 f: &mut core::fmt::Formatter<'_>,
723 input: &Variable<D>,
724 out: &Variable<D>,
725 ) -> core::fmt::Result {
726 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
727 }
728 fn warp_reduce_any(
729 f: &mut core::fmt::Formatter<'_>,
730 input: &Variable<D>,
731 out: &Variable<D>,
732 ) -> core::fmt::Result {
733 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
734 }
735 fn warp_reduce_sum_inclusive(
736 f: &mut core::fmt::Formatter<'_>,
737 input: &Variable<D>,
738 out: &Variable<D>,
739 ) -> core::fmt::Result {
740 reduce_inclusive(f, input, out, "+=")
741 }
742 fn warp_reduce_prod_inclusive(
743 f: &mut core::fmt::Formatter<'_>,
744 input: &Variable<D>,
745 out: &Variable<D>,
746 ) -> core::fmt::Result {
747 reduce_inclusive(f, input, out, "*=")
748 }
749 fn warp_reduce_sum_exclusive(
750 f: &mut core::fmt::Formatter<'_>,
751 input: &Variable<D>,
752 out: &Variable<D>,
753 ) -> core::fmt::Result {
754 reduce_exclusive(f, input, out, "+=", "0")
755 }
756 fn warp_reduce_prod_exclusive(
757 f: &mut core::fmt::Formatter<'_>,
758 input: &Variable<D>,
759 out: &Variable<D>,
760 ) -> core::fmt::Result {
761 reduce_exclusive(f, input, out, "*=", "1")
762 }
763}
764
765pub trait DialectWmmaCompiler<D: Dialect>:
766 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
767{
768 #[allow(unused_variables)]
769 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
770 Ok(())
771 }
772 #[allow(unused_variables)]
773 fn compile_wmma_type_definitions(
774 f: &mut std::fmt::Formatter<'_>,
775 flags: &Flags,
776 ) -> std::fmt::Result {
777 Ok(())
778 }
779 #[allow(unused_variables)]
780 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
781 Ok(())
782 }
783 #[allow(unused_variables)]
784 fn compile_wwma_fragment_ident(
785 f: &mut std::fmt::Formatter<'_>,
786 ident: &FragmentIdent<D>,
787 ) -> std::fmt::Result {
788 Ok(())
789 }
790 #[allow(unused_variables)]
791 fn compile_wmma_fragment_layout(
792 f: &mut std::fmt::Formatter<'_>,
793 layout: &FragmentLayout<D>,
794 ) -> std::fmt::Result {
795 Ok(())
796 }
797 #[allow(unused_variables)]
798 fn compile_wmma_fragment(
799 f: &mut std::fmt::Formatter<'_>,
800 fragment: &Fragment<D>,
801 ) -> std::fmt::Result {
802 Ok(())
803 }
804
805 fn compile_wmma_fragment_declaration(
806 f: &mut std::fmt::Formatter<'_>,
807 var: &Variable<D>,
808 ) -> std::fmt::Result;
809
810 fn compile_wmma_instruction(
811 f: &mut std::fmt::Formatter<'_>,
812 instruction: &WmmaInstruction<D>,
813 ) -> std::fmt::Result;
814 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
815 fn compile_scaled_mma(
816 f: &mut std::fmt::Formatter<'_>,
817 mma: ManualMma<D>,
818 scales_a: Variable<D>,
819 scales_b: Variable<D>,
820 scales_factor: u32,
821 ) -> std::fmt::Result;
822 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
823 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
824 fn supported_scaled_mma_combinations(
825 _arch: &D::Architecture,
826 ) -> SupportedScaledMmaCombinations {
827 Vec::new()
828 }
829}
830
831pub trait DialectProcessors<D: Dialect> {
834 fn processors() -> Vec<Box<dyn Processor>>;
835}