1use core::panic;
2use std::fmt::Display;
3
4use crate::{
5 Dialect,
6 shared::{
7 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
8 DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
9 DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
10 FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
11 SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
12 },
13};
14use cubecl_core::{
15 ir::{self as gpu},
16 prelude::{Location, Visibility},
17};
18use cubecl_runtime::MmaConfig;
19
20use super::{
21 AddressSpace, Extension,
22 arch::MetalArchitecture,
23 extension::{format_ffs, format_mulhi},
24 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
25};
26
27#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
28pub struct MslDialect {}
29
30impl Dialect for MslDialect {
33 type Architecture = MetalArchitecture;
34}
35
36impl MslDialect {
37 fn warp_op_vectorized(
38 f: &mut core::fmt::Formatter<'_>,
39 input: &Variable<Self>,
40 out: &Variable<Self>,
41 simd_op_prefix: &str,
42 simd_op_suffix: &str,
43 ) -> core::fmt::Result {
44 let out = out.fmt_left();
45 let vectorization = input.item().vectorization;
46
47 f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
48
49 for k in 0..vectorization {
50 let index = if vectorization > 1 {
51 format!(".i_{k}")
52 } else {
53 String::new()
54 };
55 let comma = if k + 1 < vectorization { "," } else { "" };
56
57 writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
58 }
59
60 f.write_fmt(format_args!("}};\n"))
61 }
62}
63
64impl DialectWarpReduceCompiler<Self> for MslDialect {
65 fn warp_reduce_sum(
66 f: &mut core::fmt::Formatter<'_>,
67 input: &Variable<Self>,
68 out: &Variable<Self>,
69 ) -> core::fmt::Result {
70 Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
71 }
72 fn warp_reduce_prod(
73 f: &mut core::fmt::Formatter<'_>,
74 input: &Variable<Self>,
75 out: &Variable<Self>,
76 ) -> core::fmt::Result {
77 Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
78 }
79 fn warp_reduce_max(
80 f: &mut core::fmt::Formatter<'_>,
81 input: &Variable<Self>,
82 out: &Variable<Self>,
83 ) -> core::fmt::Result {
84 Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
85 }
86 fn warp_reduce_min(
87 f: &mut core::fmt::Formatter<'_>,
88 input: &Variable<Self>,
89 out: &Variable<Self>,
90 ) -> core::fmt::Result {
91 Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
92 }
93 fn warp_reduce_all(
94 f: &mut core::fmt::Formatter<'_>,
95 input: &Variable<Self>,
96 out: &Variable<Self>,
97 ) -> core::fmt::Result {
98 Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
99 }
100 fn warp_reduce_any(
101 f: &mut core::fmt::Formatter<'_>,
102 input: &Variable<Self>,
103 out: &Variable<Self>,
104 ) -> core::fmt::Result {
105 Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
106 }
107 fn warp_reduce_sum_inclusive(
108 f: &mut core::fmt::Formatter<'_>,
109 input: &Variable<Self>,
110 out: &Variable<Self>,
111 ) -> core::fmt::Result {
112 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
113 }
114 fn warp_reduce_prod_inclusive(
115 f: &mut core::fmt::Formatter<'_>,
116 input: &Variable<Self>,
117 out: &Variable<Self>,
118 ) -> core::fmt::Result {
119 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
120 }
121 fn warp_reduce_sum_exclusive(
122 f: &mut core::fmt::Formatter<'_>,
123 input: &Variable<Self>,
124 out: &Variable<Self>,
125 ) -> core::fmt::Result {
126 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
127 }
128 fn warp_reduce_prod_exclusive(
129 f: &mut core::fmt::Formatter<'_>,
130 input: &Variable<Self>,
131 out: &Variable<Self>,
132 ) -> core::fmt::Result {
133 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
134 }
135}
136
137impl DialectIncludes<Self> for MslDialect {
140 type Extension = Extension<Self>;
141
142 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
143 write!(
144 f,
145 "
146#include <metal_stdlib>
147using namespace metal;
148"
149 )?;
150 Ok(())
151 }
152
153 fn compile_extensions(
154 f: &mut std::fmt::Formatter<'_>,
155 extensions: &[Self::Extension],
156 ) -> std::fmt::Result {
157 for extension in extensions {
158 match extension {
159 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
160 Extension::Ffs(elem) => format_ffs(f, elem)?,
161 Extension::MulHi(elem) => format_mulhi(f, elem)?,
162 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
163 Extension::NoExtension => {}
164 }
165 }
166 Ok(())
167 }
168
169 fn register_instruction_extension(
170 extensions: &mut Vec<Self::Extension>,
171 instruction: &Instruction<Self>,
172 ) {
173 let mut register_extension = |extension: Self::Extension| {
174 if !extensions.contains(&extension) {
175 extensions.push(extension);
176 }
177 };
178 #[allow(clippy::single_match)]
179 match instruction {
180 shared::Instruction::<Self>::Erf(instruction) => {
181 register_extension(Extension::Erf(
182 instruction.input.elem(),
183 instruction.out.elem(),
184 ));
185 }
186 shared::Instruction::<Self>::FindFirstSet(instruction) => {
187 let input_elem = instruction.input.elem();
188 match input_elem {
189 Elem::U32 | Elem::U64 => {
190 register_extension(Extension::Ffs(instruction.input.elem()));
191 }
192 Elem::I32 => {
193 register_extension(Extension::Ffs(Elem::<Self>::U32));
194 register_extension(Extension::Ffs(instruction.input.elem()));
195 }
196 Elem::I64 => {
197 register_extension(Extension::Ffs(Elem::<Self>::U64));
198 register_extension(Extension::Ffs(instruction.input.elem()));
199 }
200 _ => {
201 register_extension(Extension::Ffs(Elem::<Self>::U32));
202 }
203 }
204 }
205 shared::Instruction::<Self>::HiMul(instruction) => {
206 register_extension(Extension::MulHi(instruction.out.elem()));
207 }
208 shared::Instruction::<Self>::Tanh(instruction) => {
209 register_extension(Extension::SafeTanh(instruction.input.item()));
210 }
211 _ => {}
212 }
213 }
214
215 fn register_warp_instruction_extension(
216 _extensions: &mut Vec<Self::Extension>,
217 _instruction: &WarpInstruction<Self>,
218 ) {
219 }
220}
221
222impl DialectTypes<Self> for MslDialect {
225 fn item_can_be_optimized() -> bool {
226 false
227 }
228
229 fn compile_type_definitions(
230 f: &mut std::fmt::Formatter<'_>,
231 items: &std::collections::HashSet<crate::shared::Item<Self>>,
232 _scalars: &[(Elem<Self>, usize)],
233 _flags: &Flags,
234 ) -> std::fmt::Result {
235 for item in items.iter() {
236 let elem = item.elem;
237 let size = item.vectorization;
238 let alignment = elem.size() * size;
239 if size > 1 {
240 write!(
241 f,
242 "
243struct alignas({alignment}) {item} {{"
244 )?;
245
246 for i in 0..size {
247 write!(
248 f,
249 "
250 {elem} i_{i};"
251 )?;
252 }
253
254 f.write_str("\n};\n")?;
255 }
256 }
257 Ok(())
258 }
259
260 fn compile_elem(
261 f: &mut std::fmt::Formatter<'_>,
262 elem: &shared::Elem<Self>,
263 _words: bool,
264 ) -> std::fmt::Result {
265 match elem {
267 shared::Elem::FP4(_)
268 | shared::Elem::FP4x2(_)
269 | shared::Elem::FP6(_)
270 | shared::Elem::FP6x2(_)
271 | shared::Elem::FP8(_)
272 | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
273 shared::Elem::F16 => f.write_str("half"),
274 shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
275 shared::Elem::F32 => f.write_str("float"),
276 shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
277 shared::Elem::BF16 => f.write_str("bfloat"),
278 shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
279 shared::Elem::TF32 => f.write_str("float"),
280 shared::Elem::I8 => f.write_str("char"),
281 shared::Elem::I16 => f.write_str("short"),
282 shared::Elem::I32 => f.write_str("int"),
283 shared::Elem::I64 => f.write_str("long"),
284 shared::Elem::U8 => f.write_str("uchar"),
285 shared::Elem::U16 => f.write_str("ushort"),
286 shared::Elem::U32 => f.write_str("uint"),
287 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
289 shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
290 shared::Elem::Atomic(inner) => inner.fmt(f),
291 shared::Elem::_Dialect(_) => Ok(()),
292 }
293 }
294
295 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
296 if 1 == item.vectorization {
297 return write!(f, "{}", item.elem);
298 }
299 if item.native {
300 write!(f, "{}{}", item.elem, item.vectorization)
301 } else {
302 write!(f, "{}_{}", item.elem, item.vectorization)
303 }
304 }
305
306 fn compile_atomic_kind(
307 f: &mut std::fmt::Formatter<'_>,
308 kind: &AtomicKind<Self>,
309 ) -> std::fmt::Result {
310 match kind {
311 AtomicKind::I32 => write!(f, "atomic_int"),
312 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
313 AtomicKind::U32 => write!(f, "atomic_uint"),
314 AtomicKind::U64 => write!(f, "atomic_ulong"),
315 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
316 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
317 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
319 AtomicKind::_Dialect(_) => Ok(()),
320 }
321 }
322
323 fn address_space_for_variable(variable: &Variable<Self>) -> String {
324 format!("{} ", AddressSpace::from(variable))
325 }
326
327 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 write!(f, "thread")
329 }
330
331 fn compile_shared_memory_declaration(
332 f: &mut std::fmt::Formatter<'_>,
333 shared: &SharedMemory<Self>,
334 ) -> std::fmt::Result {
335 match shared {
336 SharedMemory::Array {
337 index,
338 item,
339 length,
340 offset,
341 ..
342 } => {
343 let size_bytes = length * item.size() as u32;
344 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
345 writeln!(
346 f,
347 "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
348 )
349 }
350 SharedMemory::Value {
351 index,
352 item,
353 offset,
354 ..
355 } => {
356 let size_bytes = item.size() as u32;
357 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
358 writeln!(
359 f,
360 "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
361 )
362 }
363 }
364 }
365}
366
367impl DialectBindings<Self> for MslDialect {
370 fn compile_kernel_signature(
371 f: &mut std::fmt::Formatter<'_>,
372 kernel_name: &str,
373 tensor_maps: &[Binding<Self>],
374 buffers: &[Binding<Self>],
375 scalars: &[(Elem<Self>, usize)],
376 flags: &Flags,
377 ) -> std::fmt::Result {
378 write!(
379 (f),
380 "
381[[kernel]]
382void {kernel_name}("
383 )?;
384 let mut buffer_idx = 0;
386 debug_assert!(
387 tensor_maps.is_empty(),
388 "Tensor maps aren't supported for metal"
389 );
390 for (i, b) in buffers.iter().enumerate() {
391 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
392 }
393 if flags.static_meta_length > 0 {
394 let binding = Binding {
395 id: 0,
396 item: Item::scalar(Elem::<Self>::U32, true),
397 location: Location::Storage,
398 size: None,
399 vis: Visibility::Read,
400 };
401 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
402 }
403 for (elem, _) in scalars.iter() {
404 let binding = Binding {
405 id: 0,
406 item: Item::scalar(*elem, true),
407 location: Location::Storage,
408 size: None,
409 vis: Visibility::Read,
410 };
411
412 let name = format!("scalars_{elem}");
413 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
414 }
415
416 let builtins = vec![
418 (
419 flags.indexes.absolute_pos_tuple,
420 Variable::<Self>::AbsolutePosBaseName,
421 ),
422 (
423 flags.indexes.cube_dim_tuple,
424 Variable::<Self>::CubeDimBaseName,
425 ),
426 (
427 flags.indexes.cube_count_tuple,
428 Variable::<Self>::CubeCountBaseName,
429 ),
430 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
431 (
432 flags.indexes.unit_pos_tuple,
433 Variable::<Self>::UnitPosBaseName,
434 ),
435 (
436 flags.indexes.cube_pos_tuple,
437 Variable::<Self>::CubePosBaseName,
438 ),
439 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
440 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
441 (flags.indexes.plane_index, Variable::<Self>::PlanePos),
442 ];
443 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
444 builtins
445 .iter()
446 .filter(|(cond, _)| *cond)
447 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
448 f.write_str("\n)")
449 }
450
451 fn compile_bindings_body(
452 f: &mut std::fmt::Formatter<'_>,
453 body: &shared::Body<Self>,
454 ) -> std::fmt::Result {
455 if !body.shared_memories.is_empty() {
456 let size = body
457 .shared_memories
458 .iter()
459 .map(|it| it.offset() + it.size())
460 .max()
461 .unwrap();
462
463 writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
464 }
465 Ok(())
466 }
467}
468
469impl DialectCubeBuiltins<Self> for MslDialect {
472 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
477 let absolute_pos = flags.absolute_pos;
478 let cube_count = flags.cube_count;
479 let cube_dim = flags.cube_dim;
480 let cube_pos = flags.cube_pos;
481 let plane_dim_checked = flags.plane_dim_checked;
482 let plane_index = flags.plane_index;
483 let unit_pos = flags.unit_pos;
484 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
485 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
486 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
487 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
488 let cluster_pos = flags.cluster_pos;
489 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
490 let unit_pos_plane = flags.unit_pos_plane || plane_index;
491 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
492 CubeIndexFlags {
493 absolute_pos_tuple,
494 absolute_pos,
495 cube_count_tuple,
496 cube_count,
497 cube_dim_tuple,
498 cube_dim,
499 cube_pos_tuple,
500 cube_pos,
501 plane_dim,
502 plane_dim_checked,
503 plane_index,
504 unit_pos_tuple,
505 unit_pos,
506 unit_pos_plane,
507 cluster_pos,
508 }
509 }
510
511 fn compile_absolute_pos_tuple_computation(
512 _f: &mut std::fmt::Formatter<'_>,
513 ) -> std::fmt::Result {
514 Ok(())
516 }
517
518 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.write_str("thread_pos_in_grid")
520 }
521
522 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 f.write_str("thread_index_in_grid")
524 }
525
526 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 Self::compile_absolute_pos_base_name(f)?;
528 write!(f, ".x")
529 }
530
531 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532 Self::compile_absolute_pos_base_name(f)?;
533 write!(f, ".y")
534 }
535
536 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 Self::compile_absolute_pos_base_name(f)?;
538 write!(f, ".z")
539 }
540
541 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542 f.write_str("threadgroups_per_grid")
543 }
544
545 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.write_str("total_threadgroups_in_grid")
547 }
548
549 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550 Self::compile_cube_count_base_name(f)?;
551 write!(f, ".x")
552 }
553
554 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 Self::compile_cube_count_base_name(f)?;
556 write!(f, ".y")
557 }
558
559 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560 Self::compile_cube_count_base_name(f)?;
561 write!(f, ".z")
562 }
563
564 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565 f.write_str("threads_per_threadgroup")
566 }
567
568 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569 f.write_str("total_thread_in_threadgroup")
570 }
571
572 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573 Self::compile_cube_dim_base_name(f)?;
574 write!(f, ".x")
575 }
576
577 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578 Self::compile_cube_dim_base_name(f)?;
579 write!(f, ".y")
580 }
581
582 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 Self::compile_cube_dim_base_name(f)?;
584 write!(f, ".z")
585 }
586
587 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588 f.write_str("threadgroup_pos_in_grid")
589 }
590
591 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 f.write_str("threadgroup_index_in_grid")
593 }
594
595 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 Self::compile_cube_pos_base_name(f)?;
597 write!(f, ".x")
598 }
599
600 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601 Self::compile_cube_pos_base_name(f)?;
602 write!(f, ".y")
603 }
604
605 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 Self::compile_cube_pos_base_name(f)?;
607 write!(f, ".z")
608 }
609
610 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611 Ok(())
613 }
614
615 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616 f.write_str("thread_pos_in_threadgroup")
617 }
618
619 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.write_str("thread_index_in_threadgroup")
621 }
622
623 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624 Self::compile_unit_pos_base_name(f)?;
625 write!(f, ".x")
626 }
627
628 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629 Self::compile_unit_pos_base_name(f)?;
630 write!(f, ".y")
631 }
632
633 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634 Self::compile_unit_pos_base_name(f)?;
635 write!(f, ".z")
636 }
637
638 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.write_str("simd_size")
640 }
641
642 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
643 f.write_str("threads_per_simdgroup_checked")
644 }
645
646 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
647 f.write_str("simd_group_id")
648 }
649
650 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
651 f.write_str("simd_lane_id")
652 }
653}
654
655impl DialectInstructions<Self> for MslDialect {
658 fn compile_atomic_add(
660 f: &mut std::fmt::Formatter<'_>,
661 lhs: &Variable<Self>,
662 rhs: &Variable<Self>,
663 out: &Variable<Self>,
664 ) -> std::fmt::Result {
665 let out = out.fmt_left();
666 writeln!(
667 f,
668 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
669 )
670 }
671
672 fn compile_atomic_and(
673 f: &mut std::fmt::Formatter<'_>,
674 lhs: &Variable<Self>,
675 rhs: &Variable<Self>,
676 out: &Variable<Self>,
677 ) -> std::fmt::Result {
678 let out = out.fmt_left();
679 writeln!(
680 f,
681 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
682 )
683 }
684
685 fn compile_atomic_cas(
686 f: &mut std::fmt::Formatter<'_>,
687 input: &Variable<Self>,
688 cmp: &Variable<Self>,
689 val: &Variable<Self>,
690 out: &Variable<Self>,
691 ) -> std::fmt::Result {
692 let out = out.fmt_left();
693 writeln!(
694 f,
695 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
696 )
697 }
698
699 fn compile_atomic_load(
700 f: &mut std::fmt::Formatter<'_>,
701 input: &Variable<Self>,
702 out: &Variable<Self>,
703 ) -> std::fmt::Result {
704 let out = out.fmt_left();
705 writeln!(
706 f,
707 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
708 )
709 }
710
711 fn compile_atomic_max(
712 f: &mut std::fmt::Formatter<'_>,
713 lhs: &Variable<Self>,
714 rhs: &Variable<Self>,
715 out: &Variable<Self>,
716 ) -> std::fmt::Result {
717 let out = out.fmt_left();
718 writeln!(
719 f,
720 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
721 )
722 }
723
724 fn compile_atomic_min(
725 f: &mut std::fmt::Formatter<'_>,
726 lhs: &Variable<Self>,
727 rhs: &Variable<Self>,
728 out: &Variable<Self>,
729 ) -> std::fmt::Result {
730 let out = out.fmt_left();
731 writeln!(
732 f,
733 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
734 )
735 }
736
737 fn compile_atomic_or(
738 f: &mut std::fmt::Formatter<'_>,
739 lhs: &Variable<Self>,
740 rhs: &Variable<Self>,
741 out: &Variable<Self>,
742 ) -> std::fmt::Result {
743 let out = out.fmt_left();
744 writeln!(
745 f,
746 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
747 )
748 }
749
750 fn compile_atomic_store(
751 f: &mut std::fmt::Formatter<'_>,
752 input: &Variable<Self>,
753 out: &Variable<Self>,
754 ) -> std::fmt::Result {
755 writeln!(
756 f,
757 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
758 )
759 }
760
761 fn compile_atomic_sub(
762 f: &mut std::fmt::Formatter<'_>,
763 lhs: &Variable<Self>,
764 rhs: &Variable<Self>,
765 out: &Variable<Self>,
766 ) -> std::fmt::Result {
767 let out = out.fmt_left();
768 writeln!(
769 f,
770 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
771 )
772 }
773
774 fn compile_atomic_swap(
775 f: &mut std::fmt::Formatter<'_>,
776 lhs: &Variable<Self>,
777 rhs: &Variable<Self>,
778 out: &Variable<Self>,
779 ) -> std::fmt::Result {
780 let out = out.fmt_left();
781 writeln!(
782 f,
783 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
784 )
785 }
786
787 fn compile_atomic_xor(
788 f: &mut std::fmt::Formatter<'_>,
789 lhs: &Variable<Self>,
790 rhs: &Variable<Self>,
791 out: &Variable<Self>,
792 ) -> std::fmt::Result {
793 let out = out.fmt_left();
794 writeln!(
795 f,
796 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
797 )
798 }
799
800 fn compile_saturating_add(
801 f: &mut std::fmt::Formatter<'_>,
802 lhs: impl Display,
803 rhs: impl Display,
804 _item: Item<Self>,
805 ) -> std::fmt::Result {
806 write!(f, "addsat({lhs}, {rhs})")
807 }
808
809 fn compile_saturating_sub(
810 f: &mut std::fmt::Formatter<'_>,
811 lhs: impl Display,
812 rhs: impl Display,
813 _item: Item<Self>,
814 ) -> std::fmt::Result {
815 write!(f, "subsat({lhs}, {rhs})")
816 }
817
818 fn compile_instruction_printf(
820 f: &mut std::fmt::Formatter<'_>,
821 format_string: &str,
822 args: &[Variable<Self>],
823 ) -> std::fmt::Result {
824 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
825 let args = match args.is_empty() {
826 true => "".to_string(),
827 false => format!(", {}", args.join(",")),
828 };
829 writeln!(f, "os_log_default.log({format_string:?}{args});")
830 }
831
832 fn compile_instruction_log1p_scalar<T: Component<Self>>(
834 f: &mut std::fmt::Formatter<'_>,
835 input: T,
836 ) -> std::fmt::Result {
837 match input.elem() {
838 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
839 write!(f, "log(half(1.0f) + {input})")
840 }
841 _ => write!(f, "log(1.0f + {input})"),
842 }
843 }
844
845 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
848 }
849
850 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
851 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
852 }
853
854 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
856 }
857
858 fn compile_instruction_tanh_scalar<T: Component<Self>>(
860 f: &mut std::fmt::Formatter<'_>,
861 input: T,
862 ) -> std::fmt::Result {
863 write!(f, "safe_tanh_scalar({input})")
864 }
865
866 fn compile_instruction_find_first_set<T: Component<Self>>(
868 f: &mut std::fmt::Formatter<'_>,
869 input: T,
870 out_elem: Elem<Self>,
871 ) -> std::fmt::Result {
872 write!(f, "{out_elem}(")?;
873 match input.elem() {
874 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
875 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
876 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
877 }?;
878 write!(f, ")")
879 }
880
881 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
882 f: &mut std::fmt::Formatter<'_>,
883 input: T,
884 out_elem: Elem<Self>,
885 ) -> std::fmt::Result {
886 write!(f, "{out_elem}(clz({input}))")
887 }
888
889 fn compile_instruction_popcount_scalar<T: Component<Self>>(
890 f: &mut std::fmt::Formatter<'_>,
891 input: T,
892 out_elem: Elem<Self>,
893 ) -> std::fmt::Result {
894 write!(f, "{out_elem}(")?;
895 match input.elem() {
896 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
897 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
898 }?;
899 write!(f, ")")
900 }
901
902 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
903 f: &mut std::fmt::Formatter<'_>,
904 input: T,
905 out_elem: Elem<Self>,
906 ) -> std::fmt::Result {
907 write!(f, "{out_elem}(")?;
908 match out_elem {
909 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
910 _ => write!(
911 f,
912 "reverse_bits({}) >> {}",
913 shared::unary::zero_extend(input),
914 (size_of::<u32>() - out_elem.size()) * 8
915 ),
916 }?;
917 write!(f, ")")
918 }
919
920 fn compile_instruction_max_function_name(
922 f: &mut std::fmt::Formatter<'_>,
923 _item: Item<Self>,
924 ) -> std::fmt::Result {
925 write!(f, "max")
926 }
927
928 fn compile_instruction_min_function_name(
929 f: &mut std::fmt::Formatter<'_>,
930 _item: Item<Self>,
931 ) -> std::fmt::Result {
932 write!(f, "min")
933 }
934
935 fn compile_instruction_powf(
936 f: &mut std::fmt::Formatter<'_>,
937 lhs: &str,
938 rhs: &str,
939 elem: Elem<Self>,
940 ) -> std::fmt::Result {
941 write!(f, "pow({lhs}, {elem}({rhs}))")
942 }
943
944 fn compile_instruction_half_function_name_prefix() -> &'static str {
945 ""
946 }
947
948 fn compile_instruction_half2_function_name_prefix() -> &'static str {
949 ""
950 }
951
952 fn compile_warp_shuffle(
954 f: &mut std::fmt::Formatter<'_>,
955 var: &str,
956 source: &str,
957 ) -> std::fmt::Result {
958 write!(f, "simd_shuffle({var}, {source})")
959 }
960
961 fn compile_warp_shuffle_xor(
962 f: &mut std::fmt::Formatter<'_>,
963 var: &str,
964 _elem: &Elem<Self>,
965 offset: &str,
966 ) -> std::fmt::Result {
967 write!(f, "simd_shuffle_xor({var}, {offset})")
968 }
969
970 fn compile_warp_shuffle_up(
971 f: &mut std::fmt::Formatter<'_>,
972 var: &str,
973 offset: &str,
974 ) -> std::fmt::Result {
975 write!(f, "simd_shuffle_up({var}, {offset})")
976 }
977
978 fn compile_warp_shuffle_down(
979 f: &mut std::fmt::Formatter<'_>,
980 var: &str,
981 offset: &str,
982 ) -> std::fmt::Result {
983 write!(f, "simd_shuffle_down({var}, {offset})")
984 }
985
986 fn compile_warp_all<T: Component<Self>>(
987 f: &mut std::fmt::Formatter<'_>,
988 input: &T,
989 ) -> std::fmt::Result {
990 write!(f, "simd_all({input})")
991 }
992
993 fn compile_warp_any<T: Component<Self>>(
994 f: &mut std::fmt::Formatter<'_>,
995 input: &T,
996 ) -> std::fmt::Result {
997 write!(f, "simd_any({input})")
998 }
999
1000 fn compile_warp_ballot(
1001 f: &mut std::fmt::Formatter<'_>,
1002 input: &Variable<Self>,
1003 out_elem: &Elem<Self>,
1004 ) -> std::fmt::Result {
1005 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1006 }
1007}
1008
1009impl DialectWmmaCompiler<Self> for MslDialect {
1012 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
1013 writeln!(f, "#include <metal_simdgroup_matrix>")
1014 }
1015
1016 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017 Ok(())
1019 }
1020
1021 fn compile_wmma_fragment_declaration(
1022 f: &mut std::fmt::Formatter<'_>,
1023 var: &crate::shared::Variable<MslDialect>,
1024 ) -> std::fmt::Result {
1025 wmma_api_base::compile_fragment_declaration(f, var)
1026 }
1027
1028 fn compile_wwma_fragment_ident(
1029 _f: &mut std::fmt::Formatter<'_>,
1030 _ident: &FragmentIdent<Self>,
1031 ) -> std::fmt::Result {
1032 Ok(())
1034 }
1035
1036 fn compile_wmma_fragment_layout(
1037 _f: &mut std::fmt::Formatter<'_>,
1038 _layout: &FragmentLayout<Self>,
1039 ) -> std::fmt::Result {
1040 Ok(())
1042 }
1043
1044 fn compile_wmma_fragment(
1045 f: &mut std::fmt::Formatter<'_>,
1046 fragment: &Fragment<Self>,
1047 ) -> std::fmt::Result {
1048 let ty = fragment.elem;
1049 let m = fragment.m;
1051 let n = fragment.n;
1052 let k = fragment.k;
1053 if m != 8 || n != 8 || k != 8 {
1054 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1055 }
1056 write!(f, "simdgroup_{ty}8x8")
1057 }
1058
1059 fn compile_wmma_instruction(
1060 f: &mut std::fmt::Formatter<'_>,
1061 instruction: &WmmaInstruction<Self>,
1062 ) -> std::fmt::Result {
1063 match instruction {
1064 WmmaInstruction::Fill { frag, value } => {
1065 match frag {
1066 Variable::WmmaFragment { .. } => {
1067 let ty = frag.elem();
1068 writeln!(
1070 f,
1071 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1072 )
1073 }
1074 _ => panic!("should be a fragment"),
1075 }
1076 }
1077 WmmaInstruction::Load {
1078 frag,
1079 value,
1080 stride,
1081 offset,
1082 layout: _layout,
1083 } => {
1084 let transpose = match frag {
1085 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1086 Some(FragmentLayout::RowMajor) => false,
1087 Some(FragmentLayout::ColMajor) => true,
1088 _ => false,
1089 },
1090 _ => panic!("should be a fragment"),
1091 };
1092 let item = value.item();
1093 if item.vectorization > 1 {
1094 let elem = item.elem;
1095 match value {
1096 Variable::GlobalInputArray(..) => writeln!(
1097 f,
1098 "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1099 ),
1100 Variable::SharedArray(..) => writeln!(
1101 f,
1102 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1103 ),
1104 _ => panic!(
1105 "Vectorized wmma load is only supported from global or shared memory."
1106 ),
1107 }
1108 } else {
1109 writeln!(
1110 f,
1111 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1112 )
1113 }
1114 }
1115 WmmaInstruction::Execute {
1116 frag_a: a,
1117 frag_b: b,
1118 frag_c: c,
1119 frag_d: d,
1120 ..
1121 } => {
1122 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1123 }
1124 WmmaInstruction::Store {
1125 output,
1126 frag,
1127 stride,
1128 offset,
1129 layout: _layout,
1130 } => {
1131 let item = output.item();
1132 let mut reinterpret_cast = item.vectorization > 1;
1133 let elem = match item.elem {
1134 Elem::BF16 => {
1135 reinterpret_cast = true;
1136 Elem::F16
1137 }
1138 _ => item.elem,
1139 };
1140 if reinterpret_cast {
1141 writeln!(
1142 f,
1143 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1144 )
1145 } else {
1146 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1147 }?;
1148 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1149 }
1150 WmmaInstruction::Cast { input, output } => {
1151 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1152 let ty = match output {
1153 Variable::WmmaFragment { frag, .. } => frag.elem,
1154 _ => panic!("should be a fragment"),
1155 };
1156 match ty {
1157 Elem::BF16 => {
1158 let addr_space = Self::address_space_for_variable(output);
1159 let elem = Elem::<Self>::F16;
1160 writeln!(
1163 f,
1164 "for(int e=0; e<8; e++) {{
1165 {ty} elem = {ty}({input}.thread_elements()[e]);
1166 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1167}}"
1168 )
1169 }
1170 _ => {
1171 writeln!(
1172 f,
1173 "for(int e=0; e<8; e++) {{
1174 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1175}}"
1176 )
1177 }
1178 }
1179 }
1180 WmmaInstruction::ExecuteManual {
1181 shape,
1182 frag_a,
1183 frag_b,
1184 frag_c,
1185 frag_d,
1186 } => {
1187 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1188 }
1189 WmmaInstruction::ExecuteScaled {
1190 shape,
1191 frag_a,
1192 frag_b,
1193 frag_c,
1194 frag_d,
1195 scales_a,
1196 scales_b,
1197 scales_factor,
1198 } => Self::compile_scaled_mma(
1199 f,
1200 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1201 *scales_a,
1202 *scales_b,
1203 *scales_factor,
1204 ),
1205 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1206 f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1207 }
1208 }
1209 }
1210
1211 fn compile_manual_mma(
1212 f: &mut std::fmt::Formatter<'_>,
1213 _mma: shared::ManualMma<Self>,
1214 ) -> std::fmt::Result {
1215 f.write_str("#error manual mma not supported on Metal\n")
1216 }
1217
1218 fn compile_scaled_mma(
1219 f: &mut std::fmt::Formatter<'_>,
1220 _mma: shared::ManualMma<Self>,
1221 _scales_a: Variable<Self>,
1222 _scales_b: Variable<Self>,
1223 _scales_factor: u32,
1224 ) -> std::fmt::Result {
1225 f.write_str("#error scaled mma not supported on Metal\n")
1226 }
1227
1228 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1229 let types = vec![
1230 (
1231 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1232 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1233 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1234 ),
1235 (
1236 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1237 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1238 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1239 ),
1240 (
1241 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1242 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1243 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1244 ),
1245 (
1246 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1248 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1249 ),
1250 ];
1251 types
1252 .into_iter()
1253 .map(|(a_type, b_type, cd_type)| MmaConfig {
1254 a_type,
1255 b_type,
1256 cd_type,
1257 m: 8,
1258 n: 8,
1259 k: 8,
1260 })
1261 .collect()
1262 }
1263
1264 fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1265 Vec::new()
1266 }
1267}
1268
1269impl DialectProcessors<Self> for MslDialect {
1272 fn processors() -> Vec<Box<dyn gpu::Processor>> {
1273 Vec::new()
1274 }
1275}