1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14 WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90 AtomicKind::_Dialect(_) => Ok(()),
91 }
92 }
93
94 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95 fn compile_type_definitions(
96 f: &mut std::fmt::Formatter<'_>,
97 items: &HashSet<Item<D>>,
98 scalars: &[(Elem<D>, usize)],
99 flags: &Flags,
100 ) -> std::fmt::Result;
101 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102 fn compile_shared_memory_declaration(
103 f: &mut std::fmt::Formatter<'_>,
104 shared: &SharedMemory<D>,
105 ) -> std::fmt::Result {
106 match shared {
107 SharedMemory::Array {
108 index,
109 item,
110 length,
111 offset,
112 ..
113 } => {
114 let size_bytes = *length * item.size() as u32;
115 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116 writeln!(
117 f,
118 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119 )
120 }
121 SharedMemory::Value {
122 index,
123 item,
124 offset,
125 ..
126 } => {
127 let size_bytes = item.size() as u32;
128 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129 writeln!(
130 f,
131 "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132 )
133 }
134 }
135 }
136 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
137 Ok(())
138 }
139 fn address_space_for_variable(_variable: &Variable<D>) -> String {
141 "".to_string()
142 }
143}
144
145pub trait DialectBindings<D: Dialect> {
148 fn compile_kernel_signature(
149 f: &mut std::fmt::Formatter<'_>,
150 kernel_name: &str,
151 tensor_maps: &[Binding<D>],
152 buffers: &[Binding<D>],
153 scalars: &[(Elem<D>, usize)],
154 flags: &Flags,
155 ) -> std::fmt::Result;
156 fn compile_bindings_body(
157 _f: &mut std::fmt::Formatter<'_>,
158 _body: &Body<D>,
159 ) -> std::fmt::Result {
160 Ok(())
161 }
162}
163
164pub trait DialectCubeBuiltins<D: Dialect> {
167 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173 let unit_pos_plane = flags.unit_pos_plane;
174 let plane_dim_checked = flags.plane_dim_checked;
175 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176 let plane_index = flags.plane_index;
177 let absolute_pos = flags.absolute_pos || unit_pos_plane;
178 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179 let cube_dim = flags.cube_dim;
180 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181 let unit_pos = flags.unit_pos;
182 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183 let cube_count = flags.cube_count;
184 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185 let cube_pos = flags.cube_pos;
186 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187 let cluster_group = flags.cluster_pos;
188
189 CubeIndexFlags {
190 absolute_pos,
191 absolute_pos_tuple,
192 cube_count,
193 cube_count_tuple,
194 cube_dim,
195 cube_dim_tuple,
196 cube_pos,
197 cube_pos_tuple,
198 plane_dim,
199 plane_dim_checked,
200 plane_index,
201 unit_pos_tuple,
202 unit_pos,
203 unit_pos_plane,
204 cluster_pos: cluster_group,
205 }
206 }
207
208 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 let variable = Variable::<D>::AbsolutePosBaseName;
210 let ty = variable.item();
211 let cube_pos_x = Variable::<D>::CubePosX;
212 let cube_pos_y = Variable::<D>::CubePosY;
213 let cube_pos_z = Variable::<D>::CubePosZ;
214 let cube_dim_x = Variable::<D>::CubeDimX;
215 let cube_dim_y = Variable::<D>::CubeDimY;
216 let cube_dim_z = Variable::<D>::CubeDimZ;
217 let unit_pos_x = Variable::<D>::UnitPosX;
218 let unit_pos_y = Variable::<D>::UnitPosY;
219 let unit_pos_z = Variable::<D>::UnitPosZ;
220 writeln!(
221 f,
222 "{ty} {variable} = make_{ty}(
223 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227 )
228 }
229
230 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.write_str("absoluteIdx")
232 }
233
234 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("idxGlobal")
236 }
237
238 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 Self::compile_absolute_pos_base_name(f)?;
240 write!(f, ".x")
241 }
242
243 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 Self::compile_absolute_pos_base_name(f)?;
245 write!(f, ".y")
246 }
247
248 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 Self::compile_absolute_pos_base_name(f)?;
250 write!(f, ".z")
251 }
252
253 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.write_str("gridDim")
255 }
256
257 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.write_str("gridDimGlobal")
259 }
260
261 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 Self::compile_cube_count_base_name(f)?;
263 write!(f, ".x")
264 }
265
266 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 Self::compile_cube_count_base_name(f)?;
268 write!(f, ".y")
269 }
270
271 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 Self::compile_cube_count_base_name(f)?;
273 write!(f, ".z")
274 }
275
276 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_str("blockDim")
278 }
279
280 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.write_str("blockDimGlobal")
282 }
283
284 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 Self::compile_cube_dim_base_name(f)?;
286 write!(f, ".x")
287 }
288
289 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 Self::compile_cube_dim_base_name(f)?;
291 write!(f, ".y")
292 }
293
294 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 Self::compile_cube_dim_base_name(f)?;
296 write!(f, ".z")
297 }
298
299 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.write_str("blockIdx")
301 }
302
303 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 f.write_str("blockIdxGlobal")
305 }
306
307 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 Self::compile_cube_pos_base_name(f)?;
309 write!(f, ".x")
310 }
311
312 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 Self::compile_cube_pos_base_name(f)?;
314 write!(f, ".y")
315 }
316
317 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 Self::compile_cube_pos_base_name(f)?;
319 write!(f, ".z")
320 }
321
322 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 let variable = Variable::<D>::UnitPos;
324 let ty = variable.item();
325 let cube_dim_x = Variable::<D>::CubeDimX;
326 let cube_dim_y = Variable::<D>::CubeDimY;
327 let unit_pos_x = Variable::<D>::UnitPosX;
328 let unit_pos_y = Variable::<D>::UnitPosY;
329 let unit_pos_z = Variable::<D>::UnitPosZ;
330 writeln!(
331 f,
332 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333 )
334 }
335
336 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.write_str("threadIdxGlobal")
338 }
339
340 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.write_str("threadIdx")
342 }
343
344 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 Self::compile_unit_pos_base_name(f)?;
346 write!(f, ".x")
347 }
348
349 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 Self::compile_unit_pos_base_name(f)?;
351 write!(f, ".y")
352 }
353
354 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 Self::compile_unit_pos_base_name(f)?;
356 write!(f, ".z")
357 }
358
359 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.write_str("warpSize")
361 }
362
363 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.write_str("warpSizeChecked")
365 }
366
367 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let unit_pos_x = Variable::<D>::UnitPosX;
369 let plane_dim = Variable::<D>::PlaneDim;
370 write!(f, "{unit_pos_x} / {plane_dim}")
371 }
372
373 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 let absolute_pos = Variable::<D>::AbsolutePos;
375 let plane_dim = Variable::<D>::PlaneDim;
376 write!(f, "{absolute_pos} % {plane_dim}")
377 }
378
379 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(f, "0")
381 }
382 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 write!(f, "0")
384 }
385 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 write!(f, "0")
387 }
388 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 write!(f, "0")
390 }
391}
392
393pub trait DialectInstructions<D: Dialect> {
396 fn compile_atomic_add(
398 f: &mut std::fmt::Formatter<'_>,
399 lhs: &Variable<D>,
400 rhs: &Variable<D>,
401 out: &Variable<D>,
402 ) -> std::fmt::Result {
403 let out = out.fmt_left();
404 match rhs.elem() {
405 Elem::I64 => writeln!(
406 f,
407 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
408 uint = Elem::<D>::U64
409 ),
410 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
411 }
412 }
413
414 fn compile_atomic_and(
415 f: &mut std::fmt::Formatter<'_>,
416 lhs: &Variable<D>,
417 rhs: &Variable<D>,
418 out: &Variable<D>,
419 ) -> std::fmt::Result {
420 let out = out.fmt_left();
421 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
422 }
423
424 fn compile_atomic_cas(
425 f: &mut std::fmt::Formatter<'_>,
426 input: &Variable<D>,
427 cmp: &Variable<D>,
428 val: &Variable<D>,
429 out: &Variable<D>,
430 ) -> std::fmt::Result {
431 let out = out.fmt_left();
432 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
433 }
434
435 fn compile_atomic_load(
436 f: &mut std::fmt::Formatter<'_>,
437 input: &Variable<D>,
438 out: &Variable<D>,
439 ) -> std::fmt::Result {
440 let out = out.fmt_left();
441 writeln!(f, "{out} = atomicAdd({input}, 0);")
442 }
443
444 fn compile_atomic_max(
445 f: &mut std::fmt::Formatter<'_>,
446 lhs: &Variable<D>,
447 rhs: &Variable<D>,
448 out: &Variable<D>,
449 ) -> std::fmt::Result {
450 let out = out.fmt_left();
451 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
452 }
453
454 fn compile_atomic_min(
455 f: &mut std::fmt::Formatter<'_>,
456 lhs: &Variable<D>,
457 rhs: &Variable<D>,
458 out: &Variable<D>,
459 ) -> std::fmt::Result {
460 let out = out.fmt_left();
461 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
462 }
463
464 fn compile_atomic_or(
465 f: &mut std::fmt::Formatter<'_>,
466 lhs: &Variable<D>,
467 rhs: &Variable<D>,
468 out: &Variable<D>,
469 ) -> std::fmt::Result {
470 let out = out.fmt_left();
471 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
472 }
473
474 fn compile_atomic_store(
475 f: &mut std::fmt::Formatter<'_>,
476 input: &Variable<D>,
477 out: &Variable<D>,
478 ) -> std::fmt::Result {
479 writeln!(f, "atomicExch({out}, {input});")
480 }
481
482 fn compile_atomic_sub(
483 f: &mut std::fmt::Formatter<'_>,
484 lhs: &Variable<D>,
485 rhs: &Variable<D>,
486 out: &Variable<D>,
487 ) -> std::fmt::Result {
488 let out = out.fmt_left();
489 match rhs.elem() {
490 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
491 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
492 Elem::I64 => writeln!(
493 f,
494 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
495 uint = Elem::<D>::U64
496 ),
497 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
498 }
499 }
500
501 fn compile_atomic_swap(
502 f: &mut std::fmt::Formatter<'_>,
503 lhs: &Variable<D>,
504 rhs: &Variable<D>,
505 out: &Variable<D>,
506 ) -> std::fmt::Result {
507 let out = out.fmt_left();
508 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
509 }
510
511 fn compile_atomic_xor(
512 f: &mut std::fmt::Formatter<'_>,
513 lhs: &Variable<D>,
514 rhs: &Variable<D>,
515 out: &Variable<D>,
516 ) -> std::fmt::Result {
517 let out = out.fmt_left();
518 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
519 }
520
521 fn compile_saturating_add(
522 f: &mut std::fmt::Formatter<'_>,
523 lhs: impl Display,
524 rhs: impl Display,
525 item: Item<D>,
526 ) -> std::fmt::Result;
527
528 fn compile_saturating_sub(
529 f: &mut std::fmt::Formatter<'_>,
530 lhs: impl Display,
531 rhs: impl Display,
532 item: Item<D>,
533 ) -> std::fmt::Result;
534
535 fn compile_instruction_printf(
537 f: &mut std::fmt::Formatter<'_>,
538 format_string: &str,
539 args: &[Variable<D>],
540 ) -> std::fmt::Result {
541 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
542 let args = match args.is_empty() {
543 true => "".to_string(),
544 false => format!(", {}", args.join(",")),
545 };
546 writeln!(f, "printf({format_string:?}{args});")
547 }
548
549 fn compile_instruction_log1p_scalar<T: Component<D>>(
551 f: &mut std::fmt::Formatter<'_>,
552 input: T,
553 ) -> std::fmt::Result {
554 let elem = input.elem();
555 match elem {
556 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
557 write!(f, "{elem}(log1p(float({input})))")
558 }
559 _ => write!(f, "log1p({input})"),
560 }
561 }
562
563 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
565 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567
568 fn compile_instruction_tanh_scalar<T: Component<D>>(
570 f: &mut std::fmt::Formatter<'_>,
571 input: T,
572 ) -> std::fmt::Result {
573 let elem = input.elem();
574 match elem {
575 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
576 write!(f, "{elem}(tanh(float({input})))")
577 }
578 _ => write!(f, "tanh({input})"),
579 }
580 }
581
582 fn compile_instruction_find_first_set<T: Component<D>>(
584 f: &mut std::fmt::Formatter<'_>,
585 input: T,
586 out_elem: Elem<D>,
587 ) -> std::fmt::Result;
588 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
589 f: &mut std::fmt::Formatter<'_>,
590 input: T,
591 out_elem: Elem<D>,
592 ) -> std::fmt::Result;
593
594 fn compile_instruction_popcount_scalar<T: Component<D>>(
595 f: &mut std::fmt::Formatter<'_>,
596 input: T,
597 out_elem: Elem<D>,
598 ) -> std::fmt::Result {
599 write!(f, "{out_elem}(")?;
600 match input.elem() {
601 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
602 Elem::U32 => write!(f, "__popc({input})"),
603 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
604 Elem::U64 => write!(f, "__popcll({input})"),
605 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
606 }?;
607 write!(f, ")")
608 }
609
610 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
611 f: &mut std::fmt::Formatter<'_>,
612 input: T,
613 out_elem: Elem<D>,
614 ) -> std::fmt::Result {
615 write!(f, "{out_elem}(")?;
616 match out_elem {
617 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
618 Elem::U32 => write!(f, "__brev({input})"),
619 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
620 Elem::U64 => write!(f, "__brevll({input})"),
621 _ => write!(
622 f,
623 "__brev({}) >> {}",
624 super::unary::zero_extend(input),
625 (size_of::<u32>() - out_elem.size()) * 8
626 ),
627 }?;
628 write!(f, ")")
629 }
630
631 fn compile_instruction_max_function_name(
633 f: &mut std::fmt::Formatter<'_>,
634 item: Item<D>,
635 ) -> std::fmt::Result;
636
637 fn compile_instruction_min_function_name(
638 f: &mut std::fmt::Formatter<'_>,
639 item: Item<D>,
640 ) -> std::fmt::Result;
641
642 fn compile_instruction_powf(
643 f: &mut std::fmt::Formatter<'_>,
644 lhs: &str,
645 rhs: &str,
646 elem: Elem<D>,
647 ) -> std::fmt::Result {
648 match elem {
649 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
650 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
651 _ => write!(f, "#error Unsupported type for powf: {elem}"),
652 }
653 }
654
655 fn compile_instruction_hypot(
656 f: &mut std::fmt::Formatter<'_>,
657 lhs: &str,
658 rhs: &str,
659 elem: Elem<D>,
660 ) -> std::fmt::Result {
661 match elem {
662 Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
663 Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
664 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
665 }
666 }
667
668 fn compile_instruction_rhypot(
669 f: &mut std::fmt::Formatter<'_>,
670 lhs: &str,
671 rhs: &str,
672 elem: Elem<D>,
673 ) -> std::fmt::Result {
674 match elem {
675 Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
676 Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
677 _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
678 }
679 }
680
681 fn compile_instruction_half_function_name_prefix() -> &'static str {
682 "h"
683 }
684
685 fn compile_instruction_half2_function_name_prefix() -> &'static str {
686 "h2"
687 }
688
689 fn compile_warp_shuffle(
691 f: &mut std::fmt::Formatter<'_>,
692 var: &str,
693 source: &str,
694 ) -> std::fmt::Result;
695 fn compile_warp_shuffle_xor(
696 f: &mut std::fmt::Formatter<'_>,
697 var: &str,
698 elem: &Elem<D>,
699 offset: &str,
700 ) -> std::fmt::Result;
701 fn compile_warp_shuffle_up(
702 f: &mut std::fmt::Formatter<'_>,
703 var: &str,
704 offset: &str,
705 ) -> std::fmt::Result;
706 fn compile_warp_shuffle_down(
707 f: &mut std::fmt::Formatter<'_>,
708 var: &str,
709 offset: &str,
710 ) -> std::fmt::Result;
711 fn compile_warp_all<T: Component<D>>(
712 f: &mut std::fmt::Formatter<'_>,
713 input: &T,
714 ) -> std::fmt::Result;
715 fn compile_warp_any<T: Component<D>>(
716 f: &mut std::fmt::Formatter<'_>,
717 input: &T,
718 ) -> std::fmt::Result;
719 fn compile_warp_ballot(
720 f: &mut std::fmt::Formatter<'_>,
721 input: &Variable<D>,
722 out_elem: &Elem<D>,
723 ) -> std::fmt::Result;
724 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
725 write!(
726 f,
727 "
728unsigned int mask = __activemask();
729unsigned int leader = __ffs(mask) - 1;
730{out} = threadIdx.x % warpSize == leader;
731 "
732 )
733 }
734}
735
736#[derive(Debug, Clone, Copy, new)]
737pub struct ManualMma<'a, D: Dialect> {
738 pub shape: MmaShape<D>,
739 pub frag_a: &'a Variable<D>,
740 pub frag_b: &'a Variable<D>,
741 pub frag_c: &'a Variable<D>,
742 pub frag_d: &'a Variable<D>,
743}
744
745pub trait DialectWarpReduceCompiler<D: Dialect>:
746 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
747{
748 fn warp_reduce_sum(
749 f: &mut core::fmt::Formatter<'_>,
750 input: &Variable<D>,
751 out: &Variable<D>,
752 ) -> core::fmt::Result {
753 reduce_operator(f, input, out, "+=")
754 }
755 fn warp_reduce_prod(
756 f: &mut core::fmt::Formatter<'_>,
757 input: &Variable<D>,
758 out: &Variable<D>,
759 ) -> core::fmt::Result {
760 reduce_operator(f, input, out, "*=")
761 }
762 fn warp_reduce_max(
763 f: &mut core::fmt::Formatter<'_>,
764 input: &Variable<D>,
765 out: &Variable<D>,
766 ) -> core::fmt::Result {
767 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
768 }
769 fn warp_reduce_min(
770 f: &mut core::fmt::Formatter<'_>,
771 input: &Variable<D>,
772 out: &Variable<D>,
773 ) -> core::fmt::Result {
774 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
775 }
776 fn warp_reduce_all(
777 f: &mut core::fmt::Formatter<'_>,
778 input: &Variable<D>,
779 out: &Variable<D>,
780 ) -> core::fmt::Result {
781 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
782 }
783 fn warp_reduce_any(
784 f: &mut core::fmt::Formatter<'_>,
785 input: &Variable<D>,
786 out: &Variable<D>,
787 ) -> core::fmt::Result {
788 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
789 }
790 fn warp_reduce_sum_inclusive(
791 f: &mut core::fmt::Formatter<'_>,
792 input: &Variable<D>,
793 out: &Variable<D>,
794 ) -> core::fmt::Result {
795 reduce_inclusive(f, input, out, "+=")
796 }
797 fn warp_reduce_prod_inclusive(
798 f: &mut core::fmt::Formatter<'_>,
799 input: &Variable<D>,
800 out: &Variable<D>,
801 ) -> core::fmt::Result {
802 reduce_inclusive(f, input, out, "*=")
803 }
804 fn warp_reduce_sum_exclusive(
805 f: &mut core::fmt::Formatter<'_>,
806 input: &Variable<D>,
807 out: &Variable<D>,
808 ) -> core::fmt::Result {
809 reduce_exclusive(f, input, out, "+=", "0")
810 }
811 fn warp_reduce_prod_exclusive(
812 f: &mut core::fmt::Formatter<'_>,
813 input: &Variable<D>,
814 out: &Variable<D>,
815 ) -> core::fmt::Result {
816 reduce_exclusive(f, input, out, "*=", "1")
817 }
818}
819
820pub trait DialectWmmaCompiler<D: Dialect>:
821 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
822{
823 #[allow(unused_variables)]
824 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
825 Ok(())
826 }
827 #[allow(unused_variables)]
828 fn compile_wmma_type_definitions(
829 f: &mut std::fmt::Formatter<'_>,
830 flags: &Flags,
831 ) -> std::fmt::Result {
832 Ok(())
833 }
834 #[allow(unused_variables)]
835 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 Ok(())
837 }
838 #[allow(unused_variables)]
839 fn compile_wwma_fragment_ident(
840 f: &mut std::fmt::Formatter<'_>,
841 ident: &FragmentIdent<D>,
842 ) -> std::fmt::Result {
843 Ok(())
844 }
845 #[allow(unused_variables)]
846 fn compile_wmma_fragment_layout(
847 f: &mut std::fmt::Formatter<'_>,
848 layout: &FragmentLayout<D>,
849 ) -> std::fmt::Result {
850 Ok(())
851 }
852 #[allow(unused_variables)]
853 fn compile_wmma_fragment(
854 f: &mut std::fmt::Formatter<'_>,
855 fragment: &Fragment<D>,
856 ) -> std::fmt::Result {
857 Ok(())
858 }
859
860 fn compile_wmma_fragment_declaration(
861 f: &mut std::fmt::Formatter<'_>,
862 var: &Variable<D>,
863 ) -> std::fmt::Result;
864
865 fn compile_wmma_instruction(
866 f: &mut std::fmt::Formatter<'_>,
867 instruction: &WmmaInstruction<D>,
868 ) -> std::fmt::Result;
869 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
870 fn compile_scaled_mma(
871 f: &mut std::fmt::Formatter<'_>,
872 mma: ManualMma<D>,
873 scales_a: Variable<D>,
874 scales_b: Variable<D>,
875 scales_factor: u32,
876 ) -> std::fmt::Result;
877 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
878 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
879 fn supported_scaled_mma_combinations(
880 _arch: &D::Architecture,
881 ) -> SupportedScaledMmaCombinations {
882 Vec::new()
883 }
884}
885
886pub trait DialectProcessors<D: Dialect> {
889 fn processors() -> Vec<Box<dyn Processor>>;
890}