1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14 WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90 AtomicKind::_Dialect(_) => Ok(()),
91 }
92 }
93
94 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95 fn compile_type_definitions(
96 f: &mut std::fmt::Formatter<'_>,
97 items: &HashSet<Item<D>>,
98 scalars: &[(Elem<D>, usize)],
99 flags: &Flags<D>,
100 ) -> std::fmt::Result;
101 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102 fn compile_shared_memory_declaration(
103 f: &mut std::fmt::Formatter<'_>,
104 shared: &SharedMemory<D>,
105 ) -> std::fmt::Result {
106 match shared {
107 SharedMemory::Array {
108 index,
109 item,
110 length,
111 offset,
112 ..
113 } => {
114 let size_bytes = *length * item.size();
115 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116 writeln!(
117 f,
118 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119 )
120 }
121 SharedMemory::Value {
122 index,
123 item,
124 offset,
125 ..
126 } => {
127 let size_bytes = item.size() as u32;
128 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129 writeln!(
130 f,
131 "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132 )
133 }
134 }
135 }
136 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
137 Ok(())
138 }
139 fn address_space_for_variable(_variable: &Variable<D>) -> String {
141 "".to_string()
142 }
143}
144
145pub trait DialectBindings<D: Dialect> {
148 fn compile_kernel_signature(
149 f: &mut std::fmt::Formatter<'_>,
150 kernel_name: &str,
151 tensor_maps: &[Binding<D>],
152 buffers: &[Binding<D>],
153 scalars: &[(Elem<D>, usize)],
154 flags: &Flags<D>,
155 ) -> std::fmt::Result;
156 fn compile_bindings_body(
157 _f: &mut std::fmt::Formatter<'_>,
158 _body: &Body<D>,
159 ) -> std::fmt::Result {
160 Ok(())
161 }
162}
163
164pub trait DialectCubeBuiltins<D: Dialect> {
167 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173 let unit_pos_plane = flags.unit_pos_plane;
174 let plane_dim_checked = flags.plane_dim_checked;
175 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176 let plane_index = flags.plane_index;
177 let absolute_pos = flags.absolute_pos || unit_pos_plane;
178 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179 let cube_dim = flags.cube_dim;
180 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181 let unit_pos = flags.unit_pos;
182 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183 let cube_count = flags.cube_count;
184 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185 let cube_pos = flags.cube_pos;
186 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187 let cluster_group = flags.cluster_pos;
188
189 CubeIndexFlags {
190 absolute_pos,
191 absolute_pos_tuple,
192 cube_count,
193 cube_count_tuple,
194 cube_dim,
195 cube_dim_tuple,
196 cube_pos,
197 cube_pos_tuple,
198 plane_dim,
199 plane_dim_checked,
200 plane_index,
201 unit_pos_tuple,
202 unit_pos,
203 unit_pos_plane,
204 cluster_pos: cluster_group,
205 }
206 }
207
208 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 let variable = Variable::<D>::AbsolutePosBaseName;
210 let ty = variable.item();
211 let cube_pos_x = Variable::<D>::CubePosX;
212 let cube_pos_y = Variable::<D>::CubePosY;
213 let cube_pos_z = Variable::<D>::CubePosZ;
214 let cube_dim_x = Variable::<D>::CubeDimX;
215 let cube_dim_y = Variable::<D>::CubeDimY;
216 let cube_dim_z = Variable::<D>::CubeDimZ;
217 let unit_pos_x = Variable::<D>::UnitPosX;
218 let unit_pos_y = Variable::<D>::UnitPosY;
219 let unit_pos_z = Variable::<D>::UnitPosZ;
220 writeln!(
221 f,
222 "{ty} {variable} = make_{ty}(
223 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227 )
228 }
229
230 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.write_str("absoluteIdx")
232 }
233
234 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("idxGlobal")
236 }
237
238 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 Self::compile_absolute_pos_base_name(f)?;
240 write!(f, ".x")
241 }
242
243 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 Self::compile_absolute_pos_base_name(f)?;
245 write!(f, ".y")
246 }
247
248 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 Self::compile_absolute_pos_base_name(f)?;
250 write!(f, ".z")
251 }
252
253 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.write_str("gridDim")
255 }
256
257 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.write_str("gridDimGlobal")
259 }
260
261 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 Self::compile_cube_count_base_name(f)?;
263 write!(f, ".x")
264 }
265
266 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 Self::compile_cube_count_base_name(f)?;
268 write!(f, ".y")
269 }
270
271 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 Self::compile_cube_count_base_name(f)?;
273 write!(f, ".z")
274 }
275
276 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_str("blockDim")
278 }
279
280 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.write_str("blockDimGlobal")
282 }
283
284 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 Self::compile_cube_dim_base_name(f)?;
286 write!(f, ".x")
287 }
288
289 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 Self::compile_cube_dim_base_name(f)?;
291 write!(f, ".y")
292 }
293
294 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 Self::compile_cube_dim_base_name(f)?;
296 write!(f, ".z")
297 }
298
299 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.write_str("blockIdx")
301 }
302
303 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 f.write_str("blockIdxGlobal")
305 }
306
307 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 Self::compile_cube_pos_base_name(f)?;
309 write!(f, ".x")
310 }
311
312 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 Self::compile_cube_pos_base_name(f)?;
314 write!(f, ".y")
315 }
316
317 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 Self::compile_cube_pos_base_name(f)?;
319 write!(f, ".z")
320 }
321
322 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 let variable = Variable::<D>::UnitPos;
324 let ty = variable.item();
325 let cube_dim_x = Variable::<D>::CubeDimX;
326 let cube_dim_y = Variable::<D>::CubeDimY;
327 let unit_pos_x = Variable::<D>::UnitPosX;
328 let unit_pos_y = Variable::<D>::UnitPosY;
329 let unit_pos_z = Variable::<D>::UnitPosZ;
330 writeln!(
331 f,
332 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333 )
334 }
335
336 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.write_str("threadIdxGlobal")
338 }
339
340 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.write_str("threadIdx")
342 }
343
344 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 Self::compile_unit_pos_base_name(f)?;
346 write!(f, ".x")
347 }
348
349 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 Self::compile_unit_pos_base_name(f)?;
351 write!(f, ".y")
352 }
353
354 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 Self::compile_unit_pos_base_name(f)?;
356 write!(f, ".z")
357 }
358
359 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.write_str("warpSize")
361 }
362
363 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.write_str("warpSizeChecked")
365 }
366
367 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let unit_pos_x = Variable::<D>::UnitPosX;
369 let plane_dim = Variable::<D>::PlaneDim;
370 write!(f, "{unit_pos_x} / {plane_dim}")
371 }
372
373 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
375 let plane_dim = Variable::<D>::PlaneDim;
376 let ty = plane_dim.item();
377 write!(f, "{ty}({absolute_pos}) % {plane_dim}")
378 }
379
380 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 write!(f, "0")
382 }
383 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 write!(f, "0")
385 }
386 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 write!(f, "0")
388 }
389 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 write!(f, "0")
391 }
392}
393
394pub trait DialectInstructions<D: Dialect> {
397 fn compile_atomic_add(
399 f: &mut std::fmt::Formatter<'_>,
400 lhs: &Variable<D>,
401 rhs: &Variable<D>,
402 out: &Variable<D>,
403 ) -> std::fmt::Result {
404 let out = out.fmt_left();
405 match rhs.elem() {
406 Elem::I64 => writeln!(
407 f,
408 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
409 uint = Elem::<D>::U64
410 ),
411 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
412 }
413 }
414
415 fn compile_atomic_and(
416 f: &mut std::fmt::Formatter<'_>,
417 lhs: &Variable<D>,
418 rhs: &Variable<D>,
419 out: &Variable<D>,
420 ) -> std::fmt::Result {
421 let out = out.fmt_left();
422 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
423 }
424
425 fn compile_atomic_cas(
426 f: &mut std::fmt::Formatter<'_>,
427 input: &Variable<D>,
428 cmp: &Variable<D>,
429 val: &Variable<D>,
430 out: &Variable<D>,
431 ) -> std::fmt::Result {
432 let out = out.fmt_left();
433 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
434 }
435
436 fn compile_atomic_load(
437 f: &mut std::fmt::Formatter<'_>,
438 input: &Variable<D>,
439 out: &Variable<D>,
440 ) -> std::fmt::Result {
441 let out = out.fmt_left();
442 writeln!(f, "{out} = atomicAdd({input}, 0);")
443 }
444
445 fn compile_atomic_max(
446 f: &mut std::fmt::Formatter<'_>,
447 lhs: &Variable<D>,
448 rhs: &Variable<D>,
449 out: &Variable<D>,
450 ) -> std::fmt::Result {
451 let out = out.fmt_left();
452 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
453 }
454
455 fn compile_atomic_min(
456 f: &mut std::fmt::Formatter<'_>,
457 lhs: &Variable<D>,
458 rhs: &Variable<D>,
459 out: &Variable<D>,
460 ) -> std::fmt::Result {
461 let out = out.fmt_left();
462 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
463 }
464
465 fn compile_atomic_or(
466 f: &mut std::fmt::Formatter<'_>,
467 lhs: &Variable<D>,
468 rhs: &Variable<D>,
469 out: &Variable<D>,
470 ) -> std::fmt::Result {
471 let out = out.fmt_left();
472 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
473 }
474
475 fn compile_atomic_store(
476 f: &mut std::fmt::Formatter<'_>,
477 input: &Variable<D>,
478 out: &Variable<D>,
479 ) -> std::fmt::Result {
480 writeln!(f, "atomicExch({out}, {input});")
481 }
482
483 fn compile_atomic_sub(
484 f: &mut std::fmt::Formatter<'_>,
485 lhs: &Variable<D>,
486 rhs: &Variable<D>,
487 out: &Variable<D>,
488 ) -> std::fmt::Result {
489 let out = out.fmt_left();
490 match rhs.elem() {
491 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
492 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
493 Elem::I64 => writeln!(
494 f,
495 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
496 uint = Elem::<D>::U64
497 ),
498 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
499 }
500 }
501
502 fn compile_atomic_swap(
503 f: &mut std::fmt::Formatter<'_>,
504 lhs: &Variable<D>,
505 rhs: &Variable<D>,
506 out: &Variable<D>,
507 ) -> std::fmt::Result {
508 let out = out.fmt_left();
509 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
510 }
511
512 fn compile_atomic_xor(
513 f: &mut std::fmt::Formatter<'_>,
514 lhs: &Variable<D>,
515 rhs: &Variable<D>,
516 out: &Variable<D>,
517 ) -> std::fmt::Result {
518 let out = out.fmt_left();
519 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
520 }
521
522 fn compile_saturating_add(
523 f: &mut std::fmt::Formatter<'_>,
524 lhs: impl Display,
525 rhs: impl Display,
526 item: Item<D>,
527 ) -> std::fmt::Result;
528
529 fn compile_saturating_sub(
530 f: &mut std::fmt::Formatter<'_>,
531 lhs: impl Display,
532 rhs: impl Display,
533 item: Item<D>,
534 ) -> std::fmt::Result;
535
536 fn compile_instruction_printf(
538 f: &mut std::fmt::Formatter<'_>,
539 format_string: &str,
540 args: &[Variable<D>],
541 ) -> std::fmt::Result {
542 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
543 let args = match args.is_empty() {
544 true => "".to_string(),
545 false => format!(", {}", args.join(",")),
546 };
547 writeln!(f, "printf({format_string:?}{args});")
548 }
549
550 fn compile_instruction_log1p_scalar<T: Component<D>>(
552 f: &mut std::fmt::Formatter<'_>,
553 input: T,
554 ) -> std::fmt::Result {
555 let elem = input.elem();
556 match elem {
557 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
558 write!(f, "{elem}(log1p(float({input})))")
559 }
560 _ => write!(f, "log1p({input})"),
561 }
562 }
563
564 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
568
569 fn compile_instruction_tanh_scalar<T: Component<D>>(
571 f: &mut std::fmt::Formatter<'_>,
572 input: T,
573 ) -> std::fmt::Result {
574 let elem = input.elem();
575 match elem {
576 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
577 write!(f, "{elem}(tanh(float({input})))")
578 }
579 _ => write!(f, "tanh({input})"),
580 }
581 }
582
583 fn compile_instruction_find_first_set<T: Component<D>>(
585 f: &mut std::fmt::Formatter<'_>,
586 input: T,
587 out_elem: Elem<D>,
588 ) -> std::fmt::Result;
589 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
590 f: &mut std::fmt::Formatter<'_>,
591 input: T,
592 out_elem: Elem<D>,
593 ) -> std::fmt::Result;
594
595 fn compile_instruction_popcount_scalar<T: Component<D>>(
596 f: &mut std::fmt::Formatter<'_>,
597 input: T,
598 out_elem: Elem<D>,
599 ) -> std::fmt::Result {
600 write!(f, "{out_elem}(")?;
601 match input.elem() {
602 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
603 Elem::U32 => write!(f, "__popc({input})"),
604 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
605 Elem::U64 => write!(f, "__popcll({input})"),
606 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
607 }?;
608 write!(f, ")")
609 }
610
611 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
612 f: &mut std::fmt::Formatter<'_>,
613 input: T,
614 out_elem: Elem<D>,
615 ) -> std::fmt::Result {
616 write!(f, "{out_elem}(")?;
617 match out_elem {
618 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
619 Elem::U32 => write!(f, "__brev({input})"),
620 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
621 Elem::U64 => write!(f, "__brevll({input})"),
622 _ => write!(
623 f,
624 "__brev({}) >> {}",
625 super::unary::zero_extend(input),
626 (size_of::<u32>() - out_elem.size()) * 8
627 ),
628 }?;
629 write!(f, ")")
630 }
631
632 fn compile_instruction_max_function_name(
634 f: &mut std::fmt::Formatter<'_>,
635 item: Item<D>,
636 ) -> std::fmt::Result;
637
638 fn compile_instruction_min_function_name(
639 f: &mut std::fmt::Formatter<'_>,
640 item: Item<D>,
641 ) -> std::fmt::Result;
642
643 fn compile_instruction_powf(
644 f: &mut std::fmt::Formatter<'_>,
645 lhs: &str,
646 rhs: &str,
647 elem: Elem<D>,
648 ) -> std::fmt::Result {
649 match elem {
650 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
651 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
652 _ => write!(f, "#error Unsupported type for powf: {elem}"),
653 }
654 }
655
656 fn compile_instruction_hypot(
657 f: &mut std::fmt::Formatter<'_>,
658 lhs: &str,
659 rhs: &str,
660 elem: Elem<D>,
661 ) -> std::fmt::Result {
662 match elem {
663 Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
664 Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
665 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
666 }
667 }
668
669 fn compile_instruction_rhypot(
670 f: &mut std::fmt::Formatter<'_>,
671 lhs: &str,
672 rhs: &str,
673 elem: Elem<D>,
674 ) -> std::fmt::Result {
675 match elem {
676 Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
677 Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
678 _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
679 }
680 }
681
682 fn compile_instruction_half_function_name_prefix() -> &'static str {
683 "h"
684 }
685
686 fn compile_instruction_half2_function_name_prefix() -> &'static str {
687 "h2"
688 }
689
690 fn compile_warp_shuffle(
692 f: &mut std::fmt::Formatter<'_>,
693 var: &str,
694 source: &str,
695 ) -> std::fmt::Result;
696 fn compile_warp_shuffle_xor(
697 f: &mut std::fmt::Formatter<'_>,
698 var: &str,
699 elem: &Elem<D>,
700 offset: &str,
701 ) -> std::fmt::Result;
702 fn compile_warp_shuffle_up(
703 f: &mut std::fmt::Formatter<'_>,
704 var: &str,
705 offset: &str,
706 ) -> std::fmt::Result;
707 fn compile_warp_shuffle_down(
708 f: &mut std::fmt::Formatter<'_>,
709 var: &str,
710 offset: &str,
711 ) -> std::fmt::Result;
712 fn compile_warp_all<T: Component<D>>(
713 f: &mut std::fmt::Formatter<'_>,
714 input: &T,
715 ) -> std::fmt::Result;
716 fn compile_warp_any<T: Component<D>>(
717 f: &mut std::fmt::Formatter<'_>,
718 input: &T,
719 ) -> std::fmt::Result;
720 fn compile_warp_ballot(
721 f: &mut std::fmt::Formatter<'_>,
722 input: &Variable<D>,
723 out_elem: &Elem<D>,
724 ) -> std::fmt::Result;
725 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
726 write!(
727 f,
728 "
729unsigned int mask = __activemask();
730unsigned int leader = __ffs(mask) - 1;
731{out} = threadIdx.x % warpSize == leader;
732 "
733 )
734 }
735}
736
737#[derive(Debug, Clone, Copy, new)]
738pub struct ManualMma<'a, D: Dialect> {
739 pub shape: MmaShape<D>,
740 pub frag_a: &'a Variable<D>,
741 pub frag_b: &'a Variable<D>,
742 pub frag_c: &'a Variable<D>,
743 pub frag_d: &'a Variable<D>,
744}
745
746pub trait DialectWarpReduceCompiler<D: Dialect>:
747 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
748{
749 fn warp_reduce_sum(
750 f: &mut core::fmt::Formatter<'_>,
751 input: &Variable<D>,
752 out: &Variable<D>,
753 ) -> core::fmt::Result {
754 reduce_operator(f, input, out, "+=")
755 }
756 fn warp_reduce_prod(
757 f: &mut core::fmt::Formatter<'_>,
758 input: &Variable<D>,
759 out: &Variable<D>,
760 ) -> core::fmt::Result {
761 reduce_operator(f, input, out, "*=")
762 }
763 fn warp_reduce_max(
764 f: &mut core::fmt::Formatter<'_>,
765 input: &Variable<D>,
766 out: &Variable<D>,
767 ) -> core::fmt::Result {
768 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
769 }
770 fn warp_reduce_min(
771 f: &mut core::fmt::Formatter<'_>,
772 input: &Variable<D>,
773 out: &Variable<D>,
774 ) -> core::fmt::Result {
775 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
776 }
777 fn warp_reduce_all(
778 f: &mut core::fmt::Formatter<'_>,
779 input: &Variable<D>,
780 out: &Variable<D>,
781 ) -> core::fmt::Result {
782 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
783 }
784 fn warp_reduce_any(
785 f: &mut core::fmt::Formatter<'_>,
786 input: &Variable<D>,
787 out: &Variable<D>,
788 ) -> core::fmt::Result {
789 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
790 }
791 fn warp_reduce_sum_inclusive(
792 f: &mut core::fmt::Formatter<'_>,
793 input: &Variable<D>,
794 out: &Variable<D>,
795 ) -> core::fmt::Result {
796 reduce_inclusive(f, input, out, "+=")
797 }
798 fn warp_reduce_prod_inclusive(
799 f: &mut core::fmt::Formatter<'_>,
800 input: &Variable<D>,
801 out: &Variable<D>,
802 ) -> core::fmt::Result {
803 reduce_inclusive(f, input, out, "*=")
804 }
805 fn warp_reduce_sum_exclusive(
806 f: &mut core::fmt::Formatter<'_>,
807 input: &Variable<D>,
808 out: &Variable<D>,
809 ) -> core::fmt::Result {
810 reduce_exclusive(f, input, out, "+=", "0")
811 }
812 fn warp_reduce_prod_exclusive(
813 f: &mut core::fmt::Formatter<'_>,
814 input: &Variable<D>,
815 out: &Variable<D>,
816 ) -> core::fmt::Result {
817 reduce_exclusive(f, input, out, "*=", "1")
818 }
819}
820
821pub trait DialectWmmaCompiler<D: Dialect>:
822 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
823{
824 #[allow(unused_variables)]
825 fn compile_wmma_includes(
826 f: &mut std::fmt::Formatter<'_>,
827 flags: &Flags<D>,
828 ) -> std::fmt::Result {
829 Ok(())
830 }
831 #[allow(unused_variables)]
832 fn compile_wmma_type_definitions(
833 f: &mut std::fmt::Formatter<'_>,
834 flags: &Flags<D>,
835 ) -> std::fmt::Result {
836 Ok(())
837 }
838 #[allow(unused_variables)]
839 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840 Ok(())
841 }
842 #[allow(unused_variables)]
843 fn compile_wwma_fragment_ident(
844 f: &mut std::fmt::Formatter<'_>,
845 ident: &FragmentIdent<D>,
846 ) -> std::fmt::Result {
847 Ok(())
848 }
849 #[allow(unused_variables)]
850 fn compile_wmma_fragment_layout(
851 f: &mut std::fmt::Formatter<'_>,
852 layout: &FragmentLayout<D>,
853 ) -> std::fmt::Result {
854 Ok(())
855 }
856 #[allow(unused_variables)]
857 fn compile_wmma_fragment(
858 f: &mut std::fmt::Formatter<'_>,
859 fragment: &Fragment<D>,
860 ) -> std::fmt::Result {
861 Ok(())
862 }
863
864 fn compile_wmma_fragment_declaration(
865 f: &mut std::fmt::Formatter<'_>,
866 var: &Variable<D>,
867 ) -> std::fmt::Result;
868
869 fn compile_wmma_instruction(
870 f: &mut std::fmt::Formatter<'_>,
871 instruction: &WmmaInstruction<D>,
872 ) -> std::fmt::Result;
873 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
874 fn compile_scaled_mma(
875 f: &mut std::fmt::Formatter<'_>,
876 mma: ManualMma<D>,
877 scales_a: Variable<D>,
878 scales_b: Variable<D>,
879 scales_factor: u32,
880 ) -> std::fmt::Result;
881 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
882 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
883 fn supported_scaled_mma_combinations(
884 _arch: &D::Architecture,
885 ) -> SupportedScaledMmaCombinations {
886 Vec::new()
887 }
888}
889
890pub trait DialectProcessors<D: Dialect> {
893 fn processors() -> Vec<Box<dyn Processor>>;
894}