1use super::{
2 AddressSpace, Extension,
3 arch::MetalArchitecture,
4 extension::{format_ffs, format_mulhi},
5 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8 Dialect,
9 shared::{
10 self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11 DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12 DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory,
14 SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15 },
16};
17use core::panic;
18use cubecl_core::{
19 ir::{self as gpu, features::MmaConfig},
20 prelude::{Location, Visibility},
21};
22use std::fmt::Display;
23
24#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
25pub struct MslDialect {}
26
27impl Dialect for MslDialect {
30 type Architecture = MetalArchitecture;
31}
32
33impl MslDialect {
34 fn warp_op_vectorized(
35 f: &mut core::fmt::Formatter<'_>,
36 input: &Variable<Self>,
37 out: &Variable<Self>,
38 simd_op_prefix: &str,
39 simd_op_suffix: &str,
40 ) -> core::fmt::Result {
41 let out = out.fmt_left();
42 let vectorization = input.item().vectorization;
43
44 f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
45
46 for k in 0..vectorization {
47 let index = if vectorization > 1 {
48 format!(".i_{k}")
49 } else {
50 String::new()
51 };
52 let comma = if k + 1 < vectorization { "," } else { "" };
53
54 writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
55 }
56
57 f.write_fmt(format_args!("}};\n"))
58 }
59}
60
61impl DialectWarpReduceCompiler<Self> for MslDialect {
62 fn warp_reduce_sum(
63 f: &mut core::fmt::Formatter<'_>,
64 input: &Variable<Self>,
65 out: &Variable<Self>,
66 ) -> core::fmt::Result {
67 Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
68 }
69 fn warp_reduce_prod(
70 f: &mut core::fmt::Formatter<'_>,
71 input: &Variable<Self>,
72 out: &Variable<Self>,
73 ) -> core::fmt::Result {
74 Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
75 }
76 fn warp_reduce_max(
77 f: &mut core::fmt::Formatter<'_>,
78 input: &Variable<Self>,
79 out: &Variable<Self>,
80 ) -> core::fmt::Result {
81 Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
82 }
83 fn warp_reduce_min(
84 f: &mut core::fmt::Formatter<'_>,
85 input: &Variable<Self>,
86 out: &Variable<Self>,
87 ) -> core::fmt::Result {
88 Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
89 }
90 fn warp_reduce_all(
91 f: &mut core::fmt::Formatter<'_>,
92 input: &Variable<Self>,
93 out: &Variable<Self>,
94 ) -> core::fmt::Result {
95 Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
96 }
97 fn warp_reduce_any(
98 f: &mut core::fmt::Formatter<'_>,
99 input: &Variable<Self>,
100 out: &Variable<Self>,
101 ) -> core::fmt::Result {
102 Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
103 }
104 fn warp_reduce_sum_inclusive(
105 f: &mut core::fmt::Formatter<'_>,
106 input: &Variable<Self>,
107 out: &Variable<Self>,
108 ) -> core::fmt::Result {
109 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
110 }
111 fn warp_reduce_prod_inclusive(
112 f: &mut core::fmt::Formatter<'_>,
113 input: &Variable<Self>,
114 out: &Variable<Self>,
115 ) -> core::fmt::Result {
116 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
117 }
118 fn warp_reduce_sum_exclusive(
119 f: &mut core::fmt::Formatter<'_>,
120 input: &Variable<Self>,
121 out: &Variable<Self>,
122 ) -> core::fmt::Result {
123 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
124 }
125 fn warp_reduce_prod_exclusive(
126 f: &mut core::fmt::Formatter<'_>,
127 input: &Variable<Self>,
128 out: &Variable<Self>,
129 ) -> core::fmt::Result {
130 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
131 }
132}
133
134impl DialectIncludes<Self> for MslDialect {
137 type Extension = Extension<Self>;
138
139 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
140 write!(
141 f,
142 "
143#include <metal_stdlib>
144using namespace metal;
145"
146 )?;
147 Ok(())
148 }
149
150 fn compile_extensions(
151 f: &mut std::fmt::Formatter<'_>,
152 extensions: &[Self::Extension],
153 ) -> std::fmt::Result {
154 for extension in extensions {
155 match extension {
156 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
157 Extension::Ffs(elem) => format_ffs(f, elem)?,
158 Extension::MulHi(elem) => format_mulhi(f, elem)?,
159 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
160 Extension::NoExtension => {}
161 }
162 }
163 Ok(())
164 }
165
166 fn register_instruction_extension(
167 extensions: &mut Vec<Self::Extension>,
168 instruction: &Instruction<Self>,
169 ) {
170 let mut register_extension = |extension: Self::Extension| {
171 if !extensions.contains(&extension) {
172 extensions.push(extension);
173 }
174 };
175 #[allow(clippy::single_match)]
176 match instruction {
177 shared::Instruction::<Self>::Erf(instruction) => {
178 register_extension(Extension::Erf(
179 instruction.input.elem(),
180 instruction.out.elem(),
181 ));
182 }
183 shared::Instruction::<Self>::FindFirstSet(instruction) => {
184 let input_elem = instruction.input.elem();
185 match input_elem {
186 Elem::U32 | Elem::U64 => {
187 register_extension(Extension::Ffs(instruction.input.elem()));
188 }
189 Elem::I32 => {
190 register_extension(Extension::Ffs(Elem::<Self>::U32));
191 register_extension(Extension::Ffs(instruction.input.elem()));
192 }
193 Elem::I64 => {
194 register_extension(Extension::Ffs(Elem::<Self>::U64));
195 register_extension(Extension::Ffs(instruction.input.elem()));
196 }
197 _ => {
198 register_extension(Extension::Ffs(Elem::<Self>::U32));
199 }
200 }
201 }
202 shared::Instruction::<Self>::HiMul(instruction) => {
203 register_extension(Extension::MulHi(instruction.out.elem()));
204 }
205 shared::Instruction::<Self>::Tanh(instruction) => {
206 register_extension(Extension::SafeTanh(instruction.input.item()));
207 }
208 _ => {}
209 }
210 }
211
212 fn register_warp_instruction_extension(
213 _extensions: &mut Vec<Self::Extension>,
214 _instruction: &WarpInstruction<Self>,
215 ) {
216 }
217}
218
219impl DialectTypes<Self> for MslDialect {
222 fn item_can_be_optimized() -> bool {
223 false
224 }
225
226 fn compile_type_definitions(
227 f: &mut std::fmt::Formatter<'_>,
228 items: &std::collections::HashSet<crate::shared::Item<Self>>,
229 _scalars: &[(Elem<Self>, usize)],
230 _flags: &Flags<Self>,
231 ) -> std::fmt::Result {
232 for item in items.iter() {
233 let elem = item.elem;
234 let size = item.vectorization;
235 let alignment = elem.size() * size;
236 if size > 1 {
237 write!(
238 f,
239 "
240struct alignas({alignment}) {item} {{"
241 )?;
242
243 for i in 0..size {
244 write!(
245 f,
246 "
247 {elem} i_{i};"
248 )?;
249 }
250
251 f.write_str("\n};\n")?;
252 }
253 }
254 Ok(())
255 }
256
257 fn compile_elem(
258 f: &mut std::fmt::Formatter<'_>,
259 elem: &shared::Elem<Self>,
260 _words: bool,
261 ) -> std::fmt::Result {
262 match elem {
264 shared::Elem::FP4(_)
265 | shared::Elem::FP4x2(_)
266 | shared::Elem::FP6(_)
267 | shared::Elem::FP6x2(_)
268 | shared::Elem::FP8(_)
269 | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270 shared::Elem::F16 => f.write_str("half"),
271 shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272 shared::Elem::F32 => f.write_str("float"),
273 shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274 shared::Elem::BF16 => f.write_str("bfloat"),
275 shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276 shared::Elem::TF32 => f.write_str("float"),
277 shared::Elem::I8 => f.write_str("char"),
278 shared::Elem::I16 => f.write_str("short"),
279 shared::Elem::I32 => f.write_str("int"),
280 shared::Elem::I64 => f.write_str("long"),
281 shared::Elem::U8 => f.write_str("uchar"),
282 shared::Elem::U16 => f.write_str("ushort"),
283 shared::Elem::U32 => f.write_str("uint"),
284 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
286 shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287 shared::Elem::Atomic(inner) => inner.fmt(f),
288 shared::Elem::_Dialect(_) => Ok(()),
289 }
290 }
291
292 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293 if 1 == item.vectorization {
294 return write!(f, "{}", item.elem);
295 }
296 if item.native {
297 write!(f, "{}{}", item.elem, item.vectorization)
298 } else {
299 write!(f, "{}_{}", item.elem, item.vectorization)
300 }
301 }
302
303 fn compile_atomic_kind(
304 f: &mut std::fmt::Formatter<'_>,
305 kind: &AtomicKind<Self>,
306 ) -> std::fmt::Result {
307 match kind {
308 AtomicKind::I32 => write!(f, "atomic_int"),
309 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310 AtomicKind::U32 => write!(f, "atomic_uint"),
311 AtomicKind::U64 => write!(f, "atomic_ulong"),
312 AtomicKind::F16 => panic!("F16 atomic kind no supported."),
313 AtomicKind::BF16 => panic!("BF16 atomic kind no supported."),
314 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316 AtomicKind::_Dialect(_) => Ok(()),
317 }
318 }
319
320 fn address_space_for_variable(variable: &Variable<Self>) -> String {
321 format!("{} ", AddressSpace::from(variable))
322 }
323
324 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 write!(f, "thread")
326 }
327
328 fn compile_shared_memory_declaration(
329 f: &mut std::fmt::Formatter<'_>,
330 shared: &SharedMemory<Self>,
331 ) -> std::fmt::Result {
332 match shared {
333 SharedMemory::Array {
334 index,
335 item,
336 length,
337 offset,
338 ..
339 } => {
340 let size_bytes = length * item.size();
341 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342 writeln!(
343 f,
344 "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345 )
346 }
347 SharedMemory::Value {
348 index,
349 item,
350 offset,
351 ..
352 } => {
353 let size_bytes = item.size();
354 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355 writeln!(
356 f,
357 "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358 )
359 }
360 }
361 }
362}
363
364impl DialectBindings<Self> for MslDialect {
367 fn compile_kernel_signature(
368 f: &mut std::fmt::Formatter<'_>,
369 kernel_name: &str,
370 tensor_maps: &[Binding<Self>],
371 buffers: &[Binding<Self>],
372 scalars: &[(Elem<Self>, usize)],
373 flags: &Flags<Self>,
374 ) -> std::fmt::Result {
375 write!(
376 (f),
377 "
378[[kernel]]
379void {kernel_name}("
380 )?;
381 let mut buffer_idx = 0;
383 debug_assert!(
384 tensor_maps.is_empty(),
385 "Tensor maps aren't supported for metal"
386 );
387 for (i, b) in buffers.iter().enumerate() {
388 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
389 }
390 if flags.static_meta_length > 0 {
391 let binding = Binding {
392 id: 0,
393 item: Item::scalar(Elem::<Self>::U32, true),
394 location: Location::Storage,
395 size: None,
396 vis: Visibility::Read,
397 };
398 format_global_binding_arg("info", &binding, None, &mut buffer_idx, f)?;
399 }
400 for (elem, _) in scalars.iter() {
401 let binding = Binding {
402 id: 0,
403 item: Item::scalar(*elem, true),
404 location: Location::Storage,
405 size: None,
406 vis: Visibility::Read,
407 };
408
409 let name = format!("scalars_{elem}");
410 format_global_binding_arg(&name, &binding, None, &mut buffer_idx, f)?;
411 }
412
413 let builtins = vec![
415 (
416 flags.indexes.absolute_pos_tuple,
417 Variable::<Self>::AbsolutePosBaseName,
418 ),
419 (
420 flags.indexes.cube_dim_tuple,
421 Variable::<Self>::CubeDimBaseName,
422 ),
423 (
424 flags.indexes.cube_count_tuple,
425 Variable::<Self>::CubeCountBaseName,
426 ),
427 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
428 (
429 flags.indexes.unit_pos_tuple,
430 Variable::<Self>::UnitPosBaseName,
431 ),
432 (
433 flags.indexes.cube_pos_tuple,
434 Variable::<Self>::CubePosBaseName,
435 ),
436 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
437 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
438 (flags.indexes.plane_index, Variable::<Self>::PlanePos),
439 ];
440 let comma = !buffers.is_empty() || flags.static_meta_length > 0 || !scalars.is_empty();
441 builtins
442 .iter()
443 .filter(|(cond, _)| *cond)
444 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
445 f.write_str("\n)")
446 }
447
448 fn compile_bindings_body(
449 f: &mut std::fmt::Formatter<'_>,
450 body: &shared::Body<Self>,
451 ) -> std::fmt::Result {
452 if !body.shared_memories.is_empty() {
453 let size = body
454 .shared_memories
455 .iter()
456 .map(|it| it.offset() + it.size())
457 .max()
458 .unwrap();
459
460 writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
461 }
462 Ok(())
463 }
464}
465
466impl DialectCubeBuiltins<Self> for MslDialect {
469 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
474 let absolute_pos = flags.absolute_pos;
475 let cube_count = flags.cube_count;
476 let cube_dim = flags.cube_dim;
477 let cube_pos = flags.cube_pos;
478 let plane_dim_checked = flags.plane_dim_checked;
479 let plane_index = flags.plane_index;
480 let unit_pos = flags.unit_pos;
481 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
482 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
483 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
484 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
485 let cluster_pos = flags.cluster_pos;
486 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
487 let unit_pos_plane = flags.unit_pos_plane || plane_index;
488 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
489 CubeIndexFlags {
490 absolute_pos_tuple,
491 absolute_pos,
492 cube_count_tuple,
493 cube_count,
494 cube_dim_tuple,
495 cube_dim,
496 cube_pos_tuple,
497 cube_pos,
498 plane_dim,
499 plane_dim_checked,
500 plane_index,
501 unit_pos_tuple,
502 unit_pos,
503 unit_pos_plane,
504 cluster_pos,
505 }
506 }
507
508 fn compile_absolute_pos_tuple_computation(
509 _f: &mut std::fmt::Formatter<'_>,
510 ) -> std::fmt::Result {
511 Ok(())
513 }
514
515 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 f.write_str("thread_pos_in_grid")
517 }
518
519 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520 f.write_str("thread_index_in_grid")
521 }
522
523 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524 Self::compile_absolute_pos_base_name(f)?;
525 write!(f, ".x")
526 }
527
528 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 Self::compile_absolute_pos_base_name(f)?;
530 write!(f, ".y")
531 }
532
533 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 Self::compile_absolute_pos_base_name(f)?;
535 write!(f, ".z")
536 }
537
538 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539 f.write_str("threadgroups_per_grid")
540 }
541
542 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 f.write_str("total_threadgroups_in_grid")
544 }
545
546 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547 Self::compile_cube_count_base_name(f)?;
548 write!(f, ".x")
549 }
550
551 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552 Self::compile_cube_count_base_name(f)?;
553 write!(f, ".y")
554 }
555
556 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 Self::compile_cube_count_base_name(f)?;
558 write!(f, ".z")
559 }
560
561 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 f.write_str("threads_per_threadgroup")
563 }
564
565 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 f.write_str("total_thread_in_threadgroup")
567 }
568
569 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
570 Self::compile_cube_dim_base_name(f)?;
571 write!(f, ".x")
572 }
573
574 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575 Self::compile_cube_dim_base_name(f)?;
576 write!(f, ".y")
577 }
578
579 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580 Self::compile_cube_dim_base_name(f)?;
581 write!(f, ".z")
582 }
583
584 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 f.write_str("threadgroup_pos_in_grid")
586 }
587
588 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 f.write_str("threadgroup_index_in_grid")
590 }
591
592 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 Self::compile_cube_pos_base_name(f)?;
594 write!(f, ".x")
595 }
596
597 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 Self::compile_cube_pos_base_name(f)?;
599 write!(f, ".y")
600 }
601
602 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603 Self::compile_cube_pos_base_name(f)?;
604 write!(f, ".z")
605 }
606
607 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608 Ok(())
610 }
611
612 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
613 f.write_str("thread_pos_in_threadgroup")
614 }
615
616 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 f.write_str("thread_index_in_threadgroup")
618 }
619
620 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 Self::compile_unit_pos_base_name(f)?;
622 write!(f, ".x")
623 }
624
625 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 Self::compile_unit_pos_base_name(f)?;
627 write!(f, ".y")
628 }
629
630 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 Self::compile_unit_pos_base_name(f)?;
632 write!(f, ".z")
633 }
634
635 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 f.write_str("simd_size")
637 }
638
639 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640 f.write_str("threads_per_simdgroup_checked")
641 }
642
643 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 f.write_str("simd_group_id")
645 }
646
647 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
648 f.write_str("simd_lane_id")
649 }
650}
651
652impl DialectInstructions<Self> for MslDialect {
655 fn compile_atomic_add(
657 f: &mut std::fmt::Formatter<'_>,
658 lhs: &Variable<Self>,
659 rhs: &Variable<Self>,
660 out: &Variable<Self>,
661 ) -> std::fmt::Result {
662 let out = out.fmt_left();
663 writeln!(
664 f,
665 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
666 )
667 }
668
669 fn compile_atomic_and(
670 f: &mut std::fmt::Formatter<'_>,
671 lhs: &Variable<Self>,
672 rhs: &Variable<Self>,
673 out: &Variable<Self>,
674 ) -> std::fmt::Result {
675 let out = out.fmt_left();
676 writeln!(
677 f,
678 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
679 )
680 }
681
682 fn compile_atomic_cas(
683 f: &mut std::fmt::Formatter<'_>,
684 input: &Variable<Self>,
685 cmp: &Variable<Self>,
686 val: &Variable<Self>,
687 out: &Variable<Self>,
688 ) -> std::fmt::Result {
689 let out = out.fmt_left();
690 writeln!(
691 f,
692 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
693 )
694 }
695
696 fn compile_atomic_load(
697 f: &mut std::fmt::Formatter<'_>,
698 input: &Variable<Self>,
699 out: &Variable<Self>,
700 ) -> std::fmt::Result {
701 let out = out.fmt_left();
702 writeln!(
703 f,
704 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
705 )
706 }
707
708 fn compile_atomic_max(
709 f: &mut std::fmt::Formatter<'_>,
710 lhs: &Variable<Self>,
711 rhs: &Variable<Self>,
712 out: &Variable<Self>,
713 ) -> std::fmt::Result {
714 let out = out.fmt_left();
715 writeln!(
716 f,
717 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
718 )
719 }
720
721 fn compile_atomic_min(
722 f: &mut std::fmt::Formatter<'_>,
723 lhs: &Variable<Self>,
724 rhs: &Variable<Self>,
725 out: &Variable<Self>,
726 ) -> std::fmt::Result {
727 let out = out.fmt_left();
728 writeln!(
729 f,
730 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
731 )
732 }
733
734 fn compile_atomic_or(
735 f: &mut std::fmt::Formatter<'_>,
736 lhs: &Variable<Self>,
737 rhs: &Variable<Self>,
738 out: &Variable<Self>,
739 ) -> std::fmt::Result {
740 let out = out.fmt_left();
741 writeln!(
742 f,
743 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
744 )
745 }
746
747 fn compile_atomic_store(
748 f: &mut std::fmt::Formatter<'_>,
749 input: &Variable<Self>,
750 out: &Variable<Self>,
751 ) -> std::fmt::Result {
752 writeln!(
753 f,
754 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
755 )
756 }
757
758 fn compile_atomic_sub(
759 f: &mut std::fmt::Formatter<'_>,
760 lhs: &Variable<Self>,
761 rhs: &Variable<Self>,
762 out: &Variable<Self>,
763 ) -> std::fmt::Result {
764 let out = out.fmt_left();
765 writeln!(
766 f,
767 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
768 )
769 }
770
771 fn compile_atomic_swap(
772 f: &mut std::fmt::Formatter<'_>,
773 lhs: &Variable<Self>,
774 rhs: &Variable<Self>,
775 out: &Variable<Self>,
776 ) -> std::fmt::Result {
777 let out = out.fmt_left();
778 writeln!(
779 f,
780 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
781 )
782 }
783
784 fn compile_atomic_xor(
785 f: &mut std::fmt::Formatter<'_>,
786 lhs: &Variable<Self>,
787 rhs: &Variable<Self>,
788 out: &Variable<Self>,
789 ) -> std::fmt::Result {
790 let out = out.fmt_left();
791 writeln!(
792 f,
793 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
794 )
795 }
796
797 fn compile_saturating_add(
798 f: &mut std::fmt::Formatter<'_>,
799 lhs: impl Display,
800 rhs: impl Display,
801 _item: Item<Self>,
802 ) -> std::fmt::Result {
803 write!(f, "addsat({lhs}, {rhs})")
804 }
805
806 fn compile_saturating_sub(
807 f: &mut std::fmt::Formatter<'_>,
808 lhs: impl Display,
809 rhs: impl Display,
810 _item: Item<Self>,
811 ) -> std::fmt::Result {
812 write!(f, "subsat({lhs}, {rhs})")
813 }
814
815 fn compile_instruction_printf(
817 f: &mut std::fmt::Formatter<'_>,
818 format_string: &str,
819 args: &[Variable<Self>],
820 ) -> std::fmt::Result {
821 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
822 let args = match args.is_empty() {
823 true => "".to_string(),
824 false => format!(", {}", args.join(",")),
825 };
826 writeln!(f, "os_log_default.log({format_string:?}{args});")
827 }
828
829 fn compile_instruction_log1p_scalar<T: Component<Self>>(
831 f: &mut std::fmt::Formatter<'_>,
832 input: T,
833 ) -> std::fmt::Result {
834 match input.elem() {
835 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
836 write!(f, "log(half(1.0f) + {input})")
837 }
838 _ => write!(f, "log(1.0f + {input})"),
839 }
840 }
841
842 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
845 }
846
847 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
849 }
850
851 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
852 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
853 }
854
855 fn compile_instruction_tanh_scalar<T: Component<Self>>(
857 f: &mut std::fmt::Formatter<'_>,
858 input: T,
859 ) -> std::fmt::Result {
860 write!(f, "safe_tanh_scalar({input})")
861 }
862
863 fn compile_instruction_find_first_set<T: Component<Self>>(
865 f: &mut std::fmt::Formatter<'_>,
866 input: T,
867 out_elem: Elem<Self>,
868 ) -> std::fmt::Result {
869 write!(f, "{out_elem}(")?;
870 match input.elem() {
871 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
872 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
873 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
874 }?;
875 write!(f, ")")
876 }
877
878 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
879 f: &mut std::fmt::Formatter<'_>,
880 input: T,
881 out_elem: Elem<Self>,
882 ) -> std::fmt::Result {
883 write!(f, "{out_elem}(clz({input}))")
884 }
885
886 fn compile_instruction_popcount_scalar<T: Component<Self>>(
887 f: &mut std::fmt::Formatter<'_>,
888 input: T,
889 out_elem: Elem<Self>,
890 ) -> std::fmt::Result {
891 write!(f, "{out_elem}(")?;
892 match input.elem() {
893 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
894 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
895 }?;
896 write!(f, ")")
897 }
898
899 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
900 f: &mut std::fmt::Formatter<'_>,
901 input: T,
902 out_elem: Elem<Self>,
903 ) -> std::fmt::Result {
904 write!(f, "{out_elem}(")?;
905 match out_elem {
906 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
907 _ => write!(
908 f,
909 "reverse_bits({}) >> {}",
910 shared::unary::zero_extend(input),
911 (size_of::<u32>() - out_elem.size()) * 8
912 ),
913 }?;
914 write!(f, ")")
915 }
916
917 fn compile_instruction_max_function_name(
919 f: &mut std::fmt::Formatter<'_>,
920 _item: Item<Self>,
921 ) -> std::fmt::Result {
922 write!(f, "max")
923 }
924
925 fn compile_instruction_min_function_name(
926 f: &mut std::fmt::Formatter<'_>,
927 _item: Item<Self>,
928 ) -> std::fmt::Result {
929 write!(f, "min")
930 }
931
932 fn compile_instruction_powf(
933 f: &mut std::fmt::Formatter<'_>,
934 lhs: &str,
935 rhs: &str,
936 elem: Elem<Self>,
937 ) -> std::fmt::Result {
938 write!(f, "pow({lhs}, {elem}({rhs}))")
939 }
940
941 fn compile_instruction_half_function_name_prefix() -> &'static str {
942 ""
943 }
944
945 fn compile_instruction_half2_function_name_prefix() -> &'static str {
946 ""
947 }
948
949 fn compile_warp_shuffle(
951 f: &mut std::fmt::Formatter<'_>,
952 var: &str,
953 source: &str,
954 ) -> std::fmt::Result {
955 write!(f, "simd_shuffle({var}, {source})")
956 }
957
958 fn compile_warp_shuffle_xor(
959 f: &mut std::fmt::Formatter<'_>,
960 var: &str,
961 _elem: &Elem<Self>,
962 offset: &str,
963 ) -> std::fmt::Result {
964 write!(f, "simd_shuffle_xor({var}, {offset})")
965 }
966
967 fn compile_warp_shuffle_up(
968 f: &mut std::fmt::Formatter<'_>,
969 var: &str,
970 offset: &str,
971 ) -> std::fmt::Result {
972 write!(f, "simd_shuffle_up({var}, {offset})")
973 }
974
975 fn compile_warp_shuffle_down(
976 f: &mut std::fmt::Formatter<'_>,
977 var: &str,
978 offset: &str,
979 ) -> std::fmt::Result {
980 write!(f, "simd_shuffle_down({var}, {offset})")
981 }
982
983 fn compile_warp_all<T: Component<Self>>(
984 f: &mut std::fmt::Formatter<'_>,
985 input: &T,
986 ) -> std::fmt::Result {
987 write!(f, "simd_all({input})")
988 }
989
990 fn compile_warp_any<T: Component<Self>>(
991 f: &mut std::fmt::Formatter<'_>,
992 input: &T,
993 ) -> std::fmt::Result {
994 write!(f, "simd_any({input})")
995 }
996
997 fn compile_warp_ballot(
998 f: &mut std::fmt::Formatter<'_>,
999 input: &Variable<Self>,
1000 out_elem: &Elem<Self>,
1001 ) -> std::fmt::Result {
1002 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1003 }
1004}
1005
1006impl DialectWmmaCompiler<Self> for MslDialect {
1009 fn compile_wmma_includes(
1010 f: &mut std::fmt::Formatter<'_>,
1011 _flags: &Flags<Self>,
1012 ) -> std::fmt::Result {
1013 writeln!(f, "#include <metal_simdgroup_matrix>")
1014 }
1015
1016 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017 Ok(())
1019 }
1020
1021 fn compile_wmma_fragment_declaration(
1022 f: &mut std::fmt::Formatter<'_>,
1023 var: &crate::shared::Variable<MslDialect>,
1024 ) -> std::fmt::Result {
1025 wmma_api_base::compile_fragment_declaration(f, var)
1026 }
1027
1028 fn compile_wwma_fragment_ident(
1029 _f: &mut std::fmt::Formatter<'_>,
1030 _ident: &FragmentIdent<Self>,
1031 ) -> std::fmt::Result {
1032 Ok(())
1034 }
1035
1036 fn compile_wmma_fragment_layout(
1037 _f: &mut std::fmt::Formatter<'_>,
1038 _layout: &FragmentLayout<Self>,
1039 ) -> std::fmt::Result {
1040 Ok(())
1042 }
1043
1044 fn compile_wmma_fragment(
1045 f: &mut std::fmt::Formatter<'_>,
1046 fragment: &Fragment<Self>,
1047 ) -> std::fmt::Result {
1048 let ty = fragment.elem;
1049 let m = fragment.m;
1051 let n = fragment.n;
1052 let k = fragment.k;
1053 if m != 8 || n != 8 || k != 8 {
1054 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1055 }
1056 write!(f, "simdgroup_{ty}8x8")
1057 }
1058
1059 fn compile_wmma_instruction(
1060 f: &mut std::fmt::Formatter<'_>,
1061 instruction: &WmmaInstruction<Self>,
1062 ) -> std::fmt::Result {
1063 match instruction {
1064 WmmaInstruction::Fill { frag, value } => {
1065 match frag {
1066 Variable::WmmaFragment { .. } => {
1067 let ty = frag.elem();
1068 writeln!(
1070 f,
1071 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1072 )
1073 }
1074 _ => panic!("should be a fragment"),
1075 }
1076 }
1077 WmmaInstruction::Load {
1078 frag,
1079 value,
1080 stride,
1081 offset,
1082 layout: _layout,
1083 } => {
1084 let transpose = match frag {
1085 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1086 Some(FragmentLayout::RowMajor) => false,
1087 Some(FragmentLayout::ColMajor) => true,
1088 _ => false,
1089 },
1090 _ => panic!("should be a fragment"),
1091 };
1092 let item = value.item();
1093 if item.vectorization > 1 {
1094 let elem = item.elem;
1095 match value {
1096 Variable::GlobalInputArray(..) => writeln!(
1097 f,
1098 "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1099 ),
1100 Variable::SharedArray(..) => writeln!(
1101 f,
1102 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1103 ),
1104 _ => panic!(
1105 "Vectorized wmma load is only supported from global or shared memory."
1106 ),
1107 }
1108 } else {
1109 writeln!(
1110 f,
1111 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1112 )
1113 }
1114 }
1115 WmmaInstruction::Execute {
1116 frag_a: a,
1117 frag_b: b,
1118 frag_c: c,
1119 frag_d: d,
1120 ..
1121 } => {
1122 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1123 }
1124 WmmaInstruction::Store {
1125 output,
1126 frag,
1127 stride,
1128 offset,
1129 layout: _layout,
1130 } => {
1131 let item = output.item();
1132 let mut reinterpret_cast = item.vectorization > 1;
1133 let elem = match item.elem {
1134 Elem::BF16 => {
1135 reinterpret_cast = true;
1136 Elem::F16
1137 }
1138 _ => item.elem,
1139 };
1140 if reinterpret_cast {
1141 writeln!(
1142 f,
1143 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1144 )
1145 } else {
1146 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1147 }?;
1148 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1149 }
1150 WmmaInstruction::Cast { input, output } => {
1151 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1152 let ty = match output {
1153 Variable::WmmaFragment { frag, .. } => frag.elem,
1154 _ => panic!("should be a fragment"),
1155 };
1156 match ty {
1157 Elem::BF16 => {
1158 let addr_space = Self::address_space_for_variable(output);
1159 let elem = Elem::<Self>::F16;
1160 writeln!(
1163 f,
1164 "for(int e=0; e<8; e++) {{
1165 {ty} elem = {ty}({input}.thread_elements()[e]);
1166 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1167}}"
1168 )
1169 }
1170 _ => {
1171 writeln!(
1172 f,
1173 "for(int e=0; e<8; e++) {{
1174 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1175}}"
1176 )
1177 }
1178 }
1179 }
1180 WmmaInstruction::ExecuteManual {
1181 shape,
1182 frag_a,
1183 frag_b,
1184 frag_c,
1185 frag_d,
1186 } => {
1187 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1188 }
1189 WmmaInstruction::ExecuteScaled {
1190 shape,
1191 frag_a,
1192 frag_b,
1193 frag_c,
1194 frag_d,
1195 scales_a,
1196 scales_b,
1197 scales_factor,
1198 } => Self::compile_scaled_mma(
1199 f,
1200 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1201 *scales_a,
1202 *scales_b,
1203 *scales_factor,
1204 ),
1205 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1206 f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1207 }
1208 }
1209 }
1210
1211 fn compile_manual_mma(
1212 f: &mut std::fmt::Formatter<'_>,
1213 _mma: shared::ManualMma<Self>,
1214 ) -> std::fmt::Result {
1215 f.write_str("#error manual mma not supported on Metal\n")
1216 }
1217
1218 fn compile_scaled_mma(
1219 f: &mut std::fmt::Formatter<'_>,
1220 _mma: shared::ManualMma<Self>,
1221 _scales_a: Variable<Self>,
1222 _scales_b: Variable<Self>,
1223 _scales_factor: u32,
1224 ) -> std::fmt::Result {
1225 f.write_str("#error scaled mma not supported on Metal\n")
1226 }
1227
1228 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1229 let types = vec![
1230 (
1231 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1232 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1233 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1234 ),
1235 (
1236 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1237 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1238 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1239 ),
1240 (
1241 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1242 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1243 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1244 ),
1245 (
1246 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1247 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1248 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1249 ),
1250 ];
1251 types
1252 .into_iter()
1253 .map(|(a_type, b_type, cd_type)| MmaConfig {
1254 a_type,
1255 b_type,
1256 cd_type,
1257 m: 8,
1258 n: 8,
1259 k: 8,
1260 })
1261 .collect()
1262 }
1263
1264 fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1265 Vec::new()
1266 }
1267}
1268
1269impl DialectProcessors<Self> for MslDialect {
1272 fn processors() -> Vec<Box<dyn gpu::Processor>> {
1273 Vec::new()
1274 }
1275}