1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5 Dialect,
6 shared::{
7 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8 DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
9 DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
10 FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
11 SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
12 },
13};
14use cubecl_core::{
15 compute::{Location, Visibility},
16 ir::{self as gpu},
17};
18use cubecl_runtime::MmaConfig;
19
20use super::{
21 AddressSpace, Extension,
22 arch::MetalArchitecture,
23 extension::{format_ffs, format_mulhi},
24 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
25};
26
27#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
28pub struct MslDialect {}
29
30impl Dialect for MslDialect {
33 type Architecture = MetalArchitecture;
34}
35
36impl MslDialect {
37 fn warp_op_vectorized(
38 f: &mut core::fmt::Formatter<'_>,
39 input: &Variable<Self>,
40 out: &Variable<Self>,
41 simd_op_prefix: &str,
42 simd_op_suffix: &str,
43 ) -> core::fmt::Result {
44 let out = out.fmt_left();
45 let vectorization = input.item().vectorization;
46
47 f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
48
49 for k in 0..vectorization {
50 let index = if vectorization > 1 {
51 format!(".i_{k}")
52 } else {
53 String::new()
54 };
55 let comma = if k + 1 < vectorization { "," } else { "" };
56
57 writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
58 }
59
60 f.write_fmt(format_args!("}};\n"))
61 }
62}
63
64impl DialectWarpReduceCompiler<Self> for MslDialect {
65 fn warp_reduce_sum(
66 f: &mut core::fmt::Formatter<'_>,
67 input: &Variable<Self>,
68 out: &Variable<Self>,
69 ) -> core::fmt::Result {
70 Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
71 }
72 fn warp_reduce_prod(
73 f: &mut core::fmt::Formatter<'_>,
74 input: &Variable<Self>,
75 out: &Variable<Self>,
76 ) -> core::fmt::Result {
77 Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
78 }
79 fn warp_reduce_max(
80 f: &mut core::fmt::Formatter<'_>,
81 input: &Variable<Self>,
82 out: &Variable<Self>,
83 ) -> core::fmt::Result {
84 Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
85 }
86 fn warp_reduce_min(
87 f: &mut core::fmt::Formatter<'_>,
88 input: &Variable<Self>,
89 out: &Variable<Self>,
90 ) -> core::fmt::Result {
91 Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
92 }
93 fn warp_reduce_all(
94 f: &mut core::fmt::Formatter<'_>,
95 input: &Variable<Self>,
96 out: &Variable<Self>,
97 ) -> core::fmt::Result {
98 Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
99 }
100 fn warp_reduce_any(
101 f: &mut core::fmt::Formatter<'_>,
102 input: &Variable<Self>,
103 out: &Variable<Self>,
104 ) -> core::fmt::Result {
105 Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
106 }
107 fn warp_reduce_sum_inclusive(
108 f: &mut core::fmt::Formatter<'_>,
109 input: &Variable<Self>,
110 out: &Variable<Self>,
111 ) -> core::fmt::Result {
112 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
113 }
114 fn warp_reduce_prod_inclusive(
115 f: &mut core::fmt::Formatter<'_>,
116 input: &Variable<Self>,
117 out: &Variable<Self>,
118 ) -> core::fmt::Result {
119 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
120 }
121 fn warp_reduce_sum_exclusive(
122 f: &mut core::fmt::Formatter<'_>,
123 input: &Variable<Self>,
124 out: &Variable<Self>,
125 ) -> core::fmt::Result {
126 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
127 }
128 fn warp_reduce_prod_exclusive(
129 f: &mut core::fmt::Formatter<'_>,
130 input: &Variable<Self>,
131 out: &Variable<Self>,
132 ) -> core::fmt::Result {
133 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
134 }
135}
136
137impl DialectIncludes<Self> for MslDialect {
140 type Extension = Extension<Self>;
141
142 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
143 write!(
144 f,
145 "
146#include <metal_stdlib>
147using namespace metal;
148"
149 )?;
150 Ok(())
151 }
152
153 fn compile_extensions(
154 f: &mut std::fmt::Formatter<'_>,
155 extensions: &[Self::Extension],
156 ) -> std::fmt::Result {
157 for extension in extensions {
158 match extension {
159 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
160 Extension::Ffs(elem) => format_ffs(f, elem)?,
161 Extension::MulHi(elem) => format_mulhi(f, elem)?,
162 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
163 Extension::NoExtension => {}
164 }
165 }
166 Ok(())
167 }
168
169 fn register_instruction_extension(
170 extensions: &mut Vec<Self::Extension>,
171 instruction: &Instruction<Self>,
172 ) {
173 let mut register_extension = |extension: Self::Extension| {
174 if !extensions.contains(&extension) {
175 extensions.push(extension);
176 }
177 };
178 #[allow(clippy::single_match)]
179 match instruction {
180 shared::Instruction::<Self>::Erf(instruction) => {
181 register_extension(Extension::Erf(
182 instruction.input.elem(),
183 instruction.out.elem(),
184 ));
185 }
186 shared::Instruction::<Self>::FindFirstSet(instruction) => {
187 let input_elem = instruction.input.elem();
188 match input_elem {
189 Elem::U32 | Elem::U64 => {
190 register_extension(Extension::Ffs(instruction.input.elem()));
191 }
192 Elem::I32 => {
193 register_extension(Extension::Ffs(Elem::<Self>::U32));
194 register_extension(Extension::Ffs(instruction.input.elem()));
195 }
196 Elem::I64 => {
197 register_extension(Extension::Ffs(Elem::<Self>::U64));
198 register_extension(Extension::Ffs(instruction.input.elem()));
199 }
200 _ => {
201 register_extension(Extension::Ffs(Elem::<Self>::U32));
202 }
203 }
204 }
205 shared::Instruction::<Self>::HiMul(instruction) => {
206 register_extension(Extension::MulHi(instruction.out.elem()));
207 }
208 shared::Instruction::<Self>::Tanh(instruction) => {
209 register_extension(Extension::SafeTanh(instruction.input.item()));
210 }
211 _ => {}
212 }
213 }
214
215 fn register_warp_instruction_extension(
216 _extensions: &mut Vec<Self::Extension>,
217 _instruction: &WarpInstruction<Self>,
218 ) {
219 }
220}
221
222impl DialectTypes<Self> for MslDialect {
225 fn item_can_be_optimized() -> bool {
226 false
227 }
228
229 fn compile_type_definitions(
230 f: &mut std::fmt::Formatter<'_>,
231 items: &std::collections::HashSet<crate::shared::Item<Self>>,
232 _scalars: &[(Elem<Self>, usize)],
233 _flags: &Flags,
234 ) -> std::fmt::Result {
235 for item in items.iter() {
236 let elem = item.elem;
237 let size = item.vectorization;
238 let alignment = elem.size() * size;
239 if size > 1 {
240 write!(
241 f,
242 "
243struct alignas({alignment}) {item} {{"
244 )?;
245
246 for i in 0..size {
247 write!(
248 f,
249 "
250 {elem} i_{i};"
251 )?;
252 }
253
254 f.write_str("\n};\n")?;
255 }
256 }
257 Ok(())
258 }
259
260 fn compile_elem(
261 f: &mut std::fmt::Formatter<'_>,
262 elem: &shared::Elem<Self>,
263 _words: bool,
264 ) -> std::fmt::Result {
265 match elem {
267 shared::Elem::FP4(_)
268 | shared::Elem::FP4x2(_)
269 | shared::Elem::FP6(_)
270 | shared::Elem::FP6x2(_)
271 | shared::Elem::FP8(_)
272 | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in Metal"),
273 shared::Elem::F16 => f.write_str("half"),
274 shared::Elem::F16x2 => panic!("type F162 not supported!"),
275 shared::Elem::F32 => f.write_str("float"),
276 shared::Elem::F64 => panic!("type double not supported!"),
277 shared::Elem::BF16 => f.write_str("bfloat"),
278 shared::Elem::BF16x2 => panic!("type BF162 not supported!"),
279 shared::Elem::TF32 => f.write_str("float"),
280 shared::Elem::I8 => f.write_str("char"),
281 shared::Elem::I16 => f.write_str("short"),
282 shared::Elem::I32 => f.write_str("int"),
283 shared::Elem::I64 => f.write_str("long"),
284 shared::Elem::U8 => f.write_str("uchar"),
285 shared::Elem::U16 => f.write_str("ushort"),
286 shared::Elem::U32 => f.write_str("uint"),
287 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
289 shared::Elem::Atomic(inner) => inner.fmt(f),
290 shared::Elem::_Dialect(_) => Ok(()),
291 }
292 }
293
294 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
295 if 1 == item.vectorization {
296 return write!(f, "{}", item.elem);
297 }
298 if item.native {
299 write!(f, "{}{}", item.elem, item.vectorization)
300 } else {
301 write!(f, "{}_{}", item.elem, item.vectorization)
302 }
303 }
304
305 fn compile_atomic_kind(
306 f: &mut std::fmt::Formatter<'_>,
307 kind: &AtomicKind<Self>,
308 ) -> std::fmt::Result {
309 match kind {
310 AtomicKind::I32 => write!(f, "atomic_int"),
311 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
312 AtomicKind::U32 => write!(f, "atomic_uint"),
313 AtomicKind::U64 => write!(f, "atomic_ulong"),
314 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
315 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
316 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
318 AtomicKind::_Dialect(_) => Ok(()),
319 }
320 }
321
322 fn address_space_for_variable(variable: &Variable<Self>) -> String {
323 format!("{} ", AddressSpace::from(variable))
324 }
325
326 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 write!(f, "thread")
328 }
329
330 fn compile_shared_memory_declaration(
331 f: &mut std::fmt::Formatter<'_>,
332 shared: &SharedMemory<Self>,
333 ) -> std::fmt::Result {
334 let item = shared.item;
335 let index = shared.index;
336 let offset = shared.offset;
337 let size = shared.length;
338 let size_bytes = size * shared.item.size() as u32;
339 writeln!(f, "// Shared memory size: {size}, {size_bytes} bytes")?;
340 writeln!(
341 f,
342 "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
343 )
344 }
345}
346
347impl DialectBindings<Self> for MslDialect {
350 fn compile_kernel_signature(
351 f: &mut std::fmt::Formatter<'_>,
352 kernel_name: &str,
353 tensor_maps: &[Binding<Self>],
354 buffers: &[Binding<Self>],
355 scalars: &[(Elem<Self>, usize)],
356 flags: &Flags,
357 ) -> std::fmt::Result {
358 write!(
359 (f),
360 "
361[[kernel]]
362void {kernel_name}("
363 )?;
364 let mut buffer_idx = 0;
366 debug_assert!(
367 tensor_maps.is_empty(),
368 "Tensor maps aren't supported for metal"
369 );
370 for (i, b) in buffers.iter().enumerate() {
371 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
372 }
373 if flags.static_meta_length > 0 {
374 let binding = Binding {
375 id: 0,
376 item: Item::scalar(Elem::<Self>::U32, true),
377 location: Location::Storage,
378 size: None,
379 vis: Visibility::Read,
380 };
381 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
382 }
383 for (elem, _) in scalars.iter() {
384 let binding = Binding {
385 id: 0,
386 item: Item::scalar(*elem, true),
387 location: Location::Storage,
388 size: None,
389 vis: Visibility::Read,
390 };
391
392 let name = format!("scalars_{elem}");
393 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
394 }
395
396 let builtins = vec![
398 (
399 flags.indexes.absolute_pos_tuple,
400 Variable::<Self>::AbsolutePosBaseName,
401 ),
402 (
403 flags.indexes.cube_dim_tuple,
404 Variable::<Self>::CubeDimBaseName,
405 ),
406 (
407 flags.indexes.cube_count_tuple,
408 Variable::<Self>::CubeCountBaseName,
409 ),
410 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
411 (
412 flags.indexes.unit_pos_tuple,
413 Variable::<Self>::UnitPosBaseName,
414 ),
415 (
416 flags.indexes.cube_pos_tuple,
417 Variable::<Self>::CubePosBaseName,
418 ),
419 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
420 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
421 (flags.indexes.plane_index, Variable::<Self>::PlanePos),
422 ];
423 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
424 builtins
425 .iter()
426 .filter(|(cond, _)| *cond)
427 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
428 f.write_str("\n)")
429 }
430
431 fn compile_bindings_body(
432 f: &mut std::fmt::Formatter<'_>,
433 body: &shared::Body<Self>,
434 ) -> std::fmt::Result {
435 if !body.shared_memories.is_empty() {
436 let size = body
437 .shared_memories
438 .iter()
439 .map(|it| it.offset + it.size())
440 .max()
441 .unwrap();
442
443 writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
444 }
445 Ok(())
446 }
447}
448
449impl DialectCubeBuiltins<Self> for MslDialect {
452 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
457 let absolute_pos = flags.absolute_pos;
458 let cube_count = flags.cube_count;
459 let cube_dim = flags.cube_dim;
460 let cube_pos = flags.cube_pos;
461 let plane_dim_checked = flags.plane_dim_checked;
462 let plane_index = flags.plane_index;
463 let unit_pos = flags.unit_pos;
464 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
465 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
466 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
467 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
468 let cluster_pos = flags.cluster_pos;
469 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
470 let unit_pos_plane = flags.unit_pos_plane || plane_index;
471 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
472 CubeIndexFlags {
473 absolute_pos_tuple,
474 absolute_pos,
475 cube_count_tuple,
476 cube_count,
477 cube_dim_tuple,
478 cube_dim,
479 cube_pos_tuple,
480 cube_pos,
481 plane_dim,
482 plane_dim_checked,
483 plane_index,
484 unit_pos_tuple,
485 unit_pos,
486 unit_pos_plane,
487 cluster_pos,
488 }
489 }
490
491 fn compile_absolute_pos_tuple_computation(
492 _f: &mut std::fmt::Formatter<'_>,
493 ) -> std::fmt::Result {
494 Ok(())
496 }
497
498 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499 f.write_str("thread_pos_in_grid")
500 }
501
502 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503 f.write_str("thread_index_in_grid")
504 }
505
506 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 Self::compile_absolute_pos_base_name(f)?;
508 write!(f, ".x")
509 }
510
511 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512 Self::compile_absolute_pos_base_name(f)?;
513 write!(f, ".y")
514 }
515
516 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517 Self::compile_absolute_pos_base_name(f)?;
518 write!(f, ".z")
519 }
520
521 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 f.write_str("threadgroups_per_grid")
523 }
524
525 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 f.write_str("total_threadgroups_in_grid")
527 }
528
529 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530 Self::compile_cube_count_base_name(f)?;
531 write!(f, ".x")
532 }
533
534 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 Self::compile_cube_count_base_name(f)?;
536 write!(f, ".y")
537 }
538
539 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540 Self::compile_cube_count_base_name(f)?;
541 write!(f, ".z")
542 }
543
544 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545 f.write_str("threads_per_threadgroup")
546 }
547
548 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 f.write_str("total_thread_in_threadgroup")
550 }
551
552 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 Self::compile_cube_dim_base_name(f)?;
554 write!(f, ".x")
555 }
556
557 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558 Self::compile_cube_dim_base_name(f)?;
559 write!(f, ".y")
560 }
561
562 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
563 Self::compile_cube_dim_base_name(f)?;
564 write!(f, ".z")
565 }
566
567 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568 f.write_str("threadgroup_pos_in_grid")
569 }
570
571 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
572 f.write_str("threadgroup_index_in_grid")
573 }
574
575 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576 Self::compile_cube_pos_base_name(f)?;
577 write!(f, ".x")
578 }
579
580 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 Self::compile_cube_pos_base_name(f)?;
582 write!(f, ".y")
583 }
584
585 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 Self::compile_cube_pos_base_name(f)?;
587 write!(f, ".z")
588 }
589
590 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 Ok(())
593 }
594
595 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 f.write_str("thread_pos_in_threadgroup")
597 }
598
599 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
600 f.write_str("thread_index_in_threadgroup")
601 }
602
603 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604 Self::compile_unit_pos_base_name(f)?;
605 write!(f, ".x")
606 }
607
608 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 Self::compile_unit_pos_base_name(f)?;
610 write!(f, ".y")
611 }
612
613 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614 Self::compile_unit_pos_base_name(f)?;
615 write!(f, ".z")
616 }
617
618 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
619 f.write_str("simd_size")
620 }
621
622 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623 f.write_str("threads_per_simdgroup_checked")
624 }
625
626 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
627 f.write_str("simd_group_id")
628 }
629
630 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 f.write_str("simd_lane_id")
632 }
633}
634
635impl DialectInstructions<Self> for MslDialect {
638 fn compile_atomic_add(
640 f: &mut std::fmt::Formatter<'_>,
641 lhs: &Variable<Self>,
642 rhs: &Variable<Self>,
643 out: &Variable<Self>,
644 ) -> std::fmt::Result {
645 let out = out.fmt_left();
646 writeln!(
647 f,
648 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
649 )
650 }
651
652 fn compile_atomic_and(
653 f: &mut std::fmt::Formatter<'_>,
654 lhs: &Variable<Self>,
655 rhs: &Variable<Self>,
656 out: &Variable<Self>,
657 ) -> std::fmt::Result {
658 let out = out.fmt_left();
659 writeln!(
660 f,
661 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
662 )
663 }
664
665 fn compile_atomic_cas(
666 f: &mut std::fmt::Formatter<'_>,
667 input: &Variable<Self>,
668 cmp: &Variable<Self>,
669 val: &Variable<Self>,
670 out: &Variable<Self>,
671 ) -> std::fmt::Result {
672 let out = out.fmt_left();
673 writeln!(
674 f,
675 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
676 )
677 }
678
679 fn compile_atomic_load(
680 f: &mut std::fmt::Formatter<'_>,
681 input: &Variable<Self>,
682 out: &Variable<Self>,
683 ) -> std::fmt::Result {
684 let out = out.fmt_left();
685 writeln!(
686 f,
687 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
688 )
689 }
690
691 fn compile_atomic_max(
692 f: &mut std::fmt::Formatter<'_>,
693 lhs: &Variable<Self>,
694 rhs: &Variable<Self>,
695 out: &Variable<Self>,
696 ) -> std::fmt::Result {
697 let out = out.fmt_left();
698 writeln!(
699 f,
700 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
701 )
702 }
703
704 fn compile_atomic_min(
705 f: &mut std::fmt::Formatter<'_>,
706 lhs: &Variable<Self>,
707 rhs: &Variable<Self>,
708 out: &Variable<Self>,
709 ) -> std::fmt::Result {
710 let out = out.fmt_left();
711 writeln!(
712 f,
713 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
714 )
715 }
716
717 fn compile_atomic_or(
718 f: &mut std::fmt::Formatter<'_>,
719 lhs: &Variable<Self>,
720 rhs: &Variable<Self>,
721 out: &Variable<Self>,
722 ) -> std::fmt::Result {
723 let out = out.fmt_left();
724 writeln!(
725 f,
726 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
727 )
728 }
729
730 fn compile_atomic_store(
731 f: &mut std::fmt::Formatter<'_>,
732 input: &Variable<Self>,
733 out: &Variable<Self>,
734 ) -> std::fmt::Result {
735 writeln!(
736 f,
737 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
738 )
739 }
740
741 fn compile_atomic_sub(
742 f: &mut std::fmt::Formatter<'_>,
743 lhs: &Variable<Self>,
744 rhs: &Variable<Self>,
745 out: &Variable<Self>,
746 ) -> std::fmt::Result {
747 let out = out.fmt_left();
748 writeln!(
749 f,
750 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
751 )
752 }
753
754 fn compile_atomic_swap(
755 f: &mut std::fmt::Formatter<'_>,
756 lhs: &Variable<Self>,
757 rhs: &Variable<Self>,
758 out: &Variable<Self>,
759 ) -> std::fmt::Result {
760 let out = out.fmt_left();
761 writeln!(
762 f,
763 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
764 )
765 }
766
767 fn compile_atomic_xor(
768 f: &mut std::fmt::Formatter<'_>,
769 lhs: &Variable<Self>,
770 rhs: &Variable<Self>,
771 out: &Variable<Self>,
772 ) -> std::fmt::Result {
773 let out = out.fmt_left();
774 writeln!(
775 f,
776 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
777 )
778 }
779
780 fn compile_saturating_add(
781 f: &mut std::fmt::Formatter<'_>,
782 lhs: impl Display,
783 rhs: impl Display,
784 _item: Item<Self>,
785 ) -> std::fmt::Result {
786 write!(f, "addsat({lhs}, {rhs})")
787 }
788
789 fn compile_saturating_sub(
790 f: &mut std::fmt::Formatter<'_>,
791 lhs: impl Display,
792 rhs: impl Display,
793 _item: Item<Self>,
794 ) -> std::fmt::Result {
795 write!(f, "subsat({lhs}, {rhs})")
796 }
797
798 fn compile_instruction_printf(
800 f: &mut std::fmt::Formatter<'_>,
801 format_string: &str,
802 args: &[Variable<Self>],
803 ) -> std::fmt::Result {
804 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
805 let args = match args.is_empty() {
806 true => "".to_string(),
807 false => format!(", {}", args.join(",")),
808 };
809 writeln!(f, "os_log_default.log({format_string:?}{args});")
810 }
811
812 fn compile_instruction_log1p_scalar<T: Component<Self>>(
814 f: &mut std::fmt::Formatter<'_>,
815 input: T,
816 ) -> std::fmt::Result {
817 match input.elem() {
818 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
819 write!(f, "log(half(1.0f) + {input})")
820 }
821 _ => write!(f, "log(1.0f + {input})"),
822 }
823 }
824
825 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
828 }
829
830 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
831 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
832 }
833
834 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
835 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
836 }
837
838 fn compile_instruction_tanh_scalar<T: Component<Self>>(
840 f: &mut std::fmt::Formatter<'_>,
841 input: T,
842 ) -> std::fmt::Result {
843 write!(f, "safe_tanh_scalar({input})")
844 }
845
846 fn compile_instruction_find_first_set<T: Component<Self>>(
848 f: &mut std::fmt::Formatter<'_>,
849 input: T,
850 out_elem: Elem<Self>,
851 ) -> std::fmt::Result {
852 write!(f, "{out_elem}(")?;
853 match input.elem() {
854 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
855 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
856 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
857 }?;
858 write!(f, ")")
859 }
860
861 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
862 f: &mut std::fmt::Formatter<'_>,
863 input: T,
864 out_elem: Elem<Self>,
865 ) -> std::fmt::Result {
866 write!(f, "{out_elem}(clz({input}))")
867 }
868
869 fn compile_instruction_popcount_scalar<T: Component<Self>>(
870 f: &mut std::fmt::Formatter<'_>,
871 input: T,
872 out_elem: Elem<Self>,
873 ) -> std::fmt::Result {
874 write!(f, "{out_elem}(")?;
875 match input.elem() {
876 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
877 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
878 }?;
879 write!(f, ")")
880 }
881
882 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
883 f: &mut std::fmt::Formatter<'_>,
884 input: T,
885 out_elem: Elem<Self>,
886 ) -> std::fmt::Result {
887 write!(f, "{out_elem}(")?;
888 match out_elem {
889 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
890 _ => write!(
891 f,
892 "reverse_bits({}) >> {}",
893 shared::unary::zero_extend(input),
894 (size_of::<u32>() - out_elem.size()) * 8
895 ),
896 }?;
897 write!(f, ")")
898 }
899
900 fn compile_instruction_max_function_name(
902 f: &mut std::fmt::Formatter<'_>,
903 _item: Item<Self>,
904 ) -> std::fmt::Result {
905 write!(f, "max")
906 }
907
908 fn compile_instruction_min_function_name(
909 f: &mut std::fmt::Formatter<'_>,
910 _item: Item<Self>,
911 ) -> std::fmt::Result {
912 write!(f, "min")
913 }
914
915 fn compile_instruction_powf(
916 f: &mut std::fmt::Formatter<'_>,
917 lhs: &str,
918 rhs: &str,
919 elem: Elem<Self>,
920 ) -> std::fmt::Result {
921 write!(f, "pow({lhs}, {elem}({rhs}))")
922 }
923
924 fn compile_instruction_half_function_name_prefix() -> &'static str {
925 ""
926 }
927
928 fn compile_instruction_half2_function_name_prefix() -> &'static str {
929 ""
930 }
931
932 fn compile_warp_shuffle(
934 f: &mut std::fmt::Formatter<'_>,
935 var: &str,
936 source: &str,
937 ) -> std::fmt::Result {
938 write!(f, "simd_shuffle({var}, {source})")
939 }
940
941 fn compile_warp_shuffle_xor(
942 f: &mut std::fmt::Formatter<'_>,
943 var: &str,
944 _elem: &Elem<Self>,
945 offset: &str,
946 ) -> std::fmt::Result {
947 write!(f, "simd_shuffle_xor({var}, {offset})")
948 }
949
950 fn compile_warp_shuffle_up(
951 f: &mut std::fmt::Formatter<'_>,
952 var: &str,
953 offset: &str,
954 ) -> std::fmt::Result {
955 write!(f, "simd_shuffle_up({var}, {offset})")
956 }
957
958 fn compile_warp_shuffle_down(
959 f: &mut std::fmt::Formatter<'_>,
960 var: &str,
961 offset: &str,
962 ) -> std::fmt::Result {
963 write!(f, "simd_shuffle_down({var}, {offset})")
964 }
965
966 fn compile_warp_all<T: Component<Self>>(
967 f: &mut std::fmt::Formatter<'_>,
968 input: &T,
969 ) -> std::fmt::Result {
970 write!(f, "simd_all({input})")
971 }
972
973 fn compile_warp_any<T: Component<Self>>(
974 f: &mut std::fmt::Formatter<'_>,
975 input: &T,
976 ) -> std::fmt::Result {
977 write!(f, "simd_any({input})")
978 }
979
980 fn compile_warp_ballot(
981 f: &mut std::fmt::Formatter<'_>,
982 input: &Variable<Self>,
983 out_elem: &Elem<Self>,
984 ) -> std::fmt::Result {
985 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
986 }
987}
988
989impl DialectWmmaCompiler<Self> for MslDialect {
992 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
993 writeln!(f, "#include <metal_simdgroup_matrix>")
994 }
995
996 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
997 Ok(())
999 }
1000
1001 fn compile_wmma_fragment_declaration(
1002 f: &mut std::fmt::Formatter<'_>,
1003 var: &crate::shared::Variable<MslDialect>,
1004 ) -> std::fmt::Result {
1005 wmma_api_base::compile_fragment_declaration(f, var)
1006 }
1007
1008 fn compile_wwma_fragment_ident(
1009 _f: &mut std::fmt::Formatter<'_>,
1010 _ident: &FragmentIdent<Self>,
1011 ) -> std::fmt::Result {
1012 Ok(())
1014 }
1015
1016 fn compile_wmma_fragment_layout(
1017 _f: &mut std::fmt::Formatter<'_>,
1018 _layout: &FragmentLayout<Self>,
1019 ) -> std::fmt::Result {
1020 Ok(())
1022 }
1023
1024 fn compile_wmma_fragment(
1025 f: &mut std::fmt::Formatter<'_>,
1026 fragment: &Fragment<Self>,
1027 ) -> std::fmt::Result {
1028 let ty = fragment.elem;
1029 let m = fragment.m;
1031 let n = fragment.n;
1032 let k = fragment.k;
1033 if m != 8 || n != 8 || k != 8 {
1034 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1035 }
1036 write!(f, "simdgroup_{ty}8x8")
1037 }
1038
1039 fn compile_wmma_instruction(
1040 f: &mut std::fmt::Formatter<'_>,
1041 instruction: &WmmaInstruction<Self>,
1042 ) -> std::fmt::Result {
1043 match instruction {
1044 WmmaInstruction::Fill { frag, value } => {
1045 match frag {
1046 Variable::WmmaFragment { .. } => {
1047 let ty = frag.elem();
1048 writeln!(
1050 f,
1051 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1052 )
1053 }
1054 _ => panic!("should be a fragment"),
1055 }
1056 }
1057 WmmaInstruction::Load {
1058 frag,
1059 value,
1060 stride,
1061 offset,
1062 layout: _layout,
1063 } => {
1064 let transpose = match frag {
1065 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1066 Some(FragmentLayout::RowMajor) => false,
1067 Some(FragmentLayout::ColMajor) => true,
1068 _ => false,
1069 },
1070 _ => panic!("should be a fragment"),
1071 };
1072 let item = value.item();
1073 if item.vectorization > 1 {
1074 let elem = item.elem;
1075 writeln!(
1076 f,
1077 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1078 )
1079 } else {
1080 writeln!(
1081 f,
1082 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1083 )
1084 }
1085 }
1086 WmmaInstruction::Execute {
1087 frag_a: a,
1088 frag_b: b,
1089 frag_c: c,
1090 frag_d: d,
1091 ..
1092 } => {
1093 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1094 }
1095 WmmaInstruction::Store {
1096 output,
1097 frag,
1098 stride,
1099 offset,
1100 layout: _layout,
1101 } => {
1102 let item = output.item();
1103 let mut reinterpret_cast = item.vectorization > 1;
1104 let elem = match item.elem {
1105 Elem::BF16 => {
1106 reinterpret_cast = true;
1107 Elem::F16
1108 }
1109 _ => item.elem,
1110 };
1111 if reinterpret_cast {
1112 writeln!(
1113 f,
1114 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1115 )
1116 } else {
1117 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1118 }?;
1119 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1120 }
1121 WmmaInstruction::Cast { input, output } => {
1122 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1123 let ty = match output {
1124 Variable::WmmaFragment { frag, .. } => frag.elem,
1125 _ => panic!("should be a fragment"),
1126 };
1127 match ty {
1128 Elem::BF16 => {
1129 let addr_space = Self::address_space_for_variable(output);
1130 let elem = Elem::<Self>::F16;
1131 writeln!(
1134 f,
1135 "for(int e=0; e<8; e++) {{
1136 {ty} elem = {ty}({input}.thread_elements()[e]);
1137 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1138}}"
1139 )
1140 }
1141 _ => {
1142 writeln!(
1143 f,
1144 "for(int e=0; e<8; e++) {{
1145 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1146}}"
1147 )
1148 }
1149 }
1150 }
1151 WmmaInstruction::ExecuteManual {
1152 shape,
1153 frag_a,
1154 frag_b,
1155 frag_c,
1156 frag_d,
1157 } => {
1158 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1159 }
1160 WmmaInstruction::ExecuteScaled {
1161 shape,
1162 frag_a,
1163 frag_b,
1164 frag_c,
1165 frag_d,
1166 scales_a,
1167 scales_b,
1168 scales_factor,
1169 } => Self::compile_scaled_mma(
1170 f,
1171 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1172 *scales_a,
1173 *scales_b,
1174 *scales_factor,
1175 ),
1176 }
1177 }
1178
1179 fn compile_manual_mma(
1180 _f: &mut std::fmt::Formatter<'_>,
1181 _mma: shared::ManualMma<Self>,
1182 ) -> std::fmt::Result {
1183 unimplemented!("Not supported")
1184 }
1185
1186 fn compile_scaled_mma(
1187 _f: &mut std::fmt::Formatter<'_>,
1188 _mma: shared::ManualMma<Self>,
1189 _scales_a: Variable<Self>,
1190 _scales_b: Variable<Self>,
1191 _scales_factor: u32,
1192 ) -> std::fmt::Result {
1193 unimplemented!("Not supported")
1194 }
1195
1196 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1197 let types = vec![
1198 (
1199 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1200 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1201 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1202 ),
1203 (
1204 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1205 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1206 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1207 ),
1208 (
1209 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1210 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1211 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1212 ),
1213 (
1214 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1215 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1216 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1217 ),
1218 ];
1219 types
1220 .into_iter()
1221 .map(|(a_type, b_type, cd_type)| MmaConfig {
1222 a_type,
1223 b_type,
1224 cd_type,
1225 m: 8,
1226 n: 8,
1227 k: 8,
1228 })
1229 .collect()
1230 }
1231
1232 fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1233 Vec::new()
1234 }
1235}
1236
1237impl DialectProcessors<Self> for MslDialect {
1240 fn processors() -> Vec<Box<dyn gpu::Processor>> {
1241 Vec::new()
1242 }
1243}