1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::Processor;
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Binding, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, SharedMemory, Variable, WarpInstruction,
14 WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
88 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
89 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
90 AtomicKind::_Dialect(_) => Ok(()),
91 }
92 }
93
94 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
95 fn compile_type_definitions(
96 f: &mut std::fmt::Formatter<'_>,
97 items: &HashSet<Item<D>>,
98 scalars: &[(Elem<D>, usize)],
99 flags: &Flags<D>,
100 ) -> std::fmt::Result;
101 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
102 fn compile_shared_memory_declaration(
103 f: &mut std::fmt::Formatter<'_>,
104 shared: &SharedMemory<D>,
105 ) -> std::fmt::Result {
106 match shared {
107 SharedMemory::Array {
108 index,
109 item,
110 length,
111 offset,
112 ..
113 } => {
114 let size_bytes = *length * item.size();
115 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
116 writeln!(
117 f,
118 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
119 )
120 }
121 SharedMemory::Value {
122 index,
123 item,
124 offset,
125 ..
126 } => {
127 let size_bytes = item.size() as u32;
128 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
129 writeln!(
130 f,
131 "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
132 )
133 }
134 }
135 }
136 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
137 Ok(())
138 }
139 fn address_space_for_variable(_variable: &Variable<D>) -> String {
141 "".to_string()
142 }
143}
144
145pub trait DialectBindings<D: Dialect> {
148 fn compile_kernel_signature(
149 f: &mut std::fmt::Formatter<'_>,
150 kernel_name: &str,
151 tensor_maps: &[Binding<D>],
152 buffers: &[Binding<D>],
153 scalars: &[(Elem<D>, usize)],
154 flags: &Flags<D>,
155 ) -> std::fmt::Result;
156 fn compile_bindings_body(
157 _f: &mut std::fmt::Formatter<'_>,
158 _body: &Body<D>,
159 ) -> std::fmt::Result {
160 Ok(())
161 }
162}
163
164pub trait DialectCubeBuiltins<D: Dialect> {
167 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
173 let unit_pos_plane = flags.unit_pos_plane;
174 let plane_dim_checked = flags.plane_dim_checked;
175 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
176 let plane_pos = flags.plane_pos;
177 let absolute_pos = flags.absolute_pos || unit_pos_plane;
178 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
179 let cube_dim = flags.cube_dim;
180 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
181 let unit_pos = flags.unit_pos;
182 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
183 let cube_count = flags.cube_count;
184 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
185 let cube_pos = flags.cube_pos;
186 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
187 let cluster_group = flags.cluster_pos;
188
189 CubeIndexFlags {
190 absolute_pos,
191 absolute_pos_tuple,
192 cube_count,
193 cube_count_tuple,
194 cube_dim,
195 cube_dim_tuple,
196 cube_pos,
197 cube_pos_tuple,
198 plane_dim,
199 plane_dim_checked,
200 plane_pos,
201 unit_pos_tuple,
202 unit_pos,
203 unit_pos_plane,
204 cluster_pos: cluster_group,
205 }
206 }
207
208 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 let variable = Variable::<D>::AbsolutePosBaseName;
210 let ty = variable.item();
211 let cube_pos_x = Variable::<D>::CubePosX;
212 let cube_pos_y = Variable::<D>::CubePosY;
213 let cube_pos_z = Variable::<D>::CubePosZ;
214 let cube_dim_x = Variable::<D>::CubeDimX;
215 let cube_dim_y = Variable::<D>::CubeDimY;
216 let cube_dim_z = Variable::<D>::CubeDimZ;
217 let unit_pos_x = Variable::<D>::UnitPosX;
218 let unit_pos_y = Variable::<D>::UnitPosY;
219 let unit_pos_z = Variable::<D>::UnitPosZ;
220 writeln!(
221 f,
222 "{ty} {variable} = make_{ty}(
223 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
224 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
225 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
226);"
227 )
228 }
229
230 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.write_str("absoluteIdx")
232 }
233
234 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("idxGlobal")
236 }
237
238 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 Self::compile_absolute_pos_base_name(f)?;
240 write!(f, ".x")
241 }
242
243 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 Self::compile_absolute_pos_base_name(f)?;
245 write!(f, ".y")
246 }
247
248 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 Self::compile_absolute_pos_base_name(f)?;
250 write!(f, ".z")
251 }
252
253 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.write_str("gridDim")
255 }
256
257 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 f.write_str("gridDimGlobal")
259 }
260
261 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 Self::compile_cube_count_base_name(f)?;
263 write!(f, ".x")
264 }
265
266 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 Self::compile_cube_count_base_name(f)?;
268 write!(f, ".y")
269 }
270
271 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 Self::compile_cube_count_base_name(f)?;
273 write!(f, ".z")
274 }
275
276 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.write_str("blockDim")
278 }
279
280 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.write_str("blockDimGlobal")
282 }
283
284 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 Self::compile_cube_dim_base_name(f)?;
286 write!(f, ".x")
287 }
288
289 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 Self::compile_cube_dim_base_name(f)?;
291 write!(f, ".y")
292 }
293
294 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 Self::compile_cube_dim_base_name(f)?;
296 write!(f, ".z")
297 }
298
299 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.write_str("blockIdx")
301 }
302
303 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 f.write_str("blockIdxGlobal")
305 }
306
307 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 Self::compile_cube_pos_base_name(f)?;
309 write!(f, ".x")
310 }
311
312 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 Self::compile_cube_pos_base_name(f)?;
314 write!(f, ".y")
315 }
316
317 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 Self::compile_cube_pos_base_name(f)?;
319 write!(f, ".z")
320 }
321
322 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 let variable = Variable::<D>::UnitPos;
324 let ty = variable.item();
325 let cube_dim_x = Variable::<D>::CubeDimX;
326 let cube_dim_y = Variable::<D>::CubeDimY;
327 let unit_pos_x = Variable::<D>::UnitPosX;
328 let unit_pos_y = Variable::<D>::UnitPosY;
329 let unit_pos_z = Variable::<D>::UnitPosZ;
330 writeln!(
331 f,
332 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
333 )
334 }
335
336 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.write_str("threadIdxGlobal")
338 }
339
340 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.write_str("threadIdx")
342 }
343
344 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 Self::compile_unit_pos_base_name(f)?;
346 write!(f, ".x")
347 }
348
349 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 Self::compile_unit_pos_base_name(f)?;
351 write!(f, ".y")
352 }
353
354 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 Self::compile_unit_pos_base_name(f)?;
356 write!(f, ".z")
357 }
358
359 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.write_str("warpSize")
361 }
362
363 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.write_str("warpSizeChecked")
365 }
366
367 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let unit_pos_x = Variable::<D>::UnitPosX;
369 let plane_dim = Variable::<D>::PlaneDim;
370 write!(f, "{unit_pos_x} / {plane_dim}")
371 }
372
373 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
375 let plane_dim = Variable::<D>::PlaneDim;
376 let ty = plane_dim.item();
377 write!(f, "{ty}({absolute_pos}) % {plane_dim}")
378 }
379
380 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 write!(f, "0")
382 }
383 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 write!(f, "0")
385 }
386 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 write!(f, "0")
388 }
389 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 write!(f, "0")
391 }
392}
393
394pub trait DialectInstructions<D: Dialect> {
397 fn compile_atomic_add(
399 f: &mut std::fmt::Formatter<'_>,
400 lhs: &Variable<D>,
401 rhs: &Variable<D>,
402 out: &Variable<D>,
403 ) -> std::fmt::Result {
404 let out = out.fmt_left();
405 match rhs.elem() {
406 Elem::I64 => writeln!(
407 f,
408 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
409 uint = Elem::<D>::U64
410 ),
411 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
412 }
413 }
414
415 fn compile_atomic_and(
416 f: &mut std::fmt::Formatter<'_>,
417 lhs: &Variable<D>,
418 rhs: &Variable<D>,
419 out: &Variable<D>,
420 ) -> std::fmt::Result {
421 let out = out.fmt_left();
422 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
423 }
424
425 fn compile_atomic_cas(
426 f: &mut std::fmt::Formatter<'_>,
427 input: &Variable<D>,
428 cmp: &Variable<D>,
429 val: &Variable<D>,
430 out: &Variable<D>,
431 ) -> std::fmt::Result {
432 let out = out.fmt_left();
433 writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});")
434 }
435
436 fn compile_atomic_load(
437 f: &mut std::fmt::Formatter<'_>,
438 input: &Variable<D>,
439 out: &Variable<D>,
440 ) -> std::fmt::Result {
441 let out = out.fmt_left();
442 writeln!(f, "{out} = atomicAdd({input}, 0);")
443 }
444
445 fn compile_atomic_max(
446 f: &mut std::fmt::Formatter<'_>,
447 lhs: &Variable<D>,
448 rhs: &Variable<D>,
449 out: &Variable<D>,
450 ) -> std::fmt::Result {
451 let out = out.fmt_left();
452 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
453 }
454
455 fn compile_atomic_min(
456 f: &mut std::fmt::Formatter<'_>,
457 lhs: &Variable<D>,
458 rhs: &Variable<D>,
459 out: &Variable<D>,
460 ) -> std::fmt::Result {
461 let out = out.fmt_left();
462 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
463 }
464
465 fn compile_atomic_or(
466 f: &mut std::fmt::Formatter<'_>,
467 lhs: &Variable<D>,
468 rhs: &Variable<D>,
469 out: &Variable<D>,
470 ) -> std::fmt::Result {
471 let out = out.fmt_left();
472 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
473 }
474
475 fn compile_atomic_store(
476 f: &mut std::fmt::Formatter<'_>,
477 input: &Variable<D>,
478 out: &Variable<D>,
479 ) -> std::fmt::Result {
480 writeln!(f, "atomicExch({out}, {input});")
481 }
482
483 fn compile_atomic_sub(
484 f: &mut std::fmt::Formatter<'_>,
485 lhs: &Variable<D>,
486 rhs: &Variable<D>,
487 out: &Variable<D>,
488 ) -> std::fmt::Result {
489 let out = out.fmt_left();
490 match rhs.elem() {
491 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
492 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
493 Elem::I64 => writeln!(
494 f,
495 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
496 uint = Elem::<D>::U64
497 ),
498 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
499 }
500 }
501
502 fn compile_atomic_swap(
503 f: &mut std::fmt::Formatter<'_>,
504 lhs: &Variable<D>,
505 rhs: &Variable<D>,
506 out: &Variable<D>,
507 ) -> std::fmt::Result {
508 let out = out.fmt_left();
509 writeln!(f, "{out} = atomicExch({lhs}, {rhs});")
510 }
511
512 fn compile_atomic_xor(
513 f: &mut std::fmt::Formatter<'_>,
514 lhs: &Variable<D>,
515 rhs: &Variable<D>,
516 out: &Variable<D>,
517 ) -> std::fmt::Result {
518 let out = out.fmt_left();
519 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
520 }
521
522 fn compile_saturating_add(
523 f: &mut std::fmt::Formatter<'_>,
524 lhs: impl Display,
525 rhs: impl Display,
526 item: Item<D>,
527 ) -> std::fmt::Result;
528
529 fn compile_saturating_sub(
530 f: &mut std::fmt::Formatter<'_>,
531 lhs: impl Display,
532 rhs: impl Display,
533 item: Item<D>,
534 ) -> std::fmt::Result;
535
536 fn compile_instruction_printf(
538 f: &mut std::fmt::Formatter<'_>,
539 format_string: &str,
540 args: &[Variable<D>],
541 ) -> std::fmt::Result {
542 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
543 let args = match args.is_empty() {
544 true => "".to_string(),
545 false => format!(", {}", args.join(",")),
546 };
547 writeln!(f, "printf({format_string:?}{args});")
548 }
549
550 fn compile_instruction_log1p_scalar<T: Component<D>>(
552 f: &mut std::fmt::Formatter<'_>,
553 input: T,
554 ) -> std::fmt::Result {
555 let elem = input.elem();
556 match elem {
557 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
558 write!(f, "{elem}(log1p(float({input})))")
559 }
560 _ => write!(f, "log1p({input})"),
561 }
562 }
563
564 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
566 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
567 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
568
569 fn compile_instruction_tanh_scalar<T: Component<D>>(
571 f: &mut std::fmt::Formatter<'_>,
572 input: T,
573 ) -> std::fmt::Result {
574 let elem = input.elem();
575 match elem {
576 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
577 write!(f, "{elem}(tanh(float({input})))")
578 }
579 _ => write!(f, "tanh({input})"),
580 }
581 }
582
583 fn compile_instruction_find_first_set<T: Component<D>>(
585 f: &mut std::fmt::Formatter<'_>,
586 input: T,
587 out_elem: Elem<D>,
588 ) -> std::fmt::Result;
589 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
590 f: &mut std::fmt::Formatter<'_>,
591 input: T,
592 out_elem: Elem<D>,
593 ) -> std::fmt::Result;
594
595 fn compile_instruction_trailing_zeros_scalar<T: Component<D>>(
596 f: &mut std::fmt::Formatter<'_>,
597 input: T,
598 out_elem: Elem<D>,
599 ) -> std::fmt::Result;
600
601 fn compile_instruction_popcount_scalar<T: Component<D>>(
602 f: &mut std::fmt::Formatter<'_>,
603 input: T,
604 out_elem: Elem<D>,
605 ) -> std::fmt::Result {
606 write!(f, "{out_elem}(")?;
607 match input.elem() {
608 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
609 Elem::U32 => write!(f, "__popc({input})"),
610 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
611 Elem::U64 => write!(f, "__popcll({input})"),
612 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
613 }?;
614 write!(f, ")")
615 }
616
617 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
618 f: &mut std::fmt::Formatter<'_>,
619 input: T,
620 out_elem: Elem<D>,
621 ) -> std::fmt::Result {
622 write!(f, "{out_elem}(")?;
623 match out_elem {
624 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
625 Elem::U32 => write!(f, "__brev({input})"),
626 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
627 Elem::U64 => write!(f, "__brevll({input})"),
628 _ => write!(
629 f,
630 "__brev({}) >> {}",
631 super::unary::zero_extend(input),
632 (size_of::<u32>() - out_elem.size()) * 8
633 ),
634 }?;
635 write!(f, ")")
636 }
637
638 fn compile_instruction_max_function_name(
640 f: &mut std::fmt::Formatter<'_>,
641 item: Item<D>,
642 ) -> std::fmt::Result;
643
644 fn compile_instruction_min_function_name(
645 f: &mut std::fmt::Formatter<'_>,
646 item: Item<D>,
647 ) -> std::fmt::Result;
648
649 fn compile_instruction_powf(
650 f: &mut std::fmt::Formatter<'_>,
651 lhs: &str,
652 rhs: &str,
653 elem: Elem<D>,
654 ) -> std::fmt::Result {
655 match elem {
656 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
657 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
658 _ => write!(f, "#error Unsupported type for powf: {elem}"),
659 }
660 }
661
662 fn compile_instruction_hypot(
663 f: &mut std::fmt::Formatter<'_>,
664 lhs: &str,
665 rhs: &str,
666 elem: Elem<D>,
667 ) -> std::fmt::Result {
668 match elem {
669 Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
670 Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
671 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
672 }
673 }
674
675 fn compile_instruction_rhypot(
676 f: &mut std::fmt::Formatter<'_>,
677 lhs: &str,
678 rhs: &str,
679 elem: Elem<D>,
680 ) -> std::fmt::Result {
681 match elem {
682 Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
683 Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
684 _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
685 }
686 }
687
688 fn compile_instruction_half_function_name_prefix() -> &'static str {
689 "h"
690 }
691
692 fn compile_instruction_half2_function_name_prefix() -> &'static str {
693 "h2"
694 }
695
696 fn compile_warp_shuffle(
698 f: &mut std::fmt::Formatter<'_>,
699 var: &str,
700 source: &str,
701 ) -> std::fmt::Result;
702 fn compile_warp_shuffle_xor(
703 f: &mut std::fmt::Formatter<'_>,
704 var: &str,
705 elem: &Elem<D>,
706 offset: &str,
707 ) -> std::fmt::Result;
708 fn compile_warp_shuffle_up(
709 f: &mut std::fmt::Formatter<'_>,
710 var: &str,
711 offset: &str,
712 ) -> std::fmt::Result;
713 fn compile_warp_shuffle_down(
714 f: &mut std::fmt::Formatter<'_>,
715 var: &str,
716 offset: &str,
717 ) -> std::fmt::Result;
718 fn compile_warp_all<T: Component<D>>(
719 f: &mut std::fmt::Formatter<'_>,
720 input: &T,
721 ) -> std::fmt::Result;
722 fn compile_warp_any<T: Component<D>>(
723 f: &mut std::fmt::Formatter<'_>,
724 input: &T,
725 ) -> std::fmt::Result;
726 fn compile_warp_ballot(
727 f: &mut std::fmt::Formatter<'_>,
728 input: &Variable<D>,
729 out_elem: &Elem<D>,
730 ) -> std::fmt::Result;
731 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
732 write!(
733 f,
734 "
735unsigned int mask = __activemask();
736unsigned int leader = __ffs(mask) - 1;
737{out} = threadIdx.x % warpSize == leader;
738 "
739 )
740 }
741}
742
743#[derive(Debug, Clone, Copy, new)]
744pub struct ManualMma<'a, D: Dialect> {
745 pub shape: MmaShape<D>,
746 pub frag_a: &'a Variable<D>,
747 pub frag_b: &'a Variable<D>,
748 pub frag_c: &'a Variable<D>,
749 pub frag_d: &'a Variable<D>,
750}
751
752pub trait DialectWarpReduceCompiler<D: Dialect>:
753 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
754{
755 fn warp_reduce_sum(
756 f: &mut core::fmt::Formatter<'_>,
757 input: &Variable<D>,
758 out: &Variable<D>,
759 ) -> core::fmt::Result {
760 reduce_operator(f, input, out, "+=")
761 }
762 fn warp_reduce_prod(
763 f: &mut core::fmt::Formatter<'_>,
764 input: &Variable<D>,
765 out: &Variable<D>,
766 ) -> core::fmt::Result {
767 reduce_operator(f, input, out, "*=")
768 }
769 fn warp_reduce_max(
770 f: &mut core::fmt::Formatter<'_>,
771 input: &Variable<D>,
772 out: &Variable<D>,
773 ) -> core::fmt::Result {
774 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
775 }
776 fn warp_reduce_min(
777 f: &mut core::fmt::Formatter<'_>,
778 input: &Variable<D>,
779 out: &Variable<D>,
780 ) -> core::fmt::Result {
781 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
782 }
783 fn warp_reduce_all(
784 f: &mut core::fmt::Formatter<'_>,
785 input: &Variable<D>,
786 out: &Variable<D>,
787 ) -> core::fmt::Result {
788 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
789 }
790 fn warp_reduce_any(
791 f: &mut core::fmt::Formatter<'_>,
792 input: &Variable<D>,
793 out: &Variable<D>,
794 ) -> core::fmt::Result {
795 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
796 }
797 fn warp_reduce_sum_inclusive(
798 f: &mut core::fmt::Formatter<'_>,
799 input: &Variable<D>,
800 out: &Variable<D>,
801 ) -> core::fmt::Result {
802 reduce_inclusive(f, input, out, "+=")
803 }
804 fn warp_reduce_prod_inclusive(
805 f: &mut core::fmt::Formatter<'_>,
806 input: &Variable<D>,
807 out: &Variable<D>,
808 ) -> core::fmt::Result {
809 reduce_inclusive(f, input, out, "*=")
810 }
811 fn warp_reduce_sum_exclusive(
812 f: &mut core::fmt::Formatter<'_>,
813 input: &Variable<D>,
814 out: &Variable<D>,
815 ) -> core::fmt::Result {
816 reduce_exclusive(f, input, out, "+=", "0")
817 }
818 fn warp_reduce_prod_exclusive(
819 f: &mut core::fmt::Formatter<'_>,
820 input: &Variable<D>,
821 out: &Variable<D>,
822 ) -> core::fmt::Result {
823 reduce_exclusive(f, input, out, "*=", "1")
824 }
825}
826
827pub trait DialectWmmaCompiler<D: Dialect>:
828 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
829{
830 #[allow(unused_variables)]
831 fn compile_wmma_includes(
832 f: &mut std::fmt::Formatter<'_>,
833 flags: &Flags<D>,
834 ) -> std::fmt::Result {
835 Ok(())
836 }
837 #[allow(unused_variables)]
838 fn compile_wmma_type_definitions(
839 f: &mut std::fmt::Formatter<'_>,
840 flags: &Flags<D>,
841 ) -> std::fmt::Result {
842 Ok(())
843 }
844 #[allow(unused_variables)]
845 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
846 Ok(())
847 }
848 #[allow(unused_variables)]
849 fn compile_wwma_fragment_ident(
850 f: &mut std::fmt::Formatter<'_>,
851 ident: &FragmentIdent<D>,
852 ) -> std::fmt::Result {
853 Ok(())
854 }
855 #[allow(unused_variables)]
856 fn compile_wmma_fragment_layout(
857 f: &mut std::fmt::Formatter<'_>,
858 layout: &FragmentLayout<D>,
859 ) -> std::fmt::Result {
860 Ok(())
861 }
862 #[allow(unused_variables)]
863 fn compile_wmma_fragment(
864 f: &mut std::fmt::Formatter<'_>,
865 fragment: &Fragment<D>,
866 ) -> std::fmt::Result {
867 Ok(())
868 }
869
870 fn compile_wmma_fragment_declaration(
871 f: &mut std::fmt::Formatter<'_>,
872 var: &Variable<D>,
873 ) -> std::fmt::Result;
874
875 fn compile_wmma_instruction(
876 f: &mut std::fmt::Formatter<'_>,
877 instruction: &WmmaInstruction<D>,
878 ) -> std::fmt::Result;
879 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
880 fn compile_scaled_mma(
881 f: &mut std::fmt::Formatter<'_>,
882 mma: ManualMma<D>,
883 scales_a: Variable<D>,
884 scales_b: Variable<D>,
885 scales_factor: u32,
886 ) -> std::fmt::Result;
887 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
888 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
889 fn supported_scaled_mma_combinations(
890 _arch: &D::Architecture,
891 ) -> SupportedScaledMmaCombinations {
892 Vec::new()
893 }
894}
895
896pub trait DialectProcessors<D: Dialect> {
899 fn processors() -> Vec<Box<dyn Processor>>;
900}