1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14 WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90 AtomicKind::_Dialect(_) => Ok(()),
91 }
92 }
93
94 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95 fn compile_type_definitions(
96 f: &mut std::fmt::Formatter<'_>,
97 items: &HashSet<Item<D>>,
98 scalars: &[(Elem<D>, usize)],
99 flags: &Flags,
100 ) -> std::fmt::Result;
101 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102 fn compile_shared_memory_declaration(
103 f: &mut std::fmt::Formatter<'_>,
104 shared: &SharedMemory<D>,
105 ) -> std::fmt::Result {
106 match shared {
107 SharedMemory::Array {
108 index,
109 item,
110 length,
111 offset,
112 ..
113 } => {
114 let size_bytes = *length * item.size() as u32;
115 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116 writeln!(
117 f,
118 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119 )
120 }
121 SharedMemory::Value {
122 index,
123 item,
124 offset,
125 ..
126 } => {
127 let size_bytes = item.size() as u32;
128 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129 writeln!(
130 f,
131 "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132 )
133 }
134 }
135 }
136 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
137 Ok(())
138 }
139 fn address_space_for_variable(_variable: &Variable<D>) -> String {
141 "".to_string()
142 }
143}
144
145pub trait DialectBindings<D: Dialect> {
148 fn compile_kernel_signature(
149 f: &mut std::fmt::Formatter<'_>,
150 kernel_name: &str,
151 tensor_maps: &[Binding<D>],
152 buffers: &[Binding<D>],
153 scalars: &[(Elem<D>, usize)],
154 flags: &Flags,
155 ) -> std::fmt::Result;
156 fn compile_bindings_body(
157 _f: &mut std::fmt::Formatter<'_>,
158 _body: &Body<D>,
159 ) -> std::fmt::Result {
160 Ok(())
161 }
162}
163
164pub trait DialectCubeBuiltins<D: Dialect> {
167 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173 let unit_pos_plane = flags.unit_pos_plane;
174 let plane_dim_checked = flags.plane_dim_checked;
175 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176 let plane_index = flags.plane_index;
177 let absolute_pos = flags.absolute_pos || unit_pos_plane;
178 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179 let cube_dim = flags.cube_dim;
180 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181 let unit_pos = flags.unit_pos;
182 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183 let cube_count = flags.cube_count;
184 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185 let cube_pos = flags.cube_pos;
186 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187 let cluster_group = flags.cluster_pos;
188
189 CubeIndexFlags {
190 absolute_pos,
191 absolute_pos_tuple,
192 cube_count,
193 cube_count_tuple,
194 cube_dim,
195 cube_dim_tuple,
196 cube_pos,
197 cube_pos_tuple,
198 plane_dim,
199 plane_dim_checked,
200 plane_index,
201 unit_pos_tuple,
202 unit_pos,
203 unit_pos_plane,
204 cluster_pos: cluster_group,
205 }
206 }
207
208 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 let variable = Variable::<D>::AbsolutePosBaseName;
210 let ty = variable.item();
211 let cube_pos_x = Variable::<D>::CubePosX;
212 let cube_pos_y = Variable::<D>::CubePosY;
213 let cube_pos_z = Variable::<D>::CubePosZ;
214 let cube_dim_x = Variable::<D>::CubeDimX;
215 let cube_dim_y = Variable::<D>::CubeDimY;
216 let cube_dim_z = Variable::<D>::CubeDimZ;
217 let unit_pos_x = Variable::<D>::UnitPosX;
218 let unit_pos_y = Variable::<D>::UnitPosY;
219 let unit_pos_z = Variable::<D>::UnitPosZ;
220 writeln!(
221 f,
222 "{ty} {variable} = make_{ty}(
223 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227 )
228 }
229
230 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.write_str("absoluteIdx")
232 }
233
234 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("idxGlobal")
236 }
237
238 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 Self::compile_absolute_pos_base_name(f)?;
240 write!(f, ".x")
241 }
242
243 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 Self::compile_absolute_pos_base_name(f)?;
245 write!(f, ".y")
246 }
247
248 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 Self::compile_absolute_pos_base_name(f)?;
250 write!(f, ".z")
251 }
252
253 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.write_str("gridDim")
255 }
256
257 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.write_str("gridDimGlobal")
259 }
260
261 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 Self::compile_cube_count_base_name(f)?;
263 write!(f, ".x")
264 }
265
266 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 Self::compile_cube_count_base_name(f)?;
268 write!(f, ".y")
269 }
270
271 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 Self::compile_cube_count_base_name(f)?;
273 write!(f, ".z")
274 }
275
276 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_str("blockDim")
278 }
279
280 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.write_str("blockDimGlobal")
282 }
283
284 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 Self::compile_cube_dim_base_name(f)?;
286 write!(f, ".x")
287 }
288
289 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 Self::compile_cube_dim_base_name(f)?;
291 write!(f, ".y")
292 }
293
294 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 Self::compile_cube_dim_base_name(f)?;
296 write!(f, ".z")
297 }
298
299 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.write_str("blockIdx")
301 }
302
303 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 f.write_str("blockIdxGlobal")
305 }
306
307 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 Self::compile_cube_pos_base_name(f)?;
309 write!(f, ".x")
310 }
311
312 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 Self::compile_cube_pos_base_name(f)?;
314 write!(f, ".y")
315 }
316
317 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 Self::compile_cube_pos_base_name(f)?;
319 write!(f, ".z")
320 }
321
322 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 let variable = Variable::<D>::UnitPos;
324 let ty = variable.item();
325 let cube_dim_x = Variable::<D>::CubeDimX;
326 let cube_dim_y = Variable::<D>::CubeDimY;
327 let unit_pos_x = Variable::<D>::UnitPosX;
328 let unit_pos_y = Variable::<D>::UnitPosY;
329 let unit_pos_z = Variable::<D>::UnitPosZ;
330 writeln!(
331 f,
332 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333 )
334 }
335
336 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.write_str("threadIdxGlobal")
338 }
339
340 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.write_str("threadIdx")
342 }
343
344 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 Self::compile_unit_pos_base_name(f)?;
346 write!(f, ".x")
347 }
348
349 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 Self::compile_unit_pos_base_name(f)?;
351 write!(f, ".y")
352 }
353
354 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 Self::compile_unit_pos_base_name(f)?;
356 write!(f, ".z")
357 }
358
359 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.write_str("warpSize")
361 }
362
363 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.write_str("warpSizeChecked")
365 }
366
367 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let unit_pos_x = Variable::<D>::UnitPosX;
369 let plane_dim = Variable::<D>::PlaneDim;
370 write!(f, "{unit_pos_x} / {plane_dim}")
371 }
372
373 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 let absolute_pos = Variable::<D>::AbsolutePos;
375 let plane_dim = Variable::<D>::PlaneDim;
376 write!(f, "{absolute_pos} % {plane_dim}")
377 }
378
379 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(f, "0")
381 }
382 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 write!(f, "0")
384 }
385 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 write!(f, "0")
387 }
388 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 write!(f, "0")
390 }
391}
392
393pub trait DialectInstructions<D: Dialect> {
396 fn compile_atomic_add(
398 f: &mut std::fmt::Formatter<'_>,
399 lhs: &Variable<D>,
400 rhs: &Variable<D>,
401 out: &Variable<D>,
402 ) -> std::fmt::Result {
403 let out = out.fmt_left();
404 match rhs.elem() {
405 Elem::I64 => writeln!(
406 f,
407 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
408 uint = Elem::<D>::U64
409 ),
410 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
411 }
412 }
413
414 fn compile_atomic_and(
415 f: &mut std::fmt::Formatter<'_>,
416 lhs: &Variable<D>,
417 rhs: &Variable<D>,
418 out: &Variable<D>,
419 ) -> std::fmt::Result {
420 let out = out.fmt_left();
421 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
422 }
423
424 fn compile_atomic_cas(
425 f: &mut std::fmt::Formatter<'_>,
426 input: &Variable<D>,
427 cmp: &Variable<D>,
428 val: &Variable<D>,
429 out: &Variable<D>,
430 ) -> std::fmt::Result {
431 let out = out.fmt_left();
432 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
433 }
434
435 fn compile_atomic_load(
436 f: &mut std::fmt::Formatter<'_>,
437 input: &Variable<D>,
438 out: &Variable<D>,
439 ) -> std::fmt::Result {
440 let out = out.fmt_left();
441 writeln!(f, "{out} = atomicAdd({input}, 0);")
442 }
443
444 fn compile_atomic_max(
445 f: &mut std::fmt::Formatter<'_>,
446 lhs: &Variable<D>,
447 rhs: &Variable<D>,
448 out: &Variable<D>,
449 ) -> std::fmt::Result {
450 let out = out.fmt_left();
451 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
452 }
453
454 fn compile_atomic_min(
455 f: &mut std::fmt::Formatter<'_>,
456 lhs: &Variable<D>,
457 rhs: &Variable<D>,
458 out: &Variable<D>,
459 ) -> std::fmt::Result {
460 let out = out.fmt_left();
461 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
462 }
463
464 fn compile_atomic_or(
465 f: &mut std::fmt::Formatter<'_>,
466 lhs: &Variable<D>,
467 rhs: &Variable<D>,
468 out: &Variable<D>,
469 ) -> std::fmt::Result {
470 let out = out.fmt_left();
471 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
472 }
473
474 fn compile_atomic_store(
475 f: &mut std::fmt::Formatter<'_>,
476 input: &Variable<D>,
477 out: &Variable<D>,
478 ) -> std::fmt::Result {
479 writeln!(f, "atomicExch({out}, {input});")
480 }
481
482 fn compile_atomic_sub(
483 f: &mut std::fmt::Formatter<'_>,
484 lhs: &Variable<D>,
485 rhs: &Variable<D>,
486 out: &Variable<D>,
487 ) -> std::fmt::Result {
488 let out = out.fmt_left();
489 match rhs.elem() {
490 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
491 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
492 Elem::I64 => writeln!(
493 f,
494 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
495 uint = Elem::<D>::U64
496 ),
497 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
498 }
499 }
500
501 fn compile_atomic_swap(
502 f: &mut std::fmt::Formatter<'_>,
503 lhs: &Variable<D>,
504 rhs: &Variable<D>,
505 out: &Variable<D>,
506 ) -> std::fmt::Result {
507 let out = out.fmt_left();
508 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
509 }
510
511 fn compile_atomic_xor(
512 f: &mut std::fmt::Formatter<'_>,
513 lhs: &Variable<D>,
514 rhs: &Variable<D>,
515 out: &Variable<D>,
516 ) -> std::fmt::Result {
517 let out = out.fmt_left();
518 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
519 }
520
521 fn compile_saturating_add(
522 f: &mut std::fmt::Formatter<'_>,
523 lhs: impl Display,
524 rhs: impl Display,
525 item: Item<D>,
526 ) -> std::fmt::Result;
527
528 fn compile_saturating_sub(
529 f: &mut std::fmt::Formatter<'_>,
530 lhs: impl Display,
531 rhs: impl Display,
532 item: Item<D>,
533 ) -> std::fmt::Result;
534
535 fn compile_instruction_printf(
537 f: &mut std::fmt::Formatter<'_>,
538 format_string: &str,
539 args: &[Variable<D>],
540 ) -> std::fmt::Result {
541 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
542 let args = match args.is_empty() {
543 true => "".to_string(),
544 false => format!(", {}", args.join(",")),
545 };
546 writeln!(f, "printf({format_string:?}{args});")
547 }
548
549 fn compile_instruction_log1p_scalar<T: Component<D>>(
551 f: &mut std::fmt::Formatter<'_>,
552 input: T,
553 ) -> std::fmt::Result {
554 let elem = input.elem();
555 match elem {
556 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557 write!(f, "{elem}(log1p(float({input})))")
558 }
559 _ => write!(f, "log1p({input})"),
560 }
561 }
562
563 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
565 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567
568 fn compile_instruction_tanh_scalar<T: Component<D>>(
570 f: &mut std::fmt::Formatter<'_>,
571 input: T,
572 ) -> std::fmt::Result {
573 let elem = input.elem();
574 match elem {
575 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
576 write!(f, "{elem}(tanh(float({input})))")
577 }
578 _ => write!(f, "tanh({input})"),
579 }
580 }
581
582 fn compile_instruction_find_first_set<T: Component<D>>(
584 f: &mut std::fmt::Formatter<'_>,
585 input: T,
586 out_elem: Elem<D>,
587 ) -> std::fmt::Result;
588 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
589 f: &mut std::fmt::Formatter<'_>,
590 input: T,
591 out_elem: Elem<D>,
592 ) -> std::fmt::Result;
593
594 fn compile_instruction_popcount_scalar<T: Component<D>>(
595 f: &mut std::fmt::Formatter<'_>,
596 input: T,
597 out_elem: Elem<D>,
598 ) -> std::fmt::Result {
599 write!(f, "{out_elem}(")?;
600 match input.elem() {
601 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
602 Elem::U32 => write!(f, "__popc({input})"),
603 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
604 Elem::U64 => write!(f, "__popcll({input})"),
605 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
606 }?;
607 write!(f, ")")
608 }
609
610 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
611 f: &mut std::fmt::Formatter<'_>,
612 input: T,
613 out_elem: Elem<D>,
614 ) -> std::fmt::Result {
615 write!(f, "{out_elem}(")?;
616 match out_elem {
617 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
618 Elem::U32 => write!(f, "__brev({input})"),
619 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
620 Elem::U64 => write!(f, "__brevll({input})"),
621 _ => write!(
622 f,
623 "__brev({}) >> {}",
624 super::unary::zero_extend(input),
625 (size_of::<u32>() - out_elem.size()) * 8
626 ),
627 }?;
628 write!(f, ")")
629 }
630
631 fn compile_instruction_max_function_name(
633 f: &mut std::fmt::Formatter<'_>,
634 item: Item<D>,
635 ) -> std::fmt::Result;
636
637 fn compile_instruction_min_function_name(
638 f: &mut std::fmt::Formatter<'_>,
639 item: Item<D>,
640 ) -> std::fmt::Result;
641
642 fn compile_instruction_powf(
643 f: &mut std::fmt::Formatter<'_>,
644 lhs: &str,
645 rhs: &str,
646 elem: Elem<D>,
647 ) -> std::fmt::Result {
648 match elem {
649 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
650 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
651 _ => panic!("Unsupported type for powf"),
652 }
653 }
654
655 fn compile_instruction_half_function_name_prefix() -> &'static str {
656 "h"
657 }
658
659 fn compile_instruction_half2_function_name_prefix() -> &'static str {
660 "h2"
661 }
662
663 fn compile_warp_shuffle(
665 f: &mut std::fmt::Formatter<'_>,
666 var: &str,
667 source: &str,
668 ) -> std::fmt::Result;
669 fn compile_warp_shuffle_xor(
670 f: &mut std::fmt::Formatter<'_>,
671 var: &str,
672 elem: &Elem<D>,
673 offset: &str,
674 ) -> std::fmt::Result;
675 fn compile_warp_shuffle_up(
676 f: &mut std::fmt::Formatter<'_>,
677 var: &str,
678 offset: &str,
679 ) -> std::fmt::Result;
680 fn compile_warp_shuffle_down(
681 f: &mut std::fmt::Formatter<'_>,
682 var: &str,
683 offset: &str,
684 ) -> std::fmt::Result;
685 fn compile_warp_all<T: Component<D>>(
686 f: &mut std::fmt::Formatter<'_>,
687 input: &T,
688 ) -> std::fmt::Result;
689 fn compile_warp_any<T: Component<D>>(
690 f: &mut std::fmt::Formatter<'_>,
691 input: &T,
692 ) -> std::fmt::Result;
693 fn compile_warp_ballot(
694 f: &mut std::fmt::Formatter<'_>,
695 input: &Variable<D>,
696 out_elem: &Elem<D>,
697 ) -> std::fmt::Result;
698 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
699 write!(
700 f,
701 "
702unsigned int mask = __activemask();
703unsigned int leader = __ffs(mask) - 1;
704{out} = threadIdx.x % warpSize == leader;
705 "
706 )
707 }
708}
709
710#[derive(Debug, Clone, Copy, new)]
711pub struct ManualMma<'a, D: Dialect> {
712 pub shape: MmaShape<D>,
713 pub frag_a: &'a Variable<D>,
714 pub frag_b: &'a Variable<D>,
715 pub frag_c: &'a Variable<D>,
716 pub frag_d: &'a Variable<D>,
717}
718
719pub trait DialectWarpReduceCompiler<D: Dialect>:
720 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
721{
722 fn warp_reduce_sum(
723 f: &mut core::fmt::Formatter<'_>,
724 input: &Variable<D>,
725 out: &Variable<D>,
726 ) -> core::fmt::Result {
727 reduce_operator(f, input, out, "+=")
728 }
729 fn warp_reduce_prod(
730 f: &mut core::fmt::Formatter<'_>,
731 input: &Variable<D>,
732 out: &Variable<D>,
733 ) -> core::fmt::Result {
734 reduce_operator(f, input, out, "*=")
735 }
736 fn warp_reduce_max(
737 f: &mut core::fmt::Formatter<'_>,
738 input: &Variable<D>,
739 out: &Variable<D>,
740 ) -> core::fmt::Result {
741 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
742 }
743 fn warp_reduce_min(
744 f: &mut core::fmt::Formatter<'_>,
745 input: &Variable<D>,
746 out: &Variable<D>,
747 ) -> core::fmt::Result {
748 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
749 }
750 fn warp_reduce_all(
751 f: &mut core::fmt::Formatter<'_>,
752 input: &Variable<D>,
753 out: &Variable<D>,
754 ) -> core::fmt::Result {
755 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
756 }
757 fn warp_reduce_any(
758 f: &mut core::fmt::Formatter<'_>,
759 input: &Variable<D>,
760 out: &Variable<D>,
761 ) -> core::fmt::Result {
762 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
763 }
764 fn warp_reduce_sum_inclusive(
765 f: &mut core::fmt::Formatter<'_>,
766 input: &Variable<D>,
767 out: &Variable<D>,
768 ) -> core::fmt::Result {
769 reduce_inclusive(f, input, out, "+=")
770 }
771 fn warp_reduce_prod_inclusive(
772 f: &mut core::fmt::Formatter<'_>,
773 input: &Variable<D>,
774 out: &Variable<D>,
775 ) -> core::fmt::Result {
776 reduce_inclusive(f, input, out, "*=")
777 }
778 fn warp_reduce_sum_exclusive(
779 f: &mut core::fmt::Formatter<'_>,
780 input: &Variable<D>,
781 out: &Variable<D>,
782 ) -> core::fmt::Result {
783 reduce_exclusive(f, input, out, "+=", "0")
784 }
785 fn warp_reduce_prod_exclusive(
786 f: &mut core::fmt::Formatter<'_>,
787 input: &Variable<D>,
788 out: &Variable<D>,
789 ) -> core::fmt::Result {
790 reduce_exclusive(f, input, out, "*=", "1")
791 }
792}
793
794pub trait DialectWmmaCompiler<D: Dialect>:
795 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
796{
797 #[allow(unused_variables)]
798 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
799 Ok(())
800 }
801 #[allow(unused_variables)]
802 fn compile_wmma_type_definitions(
803 f: &mut std::fmt::Formatter<'_>,
804 flags: &Flags,
805 ) -> std::fmt::Result {
806 Ok(())
807 }
808 #[allow(unused_variables)]
809 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
810 Ok(())
811 }
812 #[allow(unused_variables)]
813 fn compile_wwma_fragment_ident(
814 f: &mut std::fmt::Formatter<'_>,
815 ident: &FragmentIdent<D>,
816 ) -> std::fmt::Result {
817 Ok(())
818 }
819 #[allow(unused_variables)]
820 fn compile_wmma_fragment_layout(
821 f: &mut std::fmt::Formatter<'_>,
822 layout: &FragmentLayout<D>,
823 ) -> std::fmt::Result {
824 Ok(())
825 }
826 #[allow(unused_variables)]
827 fn compile_wmma_fragment(
828 f: &mut std::fmt::Formatter<'_>,
829 fragment: &Fragment<D>,
830 ) -> std::fmt::Result {
831 Ok(())
832 }
833
834 fn compile_wmma_fragment_declaration(
835 f: &mut std::fmt::Formatter<'_>,
836 var: &Variable<D>,
837 ) -> std::fmt::Result;
838
839 fn compile_wmma_instruction(
840 f: &mut std::fmt::Formatter<'_>,
841 instruction: &WmmaInstruction<D>,
842 ) -> std::fmt::Result;
843 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
844 fn compile_scaled_mma(
845 f: &mut std::fmt::Formatter<'_>,
846 mma: ManualMma<D>,
847 scales_a: Variable<D>,
848 scales_b: Variable<D>,
849 scales_factor: u32,
850 ) -> std::fmt::Result;
851 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
852 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
853 fn supported_scaled_mma_combinations(
854 _arch: &D::Architecture,
855 ) -> SupportedScaledMmaCombinations {
856 Vec::new()
857 }
858}
859
860pub trait DialectProcessors<D: Dialect> {
863 fn processors() -> Vec<Box<dyn Processor>>;
864}