1use std::{collections::HashSet, fmt::Debug};
2use std::{fmt::Display, hash::Hash};
3
4use cubecl_core::ir::{ConstantValue, Processor};
5
6use crate::shared::{
7 FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
8 reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
9};
10
11use super::{
12 Architecture, AtomicKind, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, KernelArg, SharedMemory, Variable,
14 WarpInstruction, WmmaInstruction,
15};
16
17pub trait Dialect:
20 DialectIncludes<Self>
21 + DialectTypes<Self>
22 + DialectBindings<Self>
23 + DialectWarpReduceCompiler<Self>
24 + DialectCubeBuiltins<Self>
25 + DialectInstructions<Self>
26 + DialectWmmaCompiler<Self>
27 + DialectProcessors<Self>
28 + Default
29 + Clone
30 + Copy
31 + Debug
32 + Send
33 + Sync
34 + Eq
35 + Hash
36 + 'static
37{
38 type Architecture: Architecture;
39}
40
41pub trait DialectIncludes<D: Dialect> {
44 type Extension: Debug + Clone + Sync + Send;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
47 fn compile_extensions(
48 f: &mut std::fmt::Formatter<'_>,
49 extensions: &[Self::Extension],
50 ) -> std::fmt::Result;
51 fn register_instruction_extension(
52 extensions: &mut Vec<Self::Extension>,
53 instruction: &Instruction<D>,
54 );
55 fn register_warp_instruction_extension(
56 extensions: &mut Vec<Self::Extension>,
57 instruction: &WarpInstruction<D>,
58 );
59 #[allow(unused_variables)]
60 fn register_wmma_instruction_extension(
61 extensions: &mut Vec<Self::Extension>,
62 instruction: &WmmaInstruction<D>,
63 ) {
64 }
65}
66
67pub trait DialectTypes<D: Dialect> {
70 fn item_can_be_optimized() -> bool;
71 fn compile_elem(
72 f: &mut std::fmt::Formatter<'_>,
73 elem: &Elem<D>,
74 word: bool,
75 ) -> std::fmt::Result;
76
77 fn compile_atomic_kind(
78 f: &mut std::fmt::Formatter<'_>,
79 kind: &AtomicKind<D>,
80 ) -> std::fmt::Result {
81 match kind {
82 AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
83 AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
84 AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
85 AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
86 AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
87 AtomicKind::F16x2 => write!(f, "{}", Elem::<D>::F16x2),
88 AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
89 AtomicKind::BF16x2 => write!(f, "{}", Elem::<D>::BF16x2),
90 AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
91 AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
92 AtomicKind::_Dialect(_) => Ok(()),
93 }
94 }
95
96 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
97 fn compile_type_definitions(
98 f: &mut std::fmt::Formatter<'_>,
99 items: &HashSet<Item<D>>,
100 scalars: &[(Elem<D>, usize)],
101 info: &cubecl_core::Info,
102 flags: &Flags<D>,
103 ) -> std::fmt::Result;
104 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
105 fn compile_shared_memory_declaration(
106 f: &mut std::fmt::Formatter<'_>,
107 shared: &SharedMemory<D>,
108 ) -> std::fmt::Result {
109 match shared {
110 SharedMemory::Array {
111 index,
112 item,
113 length,
114 offset,
115 ..
116 } => {
117 let size_bytes = *length * item.size();
118 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
119 writeln!(
120 f,
121 "{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
122 )
123 }
124 SharedMemory::Value {
125 index,
126 item,
127 offset,
128 ..
129 } => {
130 let size_bytes = item.size() as u32;
131 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
132 writeln!(
133 f,
134 "{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
135 )
136 }
137 }
138 }
139 fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
140 Ok(())
141 }
142 fn address_space_for_variable(_variable: &Variable<D>) -> String {
144 "".to_string()
145 }
146}
147
148pub trait DialectBindings<D: Dialect> {
151 fn compile_kernel_signature(
152 f: &mut std::fmt::Formatter<'_>,
153 kernel_name: &str,
154 tensor_maps: &[KernelArg<D>],
155 buffers: &[KernelArg<D>],
156 flags: &Flags<D>,
157 ) -> std::fmt::Result;
158 fn compile_bindings_body(
159 _f: &mut std::fmt::Formatter<'_>,
160 _body: &Body<D>,
161 ) -> std::fmt::Result {
162 Ok(())
163 }
164}
165
166pub trait DialectCubeBuiltins<D: Dialect> {
169 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
175 let unit_pos_plane = flags.unit_pos_plane;
176 let plane_dim_checked = flags.plane_dim_checked;
177 let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
178 let plane_pos = flags.plane_pos;
179 let absolute_pos = flags.absolute_pos || unit_pos_plane;
180 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
181 let cube_dim = flags.cube_dim;
182 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
183 let unit_pos = flags.unit_pos;
184 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
185 let cube_count = flags.cube_count;
186 let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
187 let cube_pos = flags.cube_pos;
188 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
189 let cluster_group = flags.cluster_pos;
190
191 CubeIndexFlags {
192 absolute_pos,
193 absolute_pos_tuple,
194 cube_count,
195 cube_count_tuple,
196 cube_dim,
197 cube_dim_tuple,
198 cube_pos,
199 cube_pos_tuple,
200 plane_dim,
201 plane_dim_checked,
202 plane_pos,
203 unit_pos_tuple,
204 unit_pos,
205 unit_pos_plane,
206 cluster_pos: cluster_group,
207 }
208 }
209
210 fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 let variable = Variable::<D>::AbsolutePosBaseName;
212 let ty = variable.item();
213 let cube_pos_x = Variable::<D>::CubePosX;
214 let cube_pos_y = Variable::<D>::CubePosY;
215 let cube_pos_z = Variable::<D>::CubePosZ;
216 let cube_dim_x = Variable::<D>::CubeDimX;
217 let cube_dim_y = Variable::<D>::CubeDimY;
218 let cube_dim_z = Variable::<D>::CubeDimZ;
219 let unit_pos_x = Variable::<D>::UnitPosX;
220 let unit_pos_y = Variable::<D>::UnitPosY;
221 let unit_pos_z = Variable::<D>::UnitPosZ;
222 writeln!(
223 f,
224 "{ty} {variable} = make_{ty}(
225 {cube_pos_x} * {cube_dim_x} + {unit_pos_x},
226 {cube_pos_y} * {cube_dim_y} + {unit_pos_y},
227 {cube_pos_z} * {cube_dim_z} + {unit_pos_z}
228);"
229 )
230 }
231
232 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 f.write_str("absoluteIdx")
234 }
235
236 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 f.write_str("idxGlobal")
238 }
239
240 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 Self::compile_absolute_pos_base_name(f)?;
242 write!(f, ".x")
243 }
244
245 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 Self::compile_absolute_pos_base_name(f)?;
247 write!(f, ".y")
248 }
249
250 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 Self::compile_absolute_pos_base_name(f)?;
252 write!(f, ".z")
253 }
254
255 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 f.write_str("gridDim")
257 }
258
259 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.write_str("gridDimGlobal")
261 }
262
263 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 Self::compile_cube_count_base_name(f)?;
265 write!(f, ".x")
266 }
267
268 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 Self::compile_cube_count_base_name(f)?;
270 write!(f, ".y")
271 }
272
273 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 Self::compile_cube_count_base_name(f)?;
275 write!(f, ".z")
276 }
277
278 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 f.write_str("blockDim")
280 }
281
282 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.write_str("blockDimGlobal")
284 }
285
286 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 Self::compile_cube_dim_base_name(f)?;
288 write!(f, ".x")
289 }
290
291 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 Self::compile_cube_dim_base_name(f)?;
293 write!(f, ".y")
294 }
295
296 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 Self::compile_cube_dim_base_name(f)?;
298 write!(f, ".z")
299 }
300
301 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.write_str("blockIdx")
303 }
304
305 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 f.write_str("blockIdxGlobal")
307 }
308
309 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 Self::compile_cube_pos_base_name(f)?;
311 write!(f, ".x")
312 }
313
314 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 Self::compile_cube_pos_base_name(f)?;
316 write!(f, ".y")
317 }
318
319 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 Self::compile_cube_pos_base_name(f)?;
321 write!(f, ".z")
322 }
323
324 fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 let variable = Variable::<D>::UnitPos;
326 let ty = variable.item();
327 let cube_dim_x = Variable::<D>::CubeDimX;
328 let cube_dim_y = Variable::<D>::CubeDimY;
329 let unit_pos_x = Variable::<D>::UnitPosX;
330 let unit_pos_y = Variable::<D>::UnitPosY;
331 let unit_pos_z = Variable::<D>::UnitPosZ;
332 writeln!(
333 f,
334 "{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
335 )
336 }
337
338 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.write_str("threadIdxGlobal")
340 }
341
342 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 f.write_str("threadIdx")
344 }
345
346 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 Self::compile_unit_pos_base_name(f)?;
348 write!(f, ".x")
349 }
350
351 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 Self::compile_unit_pos_base_name(f)?;
353 write!(f, ".y")
354 }
355
356 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 Self::compile_unit_pos_base_name(f)?;
358 write!(f, ".z")
359 }
360
361 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 f.write_str("warpSize")
363 }
364
365 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 f.write_str("warpSizeChecked")
367 }
368
369 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 let unit_pos_x = Variable::<D>::UnitPosX;
371 let plane_dim = Variable::<D>::PlaneDim;
372 write!(f, "{unit_pos_x} / {plane_dim}")
373 }
374
375 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
377 let plane_dim = Variable::<D>::PlaneDim;
378 let ty = plane_dim.item();
379 write!(f, "{ty}({absolute_pos}) % {plane_dim}")
380 }
381
382 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 write!(f, "0")
384 }
385 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 write!(f, "0")
387 }
388 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 write!(f, "0")
390 }
391 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 write!(f, "0")
393 }
394}
395
396pub trait DialectInstructions<D: Dialect> {
399 fn compile_atomic_add(
401 f: &mut std::fmt::Formatter<'_>,
402 lhs: &Variable<D>,
403 rhs: &Variable<D>,
404 out: &Variable<D>,
405 ) -> std::fmt::Result {
406 let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
407 let [lhs, rhs, out_optimized] = optimized.args;
408
409 let addr_space = D::address_space_for_variable(out);
410 let out_item = out.item();
411 let out = out.fmt_left();
412
413 match out_optimized.elem() {
414 Elem::I64 => writeln!(
415 f,
416 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
417 uint = Elem::<D>::U64
418 ),
419 Elem::F32 if out_item.vectorization > 1 => {
420 let vec_ty = format!("float{}", out_item.vectorization);
423 let out_tmp = Variable::tmp(out_optimized.item());
424 writeln!(
425 f,
426 "{vec_ty} {out_tmp} = atomicAdd(
427 reinterpret_cast<{addr_space}{vec_ty}*>({lhs}),
428 reinterpret_cast<const {addr_space}{vec_ty}&>({rhs}));",
429 )?;
430 writeln!(
431 f,
432 "{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
433 )
434 }
435 Elem::F16x2 | Elem::BF16x2 => {
436 let out_tmp = Variable::tmp(out_optimized.item());
437 writeln!(
438 f,
439 "{} = atomicAdd(
440 reinterpret_cast<{addr_space}{}*>({lhs}),
441 reinterpret_cast<const {addr_space}{}&>({rhs}));",
442 out_tmp.fmt_left(),
443 lhs.item(),
444 rhs.item()
445 )?;
446 writeln!(
447 f,
448 "{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
449 )
450 }
451 _ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
452 }
453 }
454
455 fn compile_atomic_and(
456 f: &mut std::fmt::Formatter<'_>,
457 lhs: &Variable<D>,
458 rhs: &Variable<D>,
459 out: &Variable<D>,
460 ) -> std::fmt::Result {
461 let out = out.fmt_left();
462 writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
463 }
464
465 fn compile_atomic_cas(
466 f: &mut std::fmt::Formatter<'_>,
467 input: &Variable<D>,
468 cmp: &Variable<D>,
469 val: &Variable<D>,
470 out: &Variable<D>,
471 ) -> std::fmt::Result {
472 let out_item = out.item();
473 let out = out.fmt_left();
474 match val.elem() {
475 Elem::F32 if val.item().vectorization == 2 => {
477 let u64 = Item::new(Elem::<D>::U64, 1, true);
478 let out_tmp = Variable::tmp(u64);
479 writeln!(
480 f,
481 "{} = atomicCAS(
482 reinterpret_cast<{u64}*>({input}),
483 reinterpret_cast<{u64}&>({cmp}),
484 reinterpret_cast<{u64}&>({val}));",
485 out_tmp.fmt_left()
486 )?;
487 writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
488 }
489 Elem::F16 | Elem::BF16 if val.item().vectorization == 2 => {
490 let u32 = Item::new(Elem::<D>::U32, 1, true);
491 let out_tmp = Variable::tmp(u32);
492 writeln!(
493 f,
494 "{} = atomicCAS(
495 reinterpret_cast<{u32}*>({input}),
496 reinterpret_cast<{u32}&>({cmp}),
497 reinterpret_cast<{u32}&>({val}));",
498 out_tmp.fmt_left()
499 )?;
500 writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
501 }
502 _ => writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});"),
503 }
504 }
505
506 fn compile_atomic_load(
507 f: &mut std::fmt::Formatter<'_>,
508 input: &Variable<D>,
509 out: &Variable<D>,
510 ) -> std::fmt::Result {
511 let zero = Variable::Constant(ConstantValue::UInt(0), input.item());
512 Self::compile_atomic_add(f, input, &zero, out)
513 }
514
515 fn compile_atomic_max(
516 f: &mut std::fmt::Formatter<'_>,
517 lhs: &Variable<D>,
518 rhs: &Variable<D>,
519 out: &Variable<D>,
520 ) -> std::fmt::Result {
521 let out = out.fmt_left();
522 writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
523 }
524
525 fn compile_atomic_min(
526 f: &mut std::fmt::Formatter<'_>,
527 lhs: &Variable<D>,
528 rhs: &Variable<D>,
529 out: &Variable<D>,
530 ) -> std::fmt::Result {
531 let out = out.fmt_left();
532 writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
533 }
534
535 fn compile_atomic_or(
536 f: &mut std::fmt::Formatter<'_>,
537 lhs: &Variable<D>,
538 rhs: &Variable<D>,
539 out: &Variable<D>,
540 ) -> std::fmt::Result {
541 let out = out.fmt_left();
542 writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
543 }
544
545 fn compile_atomic_store(
546 f: &mut std::fmt::Formatter<'_>,
547 input: &Variable<D>,
548 out: &Variable<D>,
549 ) -> std::fmt::Result {
550 let tmp = Variable::tmp(input.item());
551 Self::compile_atomic_swap(f, out, input, &tmp)
552 }
553
554 fn compile_atomic_sub(
555 f: &mut std::fmt::Formatter<'_>,
556 lhs: &Variable<D>,
557 rhs: &Variable<D>,
558 out: &Variable<D>,
559 ) -> std::fmt::Result {
560 let out = out.fmt_left();
561 match rhs.elem() {
562 Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
563 Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
564 Elem::I64 => writeln!(
565 f,
566 "{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
567 uint = Elem::<D>::U64
568 ),
569 _ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
570 }
571 }
572
573 fn compile_atomic_swap(
574 f: &mut std::fmt::Formatter<'_>,
575 lhs: &Variable<D>,
576 rhs: &Variable<D>,
577 out: &Variable<D>,
578 ) -> std::fmt::Result {
579 let out_item = out.item();
580 let out = out.fmt_left();
581 match rhs.elem() {
582 Elem::F32 if rhs.item().vectorization == 2 => {
584 let u64 = Item::new(Elem::<D>::U64, 1, true);
585 let out_tmp = Variable::tmp(u64);
586 writeln!(
587 f,
588 "{} = atomicExch(
589 reinterpret_cast<{u64}*>({lhs}),
590 reinterpret_cast<{u64}&>({rhs}));",
591 out_tmp.fmt_left()
592 )?;
593 writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
594 }
595 Elem::F16 | Elem::BF16 if rhs.item().vectorization == 2 => {
596 let u32 = Item::new(Elem::<D>::U32, 1, true);
597 let out_tmp = Variable::tmp(u32);
598 writeln!(
599 f,
600 "{} = atomicExch(
601 reinterpret_cast<{u32}*>({lhs}),
602 reinterpret_cast<{u32}&>({rhs}));",
603 out_tmp.fmt_left()
604 )?;
605 writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
606 }
607 _ => writeln!(f, "{out} = atomicExch({lhs}, {rhs});"),
608 }
609 }
610
611 fn compile_atomic_xor(
612 f: &mut std::fmt::Formatter<'_>,
613 lhs: &Variable<D>,
614 rhs: &Variable<D>,
615 out: &Variable<D>,
616 ) -> std::fmt::Result {
617 let out = out.fmt_left();
618 writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
619 }
620
621 fn compile_saturating_add(
622 f: &mut std::fmt::Formatter<'_>,
623 lhs: impl Display,
624 rhs: impl Display,
625 item: Item<D>,
626 ) -> std::fmt::Result;
627
628 fn compile_saturating_sub(
629 f: &mut std::fmt::Formatter<'_>,
630 lhs: impl Display,
631 rhs: impl Display,
632 item: Item<D>,
633 ) -> std::fmt::Result;
634
635 fn compile_instruction_printf(
637 f: &mut std::fmt::Formatter<'_>,
638 format_string: &str,
639 args: &[Variable<D>],
640 ) -> std::fmt::Result {
641 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
642 let args = match args.is_empty() {
643 true => "".to_string(),
644 false => format!(", {}", args.join(",")),
645 };
646 writeln!(f, "printf({format_string:?}{args});")
647 }
648
649 fn compile_instruction_log1p_scalar<T: Component<D>>(
651 f: &mut std::fmt::Formatter<'_>,
652 input: T,
653 ) -> std::fmt::Result {
654 let elem = input.elem();
655 match elem {
656 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
657 write!(f, "{elem}(log1p(float({input})))")
658 }
659 _ => write!(f, "log1p({input})"),
660 }
661 }
662
663 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
665 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
666 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
667
668 fn compile_instruction_tanh_scalar<T: Component<D>>(
670 f: &mut std::fmt::Formatter<'_>,
671 input: T,
672 ) -> std::fmt::Result {
673 let elem = input.elem();
674 match elem {
675 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
676 write!(f, "{elem}(tanh(float({input})))")
677 }
678 _ => write!(f, "tanh({input})"),
679 }
680 }
681
682 fn compile_instruction_find_first_set<T: Component<D>>(
684 f: &mut std::fmt::Formatter<'_>,
685 input: T,
686 out_elem: Elem<D>,
687 ) -> std::fmt::Result;
688 fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
689 f: &mut std::fmt::Formatter<'_>,
690 input: T,
691 out_elem: Elem<D>,
692 ) -> std::fmt::Result;
693
694 fn compile_instruction_trailing_zeros_scalar<T: Component<D>>(
695 f: &mut std::fmt::Formatter<'_>,
696 input: T,
697 out_elem: Elem<D>,
698 ) -> std::fmt::Result;
699
700 fn compile_instruction_popcount_scalar<T: Component<D>>(
701 f: &mut std::fmt::Formatter<'_>,
702 input: T,
703 out_elem: Elem<D>,
704 ) -> std::fmt::Result {
705 write!(f, "{out_elem}(")?;
706 match input.elem() {
707 Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
708 Elem::U32 => write!(f, "__popc({input})"),
709 Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
710 Elem::U64 => write!(f, "__popcll({input})"),
711 _ => write!(f, "__popc({})", super::unary::zero_extend(input)),
712 }?;
713 write!(f, ")")
714 }
715
716 fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
717 f: &mut std::fmt::Formatter<'_>,
718 input: T,
719 out_elem: Elem<D>,
720 ) -> std::fmt::Result {
721 write!(f, "{out_elem}(")?;
722 match out_elem {
723 Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
724 Elem::U32 => write!(f, "__brev({input})"),
725 Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
726 Elem::U64 => write!(f, "__brevll({input})"),
727 _ => write!(
728 f,
729 "__brev({}) >> {}",
730 super::unary::zero_extend(input),
731 (size_of::<u32>() - out_elem.size()) * 8
732 ),
733 }?;
734 write!(f, ")")
735 }
736
737 fn compile_instruction_max_function_name(
739 f: &mut std::fmt::Formatter<'_>,
740 item: Item<D>,
741 ) -> std::fmt::Result;
742
743 fn compile_instruction_min_function_name(
744 f: &mut std::fmt::Formatter<'_>,
745 item: Item<D>,
746 ) -> std::fmt::Result;
747
748 fn compile_instruction_powf(
749 f: &mut std::fmt::Formatter<'_>,
750 lhs: &str,
751 rhs: &str,
752 elem: Elem<D>,
753 ) -> std::fmt::Result {
754 match elem {
755 Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
756 Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
757 _ => write!(f, "#error Unsupported type for powf: {elem}"),
758 }
759 }
760
761 fn compile_instruction_hypot(
762 f: &mut std::fmt::Formatter<'_>,
763 lhs: &str,
764 rhs: &str,
765 elem: Elem<D>,
766 ) -> std::fmt::Result {
767 match elem {
768 Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
769 Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
770 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
771 }
772 }
773
774 fn compile_instruction_rhypot(
775 f: &mut std::fmt::Formatter<'_>,
776 lhs: &str,
777 rhs: &str,
778 elem: Elem<D>,
779 ) -> std::fmt::Result {
780 match elem {
781 Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
782 Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
783 _ => write!(f, "#error Unsupported type for rhypot: {elem}"),
784 }
785 }
786
787 fn compile_instruction_half_function_name_prefix() -> &'static str {
788 "h"
789 }
790
791 fn compile_instruction_half2_function_name_prefix() -> &'static str {
792 "h2"
793 }
794
795 fn compile_warp_shuffle(
797 f: &mut std::fmt::Formatter<'_>,
798 var: &str,
799 source: &str,
800 ) -> std::fmt::Result;
801 fn compile_warp_shuffle_xor(
802 f: &mut std::fmt::Formatter<'_>,
803 var: &str,
804 elem: &Elem<D>,
805 offset: &str,
806 ) -> std::fmt::Result;
807 fn compile_warp_shuffle_up(
808 f: &mut std::fmt::Formatter<'_>,
809 var: &str,
810 offset: &str,
811 ) -> std::fmt::Result;
812 fn compile_warp_shuffle_down(
813 f: &mut std::fmt::Formatter<'_>,
814 var: &str,
815 offset: &str,
816 ) -> std::fmt::Result;
817 fn compile_warp_all<T: Component<D>>(
818 f: &mut std::fmt::Formatter<'_>,
819 input: &T,
820 ) -> std::fmt::Result;
821 fn compile_warp_any<T: Component<D>>(
822 f: &mut std::fmt::Formatter<'_>,
823 input: &T,
824 ) -> std::fmt::Result;
825 fn compile_warp_ballot(
826 f: &mut std::fmt::Formatter<'_>,
827 input: &Variable<D>,
828 out_elem: &Elem<D>,
829 ) -> std::fmt::Result;
830 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
831 write!(
832 f,
833 "
834unsigned int mask = __activemask();
835unsigned int leader = __ffs(mask) - 1;
836{out} = threadIdx.x % warpSize == leader;
837 "
838 )
839 }
840 fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
841}
842
843#[derive(Debug, Clone, Copy, new)]
844pub struct ManualMma<'a, D: Dialect> {
845 pub shape: MmaShape<D>,
846 pub frag_a: &'a Variable<D>,
847 pub frag_b: &'a Variable<D>,
848 pub frag_c: &'a Variable<D>,
849 pub frag_d: &'a Variable<D>,
850}
851
852pub trait DialectWarpReduceCompiler<D: Dialect>:
853 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
854{
855 fn warp_reduce_sum(
856 f: &mut core::fmt::Formatter<'_>,
857 input: &Variable<D>,
858 out: &Variable<D>,
859 ) -> core::fmt::Result {
860 reduce_operator(f, input, out, "+=")
861 }
862 fn warp_reduce_prod(
863 f: &mut core::fmt::Formatter<'_>,
864 input: &Variable<D>,
865 out: &Variable<D>,
866 ) -> core::fmt::Result {
867 reduce_operator(f, input, out, "*=")
868 }
869 fn warp_reduce_max(
870 f: &mut core::fmt::Formatter<'_>,
871 input: &Variable<D>,
872 out: &Variable<D>,
873 ) -> core::fmt::Result {
874 reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
875 }
876 fn warp_reduce_min(
877 f: &mut core::fmt::Formatter<'_>,
878 input: &Variable<D>,
879 out: &Variable<D>,
880 ) -> core::fmt::Result {
881 reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
882 }
883 fn warp_reduce_all(
884 f: &mut core::fmt::Formatter<'_>,
885 input: &Variable<D>,
886 out: &Variable<D>,
887 ) -> core::fmt::Result {
888 reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
889 }
890 fn warp_reduce_any(
891 f: &mut core::fmt::Formatter<'_>,
892 input: &Variable<D>,
893 out: &Variable<D>,
894 ) -> core::fmt::Result {
895 reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
896 }
897 fn warp_reduce_sum_inclusive(
898 f: &mut core::fmt::Formatter<'_>,
899 input: &Variable<D>,
900 out: &Variable<D>,
901 ) -> core::fmt::Result {
902 reduce_inclusive(f, input, out, "+=")
903 }
904 fn warp_reduce_prod_inclusive(
905 f: &mut core::fmt::Formatter<'_>,
906 input: &Variable<D>,
907 out: &Variable<D>,
908 ) -> core::fmt::Result {
909 reduce_inclusive(f, input, out, "*=")
910 }
911 fn warp_reduce_sum_exclusive(
912 f: &mut core::fmt::Formatter<'_>,
913 input: &Variable<D>,
914 out: &Variable<D>,
915 ) -> core::fmt::Result {
916 reduce_exclusive(f, input, out, "+=", "0")
917 }
918 fn warp_reduce_prod_exclusive(
919 f: &mut core::fmt::Formatter<'_>,
920 input: &Variable<D>,
921 out: &Variable<D>,
922 ) -> core::fmt::Result {
923 reduce_exclusive(f, input, out, "*=", "1")
924 }
925}
926
927pub trait DialectWmmaCompiler<D: Dialect>:
928 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
929{
930 #[allow(unused_variables)]
931 fn compile_wmma_includes(
932 f: &mut std::fmt::Formatter<'_>,
933 flags: &Flags<D>,
934 ) -> std::fmt::Result {
935 Ok(())
936 }
937 #[allow(unused_variables)]
938 fn compile_wmma_type_definitions(
939 f: &mut std::fmt::Formatter<'_>,
940 flags: &Flags<D>,
941 ) -> std::fmt::Result {
942 Ok(())
943 }
944 #[allow(unused_variables)]
945 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
946 Ok(())
947 }
948 #[allow(unused_variables)]
949 fn compile_wwma_fragment_ident(
950 f: &mut std::fmt::Formatter<'_>,
951 ident: &FragmentIdent<D>,
952 ) -> std::fmt::Result {
953 Ok(())
954 }
955 #[allow(unused_variables)]
956 fn compile_wmma_fragment_layout(
957 f: &mut std::fmt::Formatter<'_>,
958 layout: &FragmentLayout<D>,
959 ) -> std::fmt::Result {
960 Ok(())
961 }
962 #[allow(unused_variables)]
963 fn compile_wmma_fragment(
964 f: &mut std::fmt::Formatter<'_>,
965 fragment: &Fragment<D>,
966 ) -> std::fmt::Result {
967 Ok(())
968 }
969
970 fn compile_wmma_fragment_declaration(
971 f: &mut std::fmt::Formatter<'_>,
972 var: &Variable<D>,
973 ) -> std::fmt::Result;
974
975 fn compile_wmma_instruction(
976 f: &mut std::fmt::Formatter<'_>,
977 instruction: &WmmaInstruction<D>,
978 ) -> std::fmt::Result;
979 fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
980 fn compile_scaled_mma(
981 f: &mut std::fmt::Formatter<'_>,
982 mma: ManualMma<D>,
983 scales_a: Variable<D>,
984 scales_b: Variable<D>,
985 scales_factor: u32,
986 ) -> std::fmt::Result;
987 fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
988 fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
989 fn supported_scaled_mma_combinations(
990 _arch: &D::Architecture,
991 ) -> SupportedScaledMmaCombinations {
992 Vec::new()
993 }
994}
995
996pub trait DialectProcessors<D: Dialect> {
999 fn processors() -> Vec<Box<dyn Processor>>;
1000}