1use super::{
2 AddressSpace, Extension,
3 arch::MetalArchitecture,
4 extension::{format_ffs, format_mulhi},
5 format_erf, format_global_binding_arg, format_metal_builtin_binding_arg, format_safe_tanh,
6};
7use crate::{
8 Dialect,
9 shared::{
10 self, AtomicKind, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins,
11 DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes,
12 DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, KernelArg, ManualMma, SharedMemory,
14 SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base,
15 },
16};
17use core::panic;
18use cubecl_core::ir::{self as gpu, features::MmaConfig};
19use std::fmt::Display;
20
21#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
22pub struct MslDialect {}
23
24impl Dialect for MslDialect {
27 type Architecture = MetalArchitecture;
28}
29
30impl MslDialect {
31 fn warp_op_vectorized(
32 f: &mut core::fmt::Formatter<'_>,
33 input: &Variable<Self>,
34 out: &Variable<Self>,
35 simd_op_prefix: &str,
36 simd_op_suffix: &str,
37 ) -> core::fmt::Result {
38 let out = out.fmt_left();
39 let vectorization = input.item().vectorization;
40
41 f.write_fmt(format_args!("{out} = {} {{", input.item()))?;
42
43 for k in 0..vectorization {
44 let index = if vectorization > 1 {
45 format!(".i_{k}")
46 } else {
47 String::new()
48 };
49 let comma = if k + 1 < vectorization { "," } else { "" };
50
51 writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?;
52 }
53
54 f.write_fmt(format_args!("}};\n"))
55 }
56}
57
58impl DialectWarpReduceCompiler<Self> for MslDialect {
59 fn warp_reduce_sum(
60 f: &mut core::fmt::Formatter<'_>,
61 input: &Variable<Self>,
62 out: &Variable<Self>,
63 ) -> core::fmt::Result {
64 Self::warp_op_vectorized(f, input, out, "simd_sum(", ")")
65 }
66 fn warp_reduce_prod(
67 f: &mut core::fmt::Formatter<'_>,
68 input: &Variable<Self>,
69 out: &Variable<Self>,
70 ) -> core::fmt::Result {
71 Self::warp_op_vectorized(f, input, out, "simd_product(", ")")
72 }
73 fn warp_reduce_max(
74 f: &mut core::fmt::Formatter<'_>,
75 input: &Variable<Self>,
76 out: &Variable<Self>,
77 ) -> core::fmt::Result {
78 Self::warp_op_vectorized(f, input, out, "simd_max(", ")")
79 }
80 fn warp_reduce_min(
81 f: &mut core::fmt::Formatter<'_>,
82 input: &Variable<Self>,
83 out: &Variable<Self>,
84 ) -> core::fmt::Result {
85 Self::warp_op_vectorized(f, input, out, "simd_min(", ")")
86 }
87 fn warp_reduce_all(
88 f: &mut core::fmt::Formatter<'_>,
89 input: &Variable<Self>,
90 out: &Variable<Self>,
91 ) -> core::fmt::Result {
92 Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u")
93 }
94 fn warp_reduce_any(
95 f: &mut core::fmt::Formatter<'_>,
96 input: &Variable<Self>,
97 out: &Variable<Self>,
98 ) -> core::fmt::Result {
99 Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u")
100 }
101 fn warp_reduce_sum_inclusive(
102 f: &mut core::fmt::Formatter<'_>,
103 input: &Variable<Self>,
104 out: &Variable<Self>,
105 ) -> core::fmt::Result {
106 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")")
107 }
108 fn warp_reduce_prod_inclusive(
109 f: &mut core::fmt::Formatter<'_>,
110 input: &Variable<Self>,
111 out: &Variable<Self>,
112 ) -> core::fmt::Result {
113 Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")")
114 }
115 fn warp_reduce_sum_exclusive(
116 f: &mut core::fmt::Formatter<'_>,
117 input: &Variable<Self>,
118 out: &Variable<Self>,
119 ) -> core::fmt::Result {
120 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")")
121 }
122 fn warp_reduce_prod_exclusive(
123 f: &mut core::fmt::Formatter<'_>,
124 input: &Variable<Self>,
125 out: &Variable<Self>,
126 ) -> core::fmt::Result {
127 Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")")
128 }
129}
130
131impl DialectIncludes<Self> for MslDialect {
134 type Extension = Extension<Self>;
135
136 fn compile_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags<Self>) -> std::fmt::Result {
137 write!(
138 f,
139 "
140#include <metal_stdlib>
141using namespace metal;
142"
143 )?;
144 Ok(())
145 }
146
147 fn compile_extensions(
148 f: &mut std::fmt::Formatter<'_>,
149 extensions: &[Self::Extension],
150 ) -> std::fmt::Result {
151 for extension in extensions {
152 match extension {
153 Extension::Erf(input, output) => format_erf::<Self>(f, input, output)?,
154 Extension::Ffs(elem) => format_ffs(f, elem)?,
155 Extension::MulHi(elem) => format_mulhi(f, elem)?,
156 Extension::SafeTanh(item) => format_safe_tanh::<Self>(f, item)?,
157 Extension::NoExtension => {}
158 }
159 }
160 Ok(())
161 }
162
163 fn register_instruction_extension(
164 extensions: &mut Vec<Self::Extension>,
165 instruction: &Instruction<Self>,
166 ) {
167 let mut register_extension = |extension: Self::Extension| {
168 if !extensions.contains(&extension) {
169 extensions.push(extension);
170 }
171 };
172 #[allow(clippy::single_match)]
173 match instruction {
174 shared::Instruction::<Self>::Erf(instruction) => {
175 register_extension(Extension::Erf(
176 instruction.input.elem(),
177 instruction.out.elem(),
178 ));
179 }
180 shared::Instruction::<Self>::FindFirstSet(instruction) => {
181 let input_elem = instruction.input.elem();
182 match input_elem {
183 Elem::U32 | Elem::U64 => {
184 register_extension(Extension::Ffs(instruction.input.elem()));
185 }
186 Elem::I32 => {
187 register_extension(Extension::Ffs(Elem::<Self>::U32));
188 register_extension(Extension::Ffs(instruction.input.elem()));
189 }
190 Elem::I64 => {
191 register_extension(Extension::Ffs(Elem::<Self>::U64));
192 register_extension(Extension::Ffs(instruction.input.elem()));
193 }
194 _ => {
195 register_extension(Extension::Ffs(Elem::<Self>::U32));
196 }
197 }
198 }
199 shared::Instruction::<Self>::HiMul(instruction) => {
200 register_extension(Extension::MulHi(instruction.out.elem()));
201 }
202 shared::Instruction::<Self>::Tanh(instruction) => {
203 register_extension(Extension::SafeTanh(instruction.input.item()));
204 }
205 _ => {}
206 }
207 }
208
209 fn register_warp_instruction_extension(
210 _extensions: &mut Vec<Self::Extension>,
211 _instruction: &WarpInstruction<Self>,
212 ) {
213 }
214}
215
216impl DialectTypes<Self> for MslDialect {
219 fn item_can_be_optimized() -> bool {
220 false
221 }
222
223 fn compile_type_definitions(
224 f: &mut std::fmt::Formatter<'_>,
225 items: &std::collections::HashSet<crate::shared::Item<Self>>,
226 scalars: &[(Elem<Self>, usize)],
227 info: &cubecl_core::Info,
228 flags: &Flags<Self>,
229 ) -> std::fmt::Result {
230 for item in items.iter() {
231 let elem = item.elem;
232 let size = item.vectorization;
233 let alignment = elem.size() * size;
234 if size > 1 {
235 write!(
236 f,
237 "
238struct alignas({alignment}) {item} {{"
239 )?;
240
241 for i in 0..size {
242 write!(
243 f,
244 "
245 {elem} i_{i};"
246 )?;
247 }
248
249 f.write_str("\n};\n")?;
250 }
251 }
252
253 shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
254 Ok(())
255 }
256
257 fn compile_elem(
258 f: &mut std::fmt::Formatter<'_>,
259 elem: &shared::Elem<Self>,
260 _words: bool,
261 ) -> std::fmt::Result {
262 match elem {
264 shared::Elem::FP4(_)
265 | shared::Elem::FP4x2(_)
266 | shared::Elem::FP6(_)
267 | shared::Elem::FP6x2(_)
268 | shared::Elem::FP8(_)
269 | shared::Elem::FP8x2(_) => f.write_str("#error FP4/FP6/FP8 not supported in Metal\n"),
270 shared::Elem::F16 => f.write_str("half"),
271 shared::Elem::F16x2 => f.write_str("#error type F162 not supported!\n"),
272 shared::Elem::F32 => f.write_str("float"),
273 shared::Elem::F64 => f.write_str("#error type double not supported!\n"),
274 shared::Elem::BF16 => f.write_str("bfloat"),
275 shared::Elem::BF16x2 => f.write_str("#error type BF162 not supported!\n"),
276 shared::Elem::TF32 => f.write_str("float"),
277 shared::Elem::I8 => f.write_str("char"),
278 shared::Elem::I16 => f.write_str("short"),
279 shared::Elem::I32 => f.write_str("int"),
280 shared::Elem::I64 => f.write_str("long"),
281 shared::Elem::U8 => f.write_str("uchar"),
282 shared::Elem::U16 => f.write_str("ushort"),
283 shared::Elem::U32 => f.write_str("uint"),
284 shared::Elem::U64 => f.write_str("uint64_t"), shared::Elem::Bool => f.write_str("bool"),
286 shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"),
287 shared::Elem::Atomic(inner) => inner.fmt(f),
288 shared::Elem::_Dialect(_) => Ok(()),
289 }
290 }
291
292 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
293 if 1 == item.vectorization {
294 return write!(f, "{}", item.elem);
295 }
296 if item.native {
297 write!(f, "{}{}", item.elem, item.vectorization)
298 } else {
299 write!(f, "{}_{}", item.elem, item.vectorization)
300 }
301 }
302
303 fn compile_atomic_kind(
304 f: &mut std::fmt::Formatter<'_>,
305 kind: &AtomicKind<Self>,
306 ) -> std::fmt::Result {
307 match kind {
308 AtomicKind::I32 => write!(f, "atomic_int"),
309 AtomicKind::I64 => panic!("I64 atomic kind no supported."),
310 AtomicKind::U32 => write!(f, "atomic_uint"),
311 AtomicKind::U64 => write!(f, "atomic_ulong"),
312 AtomicKind::F16 | AtomicKind::F16x2 => panic!("F16 atomic kind no supported."),
313 AtomicKind::BF16 | AtomicKind::BF16x2 => panic!("BF16 atomic kind no supported."),
314 AtomicKind::F32 => write!(f, "atomic_float"), AtomicKind::F64 => panic!("F64 atomic kind no supported."),
316 AtomicKind::_Dialect(_) => Ok(()),
317 }
318 }
319
320 fn address_space_for_variable(variable: &Variable<Self>) -> String {
321 format!("{} ", AddressSpace::from(variable))
322 }
323
324 fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 write!(f, "thread")
326 }
327
328 fn compile_shared_memory_declaration(
329 f: &mut std::fmt::Formatter<'_>,
330 shared: &SharedMemory<Self>,
331 ) -> std::fmt::Result {
332 match shared {
333 SharedMemory::Array {
334 index,
335 item,
336 length,
337 offset,
338 ..
339 } => {
340 let size_bytes = length * item.size();
341 writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
342 writeln!(
343 f,
344 "threadgroup {item}* shared_memory_{index} = reinterpret_cast<threadgroup {item}*>(&dynamic_shared_mem[{offset}]);"
345 )
346 }
347 SharedMemory::Value {
348 index,
349 item,
350 offset,
351 ..
352 } => {
353 let size_bytes = item.size();
354 writeln!(f, "// Shared value size: {size_bytes} bytes")?;
355 writeln!(
356 f,
357 "threadgroup {item}& shared_memory_{index} = reinterpret_cast<threadgroup {item}&>(dynamic_shared_mem[{offset}]);"
358 )
359 }
360 }
361 }
362}
363
364impl DialectBindings<Self> for MslDialect {
367 fn compile_kernel_signature(
368 f: &mut std::fmt::Formatter<'_>,
369 kernel_name: &str,
370 tensor_maps: &[KernelArg<Self>],
371 buffers: &[KernelArg<Self>],
372 flags: &Flags<Self>,
373 ) -> std::fmt::Result {
374 write!(
375 (f),
376 "
377[[kernel]]
378void {kernel_name}("
379 )?;
380 let mut buffer_idx = 0;
382 debug_assert!(
383 tensor_maps.is_empty(),
384 "Tensor maps aren't supported for metal"
385 );
386 for (i, b) in buffers.iter().enumerate() {
387 format_global_binding_arg("buffer", b, Some(&i.to_string()), &mut buffer_idx, f)?;
388 }
389
390 if flags.has_info {
391 let comma = if buffer_idx > 0 { "," } else { "" };
392 let (address_space, var) = match flags.has_dynamic_meta {
393 true => (AddressSpace::ConstDevice, "info_st* info_ptr"),
394 false => (AddressSpace::Constant, "info_st& info"),
395 };
396 let attribute = address_space.attribute();
397
398 write!(f, "{comma}\n {address_space} {var}",)?;
399 attribute.indexed_fmt(buffer_idx, f)?;
401 buffer_idx += 1;
402 }
403
404 let builtins = vec![
406 (
407 flags.indexes.absolute_pos_tuple,
408 Variable::<Self>::AbsolutePosBaseName,
409 ),
410 (
411 flags.indexes.cube_dim_tuple,
412 Variable::<Self>::CubeDimBaseName,
413 ),
414 (
415 flags.indexes.cube_count_tuple,
416 Variable::<Self>::CubeCountBaseName,
417 ),
418 (flags.indexes.unit_pos, Variable::<Self>::UnitPos),
419 (
420 flags.indexes.unit_pos_tuple,
421 Variable::<Self>::UnitPosBaseName,
422 ),
423 (
424 flags.indexes.cube_pos_tuple,
425 Variable::<Self>::CubePosBaseName,
426 ),
427 (flags.indexes.unit_pos_plane, Variable::<Self>::UnitPosPlane),
428 (flags.indexes.plane_dim, Variable::<Self>::PlaneDim),
429 (flags.indexes.plane_pos, Variable::<Self>::PlanePos),
430 ];
431 let comma = buffer_idx > 0;
432 builtins
433 .iter()
434 .filter(|(cond, _)| *cond)
435 .try_for_each(|(_, var)| format_metal_builtin_binding_arg(f, var, comma))?;
436 f.write_str("\n)")
437 }
438
439 fn compile_bindings_body(
440 f: &mut std::fmt::Formatter<'_>,
441 body: &shared::Body<Self>,
442 ) -> std::fmt::Result {
443 if !body.shared_memories.is_empty() {
444 let size = body
445 .shared_memories
446 .iter()
447 .map(|it| it.offset() + it.size())
448 .max()
449 .unwrap();
450
451 writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?;
452 }
453 if body.info_by_ptr && body.has_dynamic_meta {
454 let address_space = AddressSpace::ConstDevice;
455 writeln!(f, "const {address_space} info_st& info = *info_ptr;")?;
456 writeln!(
458 f,
459 "const {address_space} {addr}* dynamic_meta = reinterpret_cast<const {address_space} {addr}*>(
460 reinterpret_cast<const {address_space} char*>(info_ptr) + sizeof(info_st)
461 );\n",
462 addr = body.address_type,
463 )?;
464 }
465 Ok(())
466 }
467}
468
469impl DialectCubeBuiltins<Self> for MslDialect {
472 fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
477 let absolute_pos = flags.absolute_pos;
478 let cube_count = flags.cube_count;
479 let cube_dim = flags.cube_dim;
480 let cube_pos = flags.cube_pos;
481 let plane_dim_checked = flags.plane_dim_checked;
482 let plane_index = flags.plane_pos;
483 let unit_pos = flags.unit_pos;
484 let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
485 let cube_count_tuple = flags.cube_count_tuple || cube_count || cube_pos || absolute_pos;
486 let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
487 let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
488 let cluster_pos = flags.cluster_pos;
489 let plane_dim = flags.plane_dim || plane_dim_checked || plane_index;
490 let unit_pos_plane = flags.unit_pos_plane || plane_index;
491 let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
492 CubeIndexFlags {
493 absolute_pos_tuple,
494 absolute_pos,
495 cube_count_tuple,
496 cube_count,
497 cube_dim_tuple,
498 cube_dim,
499 cube_pos_tuple,
500 cube_pos,
501 plane_dim,
502 plane_dim_checked,
503 plane_pos: plane_index,
504 unit_pos_tuple,
505 unit_pos,
506 unit_pos_plane,
507 cluster_pos,
508 }
509 }
510
511 fn compile_absolute_pos_tuple_computation(
512 _f: &mut std::fmt::Formatter<'_>,
513 ) -> std::fmt::Result {
514 Ok(())
516 }
517
518 fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.write_str("thread_pos_in_grid")
520 }
521
522 fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 f.write_str("thread_index_in_grid")
524 }
525
526 fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 Self::compile_absolute_pos_base_name(f)?;
528 write!(f, ".x")
529 }
530
531 fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532 Self::compile_absolute_pos_base_name(f)?;
533 write!(f, ".y")
534 }
535
536 fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 Self::compile_absolute_pos_base_name(f)?;
538 write!(f, ".z")
539 }
540
541 fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
542 f.write_str("threadgroups_per_grid")
543 }
544
545 fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.write_str("total_threadgroups_in_grid")
547 }
548
549 fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550 Self::compile_cube_count_base_name(f)?;
551 write!(f, ".x")
552 }
553
554 fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 Self::compile_cube_count_base_name(f)?;
556 write!(f, ".y")
557 }
558
559 fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560 Self::compile_cube_count_base_name(f)?;
561 write!(f, ".z")
562 }
563
564 fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565 f.write_str("threads_per_threadgroup")
566 }
567
568 fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569 f.write_str("total_thread_in_threadgroup")
570 }
571
572 fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573 Self::compile_cube_dim_base_name(f)?;
574 write!(f, ".x")
575 }
576
577 fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
578 Self::compile_cube_dim_base_name(f)?;
579 write!(f, ".y")
580 }
581
582 fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 Self::compile_cube_dim_base_name(f)?;
584 write!(f, ".z")
585 }
586
587 fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588 f.write_str("threadgroup_pos_in_grid")
589 }
590
591 fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 f.write_str("threadgroup_index_in_grid")
593 }
594
595 fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 Self::compile_cube_pos_base_name(f)?;
597 write!(f, ".x")
598 }
599
600 fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601 Self::compile_cube_pos_base_name(f)?;
602 write!(f, ".y")
603 }
604
605 fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 Self::compile_cube_pos_base_name(f)?;
607 write!(f, ".z")
608 }
609
610 fn compile_unit_pos_computation(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611 Ok(())
613 }
614
615 fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616 f.write_str("thread_pos_in_threadgroup")
617 }
618
619 fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.write_str("thread_index_in_threadgroup")
621 }
622
623 fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624 Self::compile_unit_pos_base_name(f)?;
625 write!(f, ".x")
626 }
627
628 fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629 Self::compile_unit_pos_base_name(f)?;
630 write!(f, ".y")
631 }
632
633 fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634 Self::compile_unit_pos_base_name(f)?;
635 write!(f, ".z")
636 }
637
638 fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.write_str("simd_size")
640 }
641
642 fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
643 f.write_str("threads_per_simdgroup_checked")
644 }
645
646 fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
647 f.write_str("simd_group_id")
648 }
649
650 fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
651 f.write_str("simd_lane_id")
652 }
653}
654
655impl DialectInstructions<Self> for MslDialect {
658 fn compile_atomic_add(
660 f: &mut std::fmt::Formatter<'_>,
661 lhs: &Variable<Self>,
662 rhs: &Variable<Self>,
663 out: &Variable<Self>,
664 ) -> std::fmt::Result {
665 let out = out.fmt_left();
666 writeln!(
667 f,
668 "{out} = atomic_fetch_add_explicit({lhs}, {rhs}, memory_order_relaxed);"
669 )
670 }
671
672 fn compile_atomic_and(
673 f: &mut std::fmt::Formatter<'_>,
674 lhs: &Variable<Self>,
675 rhs: &Variable<Self>,
676 out: &Variable<Self>,
677 ) -> std::fmt::Result {
678 let out = out.fmt_left();
679 writeln!(
680 f,
681 "{out} = atomic_fetch_and_explicit({lhs}, {rhs}, memory_order_relaxed);"
682 )
683 }
684
685 fn compile_atomic_cas(
686 f: &mut std::fmt::Formatter<'_>,
687 input: &Variable<Self>,
688 cmp: &Variable<Self>,
689 val: &Variable<Self>,
690 out: &Variable<Self>,
691 ) -> std::fmt::Result {
692 let out = out.fmt_left();
693 writeln!(
694 f,
695 "{out} = atomic_compare_exchange_weak_explicit({input}, &{cmp}, {val}, memory_order_relaxed, memory_order_relaxed);"
696 )
697 }
698
699 fn compile_atomic_load(
700 f: &mut std::fmt::Formatter<'_>,
701 input: &Variable<Self>,
702 out: &Variable<Self>,
703 ) -> std::fmt::Result {
704 let out = out.fmt_left();
705 writeln!(
706 f,
707 "{out} = atomic_load_explicit({input}, memory_order_relaxed);"
708 )
709 }
710
711 fn compile_atomic_max(
712 f: &mut std::fmt::Formatter<'_>,
713 lhs: &Variable<Self>,
714 rhs: &Variable<Self>,
715 out: &Variable<Self>,
716 ) -> std::fmt::Result {
717 let out = out.fmt_left();
718 writeln!(
719 f,
720 "{out} = atomic_fetch_max_explicit({lhs}, {rhs}, memory_order_relaxed);"
721 )
722 }
723
724 fn compile_atomic_min(
725 f: &mut std::fmt::Formatter<'_>,
726 lhs: &Variable<Self>,
727 rhs: &Variable<Self>,
728 out: &Variable<Self>,
729 ) -> std::fmt::Result {
730 let out = out.fmt_left();
731 writeln!(
732 f,
733 "{out} = atomic_fetch_min_explicit({lhs}, {rhs}, memory_order_relaxed);"
734 )
735 }
736
737 fn compile_atomic_or(
738 f: &mut std::fmt::Formatter<'_>,
739 lhs: &Variable<Self>,
740 rhs: &Variable<Self>,
741 out: &Variable<Self>,
742 ) -> std::fmt::Result {
743 let out = out.fmt_left();
744 writeln!(
745 f,
746 "{out} = atomic_fetch_or_explicit({lhs}, {rhs}, memory_order_relaxed);"
747 )
748 }
749
750 fn compile_atomic_store(
751 f: &mut std::fmt::Formatter<'_>,
752 input: &Variable<Self>,
753 out: &Variable<Self>,
754 ) -> std::fmt::Result {
755 writeln!(
756 f,
757 "atomic_store_explicit({out}, {input}, memory_order_relaxed);"
758 )
759 }
760
761 fn compile_atomic_sub(
762 f: &mut std::fmt::Formatter<'_>,
763 lhs: &Variable<Self>,
764 rhs: &Variable<Self>,
765 out: &Variable<Self>,
766 ) -> std::fmt::Result {
767 let out = out.fmt_left();
768 writeln!(
769 f,
770 "{out} = atomic_fetch_sub_explicit({lhs}, {rhs}, memory_order_relaxed);"
771 )
772 }
773
774 fn compile_atomic_swap(
775 f: &mut std::fmt::Formatter<'_>,
776 lhs: &Variable<Self>,
777 rhs: &Variable<Self>,
778 out: &Variable<Self>,
779 ) -> std::fmt::Result {
780 let out = out.fmt_left();
781 writeln!(
782 f,
783 "{out} = atomic_exchange_explicit({lhs}, {rhs}, memory_order_relaxed);"
784 )
785 }
786
787 fn compile_atomic_xor(
788 f: &mut std::fmt::Formatter<'_>,
789 lhs: &Variable<Self>,
790 rhs: &Variable<Self>,
791 out: &Variable<Self>,
792 ) -> std::fmt::Result {
793 let out = out.fmt_left();
794 writeln!(
795 f,
796 "{out} = atomic_fetch_xor_explicit({lhs}, {rhs}, memory_order_relaxed);"
797 )
798 }
799
800 fn compile_saturating_add(
801 f: &mut std::fmt::Formatter<'_>,
802 lhs: impl Display,
803 rhs: impl Display,
804 _item: Item<Self>,
805 ) -> std::fmt::Result {
806 write!(f, "addsat({lhs}, {rhs})")
807 }
808
809 fn compile_saturating_sub(
810 f: &mut std::fmt::Formatter<'_>,
811 lhs: impl Display,
812 rhs: impl Display,
813 _item: Item<Self>,
814 ) -> std::fmt::Result {
815 write!(f, "subsat({lhs}, {rhs})")
816 }
817
818 fn compile_instruction_printf(
820 f: &mut std::fmt::Formatter<'_>,
821 format_string: &str,
822 args: &[Variable<Self>],
823 ) -> std::fmt::Result {
824 let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
825 let args = match args.is_empty() {
826 true => "".to_string(),
827 false => format!(", {}", args.join(",")),
828 };
829 writeln!(f, "os_log_default.log({format_string:?}{args});")
830 }
831
832 fn compile_instruction_log1p_scalar<T: Component<Self>>(
834 f: &mut std::fmt::Formatter<'_>,
835 input: T,
836 ) -> std::fmt::Result {
837 match input.elem() {
838 Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
839 write!(f, "log(half(1.0f) + {input})")
840 }
841 _ => write!(f, "log(1.0f + {input})"),
842 }
843 }
844
845 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
847 writeln!(f, "threadgroup_barrier(mem_flags::mem_threadgroup);")
848 }
849
850 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
851 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
852 }
853
854 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
855 writeln!(f, "threadgroup_thread_fence(mem_flags::mem_device);")
856 }
857
858 fn compile_instruction_tanh_scalar<T: Component<Self>>(
860 f: &mut std::fmt::Formatter<'_>,
861 input: T,
862 ) -> std::fmt::Result {
863 write!(f, "safe_tanh_scalar({input})")
864 }
865
866 fn compile_instruction_find_first_set<T: Component<Self>>(
868 f: &mut std::fmt::Formatter<'_>,
869 input: T,
870 out_elem: Elem<Self>,
871 ) -> std::fmt::Result {
872 write!(f, "{out_elem}(")?;
873 match input.elem() {
874 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
875 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
876 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
877 }?;
878 write!(f, ")")
879 }
880
881 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
882 f: &mut std::fmt::Formatter<'_>,
883 input: T,
884 out_elem: Elem<Self>,
885 ) -> std::fmt::Result {
886 write!(f, "{out_elem}(clz({input}))")
887 }
888
889 fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
890 f: &mut std::fmt::Formatter<'_>,
891 input: T,
892 out_elem: Elem<Self>,
893 ) -> std::fmt::Result {
894 write!(f, "{out_elem}(ctz({input}))")
895 }
896
897 fn compile_instruction_popcount_scalar<T: Component<Self>>(
898 f: &mut std::fmt::Formatter<'_>,
899 input: T,
900 out_elem: Elem<Self>,
901 ) -> std::fmt::Result {
902 write!(f, "{out_elem}(")?;
903 match input.elem() {
904 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "popcount({input})"),
905 _ => write!(f, "popcount({})", shared::unary::zero_extend(input)),
906 }?;
907 write!(f, ")")
908 }
909
910 fn compile_instruction_reverse_bits_scalar<T: Component<Self>>(
911 f: &mut std::fmt::Formatter<'_>,
912 input: T,
913 out_elem: Elem<Self>,
914 ) -> std::fmt::Result {
915 write!(f, "{out_elem}(")?;
916 match out_elem {
917 Elem::I32 | Elem::U32 | Elem::I64 | Elem::U64 => write!(f, "reverse_bits({input})"),
918 _ => write!(
919 f,
920 "reverse_bits({}) >> {}",
921 shared::unary::zero_extend(input),
922 (size_of::<u32>() - out_elem.size()) * 8
923 ),
924 }?;
925 write!(f, ")")
926 }
927
928 fn compile_instruction_max_function_name(
930 f: &mut std::fmt::Formatter<'_>,
931 _item: Item<Self>,
932 ) -> std::fmt::Result {
933 write!(f, "max")
934 }
935
936 fn compile_instruction_min_function_name(
937 f: &mut std::fmt::Formatter<'_>,
938 _item: Item<Self>,
939 ) -> std::fmt::Result {
940 write!(f, "min")
941 }
942
943 fn compile_instruction_powf(
944 f: &mut std::fmt::Formatter<'_>,
945 lhs: &str,
946 rhs: &str,
947 elem: Elem<Self>,
948 ) -> std::fmt::Result {
949 write!(f, "pow({lhs}, {elem}({rhs}))")
950 }
951
952 fn compile_instruction_hypot(
953 f: &mut std::fmt::Formatter<'_>,
954 lhs: &str,
955 rhs: &str,
956 elem: Elem<Self>,
957 ) -> std::fmt::Result {
958 match elem {
959 Elem::F32 => write!(f, "length(float2({lhs}, {rhs}))"),
960 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
961 }
962 }
963
964 fn compile_instruction_rhypot(
965 f: &mut std::fmt::Formatter<'_>,
966 lhs: &str,
967 rhs: &str,
968 elem: Elem<Self>,
969 ) -> std::fmt::Result {
970 match elem {
971 Elem::F32 => write!(f, "rsqrt({lhs} * {lhs} + {rhs} * {rhs})"),
972 _ => write!(f, "#error Unsupported type for hypot: {elem}"),
973 }
974 }
975
976 fn compile_instruction_half_function_name_prefix() -> &'static str {
977 ""
978 }
979
980 fn compile_instruction_half2_function_name_prefix() -> &'static str {
981 ""
982 }
983
984 fn compile_warp_shuffle(
986 f: &mut std::fmt::Formatter<'_>,
987 var: &str,
988 source: &str,
989 ) -> std::fmt::Result {
990 write!(f, "simd_shuffle({var}, {source})")
991 }
992
993 fn compile_warp_shuffle_xor(
994 f: &mut std::fmt::Formatter<'_>,
995 var: &str,
996 _elem: &Elem<Self>,
997 offset: &str,
998 ) -> std::fmt::Result {
999 write!(f, "simd_shuffle_xor({var}, {offset})")
1000 }
1001
1002 fn compile_warp_shuffle_up(
1003 f: &mut std::fmt::Formatter<'_>,
1004 var: &str,
1005 offset: &str,
1006 ) -> std::fmt::Result {
1007 write!(f, "simd_shuffle_up({var}, {offset})")
1008 }
1009
1010 fn compile_warp_shuffle_down(
1011 f: &mut std::fmt::Formatter<'_>,
1012 var: &str,
1013 offset: &str,
1014 ) -> std::fmt::Result {
1015 write!(f, "simd_shuffle_down({var}, {offset})")
1016 }
1017
1018 fn compile_warp_all<T: Component<Self>>(
1019 f: &mut std::fmt::Formatter<'_>,
1020 input: &T,
1021 ) -> std::fmt::Result {
1022 write!(f, "simd_all({input})")
1023 }
1024
1025 fn compile_warp_any<T: Component<Self>>(
1026 f: &mut std::fmt::Formatter<'_>,
1027 input: &T,
1028 ) -> std::fmt::Result {
1029 write!(f, "simd_any({input})")
1030 }
1031
1032 fn compile_warp_ballot(
1033 f: &mut std::fmt::Formatter<'_>,
1034 input: &Variable<Self>,
1035 out_elem: &Elem<Self>,
1036 ) -> std::fmt::Result {
1037 write!(f, "{out_elem}(uint64_t(simd_ballot({input})))")
1038 }
1039
1040 fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041 write!(f, "__builtin_unreachable();")
1042 }
1043}
1044
1045impl DialectWmmaCompiler<Self> for MslDialect {
1048 fn compile_wmma_includes(
1049 f: &mut std::fmt::Formatter<'_>,
1050 _flags: &Flags<Self>,
1051 ) -> std::fmt::Result {
1052 writeln!(f, "#include <metal_simdgroup_matrix>")
1053 }
1054
1055 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1056 Ok(())
1058 }
1059
1060 fn compile_wmma_fragment_declaration(
1061 f: &mut std::fmt::Formatter<'_>,
1062 var: &crate::shared::Variable<MslDialect>,
1063 ) -> std::fmt::Result {
1064 wmma_api_base::compile_fragment_declaration(f, var)
1065 }
1066
1067 fn compile_wwma_fragment_ident(
1068 _f: &mut std::fmt::Formatter<'_>,
1069 _ident: &FragmentIdent<Self>,
1070 ) -> std::fmt::Result {
1071 Ok(())
1073 }
1074
1075 fn compile_wmma_fragment_layout(
1076 _f: &mut std::fmt::Formatter<'_>,
1077 _layout: &FragmentLayout<Self>,
1078 ) -> std::fmt::Result {
1079 Ok(())
1081 }
1082
1083 fn compile_wmma_fragment(
1084 f: &mut std::fmt::Formatter<'_>,
1085 fragment: &Fragment<Self>,
1086 ) -> std::fmt::Result {
1087 let ty = fragment.elem;
1088 let m = fragment.m;
1090 let n = fragment.n;
1091 let k = fragment.k;
1092 if m != 8 || n != 8 || k != 8 {
1093 panic!("{m}x{n}x{k} fragments not supported. Only 8x8x8 fragments are supported.");
1094 }
1095 write!(f, "simdgroup_{ty}8x8")
1096 }
1097
1098 fn compile_wmma_instruction(
1099 f: &mut std::fmt::Formatter<'_>,
1100 instruction: &WmmaInstruction<Self>,
1101 ) -> std::fmt::Result {
1102 match instruction {
1103 WmmaInstruction::Fill { frag, value } => {
1104 match frag {
1105 Variable::WmmaFragment { .. } => {
1106 let ty = frag.elem();
1107 writeln!(
1109 f,
1110 "{frag} = make_filled_simdgroup_matrix<{ty}, 8, 8>({value});"
1111 )
1112 }
1113 _ => panic!("should be a fragment"),
1114 }
1115 }
1116 WmmaInstruction::Load {
1117 frag,
1118 value,
1119 stride,
1120 offset,
1121 layout: _layout,
1122 } => {
1123 let transpose = match frag {
1124 Variable::WmmaFragment { frag: inner, .. } => match inner.layout {
1125 Some(FragmentLayout::RowMajor) => false,
1126 Some(FragmentLayout::ColMajor) => true,
1127 _ => false,
1128 },
1129 _ => panic!("should be a fragment"),
1130 };
1131 let item = value.item();
1132 if item.vectorization > 1 {
1133 let elem = item.elem;
1134 match value {
1135 Variable::GlobalInputArray(..) => writeln!(
1136 f,
1137 "simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1138 ),
1139 Variable::SharedArray(..) => writeln!(
1140 f,
1141 "simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1142 ),
1143 _ => panic!(
1144 "Vectorized wmma load is only supported from global or shared memory."
1145 ),
1146 }
1147 } else {
1148 writeln!(
1149 f,
1150 "simdgroup_load({frag}, {value} + {offset}, {stride}, 0, {transpose});"
1151 )
1152 }
1153 }
1154 WmmaInstruction::Execute {
1155 frag_a: a,
1156 frag_b: b,
1157 frag_c: c,
1158 frag_d: d,
1159 ..
1160 } => {
1161 writeln!(f, "simdgroup_multiply_accumulate({d}, {a}, {b}, {c});")
1162 }
1163 WmmaInstruction::Store {
1164 output,
1165 frag,
1166 stride,
1167 offset,
1168 layout: _layout,
1169 } => {
1170 let item = output.item();
1171 let mut reinterpret_cast = item.vectorization > 1;
1172 let elem = match item.elem {
1173 Elem::BF16 => {
1174 reinterpret_cast = true;
1175 Elem::F16
1176 }
1177 _ => item.elem,
1178 };
1179 if reinterpret_cast {
1180 writeln!(
1181 f,
1182 "simdgroup_store({frag}, reinterpret_cast<threadgroup {elem} *>({output} + {offset}), {stride});"
1183 )
1184 } else {
1185 writeln!(f, "simdgroup_store({frag}, {output} + {offset}, {stride});")
1186 }?;
1187 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")
1188 }
1189 WmmaInstruction::Cast { input, output } => {
1190 writeln!(f, "simdgroup_barrier(mem_flags::mem_none);")?;
1191 let ty = match output {
1192 Variable::WmmaFragment { frag, .. } => frag.elem,
1193 _ => panic!("should be a fragment"),
1194 };
1195 match ty {
1196 Elem::BF16 => {
1197 let addr_space = Self::address_space_for_variable(output);
1198 let elem = Elem::<Self>::F16;
1199 writeln!(
1202 f,
1203 "for(int e=0; e<8; e++) {{
1204 {ty} elem = {ty}({input}.thread_elements()[e]);
1205 {output}.thread_elements()[e] = *reinterpret_cast<{addr_space}{elem} *>(&elem);
1206}}"
1207 )
1208 }
1209 _ => {
1210 writeln!(
1211 f,
1212 "for(int e=0; e<8; e++) {{
1213 {output}.thread_elements()[e] = {ty}({input}.thread_elements()[e]);
1214}}"
1215 )
1216 }
1217 }
1218 }
1219 WmmaInstruction::ExecuteManual {
1220 shape,
1221 frag_a,
1222 frag_b,
1223 frag_c,
1224 frag_d,
1225 } => {
1226 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
1227 }
1228 WmmaInstruction::ExecuteScaled {
1229 shape,
1230 frag_a,
1231 frag_b,
1232 frag_c,
1233 frag_d,
1234 scales_a,
1235 scales_b,
1236 scales_factor,
1237 } => Self::compile_scaled_mma(
1238 f,
1239 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
1240 *scales_a,
1241 *scales_b,
1242 *scales_factor,
1243 ),
1244 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
1245 f.write_str("#error WmmaInstruction Ld & St Matrix not supported on Metal\n")
1246 }
1247 }
1248 }
1249
1250 fn compile_manual_mma(
1251 f: &mut std::fmt::Formatter<'_>,
1252 _mma: shared::ManualMma<Self>,
1253 ) -> std::fmt::Result {
1254 f.write_str("#error manual mma not supported on Metal\n")
1255 }
1256
1257 fn compile_scaled_mma(
1258 f: &mut std::fmt::Formatter<'_>,
1259 _mma: shared::ManualMma<Self>,
1260 _scales_a: Variable<Self>,
1261 _scales_b: Variable<Self>,
1262 _scales_factor: u32,
1263 ) -> std::fmt::Result {
1264 f.write_str("#error scaled mma not supported on Metal\n")
1265 }
1266
1267 fn supported_wmma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1268 let types = vec![
1269 (
1270 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1271 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1272 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1273 ),
1274 (
1275 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1276 gpu::ElemType::Float(gpu::FloatKind::F16).into(),
1277 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1278 ),
1279 (
1280 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1281 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1282 gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
1283 ),
1284 (
1285 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1286 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1287 gpu::ElemType::Float(gpu::FloatKind::F32).into(),
1288 ),
1289 ];
1290 types
1291 .into_iter()
1292 .map(|(a_type, b_type, cd_type)| MmaConfig {
1293 a_type,
1294 b_type,
1295 cd_type,
1296 m: 8,
1297 n: 8,
1298 k: 8,
1299 })
1300 .collect()
1301 }
1302
1303 fn supported_mma_combinations(_arch: &MetalArchitecture) -> SupportedMmaCombinations {
1304 Vec::new()
1305 }
1306}
1307
1308impl DialectProcessors<Self> for MslDialect {
1311 fn processors() -> Vec<Box<dyn gpu::Processor>> {
1312 Vec::new()
1313 }
1314}