cubecl_cpp/metal/
binding.rs1use cubecl_core::compute::Visibility;
2
3use crate::{
4 Dialect,
5 metal::AddressSpace,
6 shared::{Binding, Component, MslComputeKernel, Variable},
7};
8
9pub fn bindings(repr: &MslComputeKernel) -> (Vec<Visibility>, Vec<Visibility>) {
10 let mut bindings: Vec<Visibility> = vec![];
11 for b in repr.buffers.iter() {
13 bindings.push(b.vis);
14 }
15 let mut meta: Vec<Visibility> = vec![];
16 if repr.meta_static_len > 0 {
17 meta.push(Visibility::Read);
18 }
19 for _ in repr.scalars.iter() {
20 meta.push(Visibility::Read);
21 }
22 (bindings, meta)
23}
24
25pub fn format_global_binding_arg<D: Dialect>(
26 name: &str,
27 binding: &Binding<D>,
28 suffix: Option<&str>,
29 attr_idx: &mut usize,
30 f: &mut core::fmt::Formatter<'_>,
31) -> core::fmt::Result {
32 let suffix = suffix.map_or("".into(), |s| format!("_{s}"));
33 let (pointer, size) = match binding.size {
34 Some(size) => ("".to_string(), format!("[{size}]")),
35 None => (" *".to_string(), "".to_string()),
36 };
37
38 let comma = if *attr_idx > 0 { "," } else { "" };
39 let address_space = AddressSpace::from(binding);
40 let ty = binding.item;
41 let attribute = address_space.attribute();
42
43 write!(
44 f,
45 "{comma}\n {address_space} {ty}{pointer} {name}{suffix}",
46 )?;
47 attribute.indexed_fmt(*attr_idx, f)?;
49 write!(f, "{size}")?;
50 *attr_idx += 1;
51 Ok(())
52}
53
54pub fn format_metal_builtin_binding_arg<D: Dialect>(
55 f: &mut core::fmt::Formatter<'_>,
56 variable: &Variable<D>,
57 comma: bool,
58) -> core::fmt::Result {
59 let ty = variable.item();
60 let attribute = variable.attribute();
61 let comma = if comma { "," } else { "" };
62 write!(f, "{comma}\n {ty} {variable} {attribute}",)?;
63 Ok(())
64}