cubecl_cpp/metal/
binding.rs

1use cubecl_core::compute::Visibility;
2
3use crate::{
4    Dialect,
5    metal::AddressSpace,
6    shared::{Binding, Component, MslComputeKernel, Variable},
7};
8
9pub fn bindings(repr: &MslComputeKernel) -> (Vec<Visibility>, Vec<Visibility>) {
10    let mut bindings: Vec<Visibility> = vec![];
11    // must be in the same order as the compilation order: inputs, outputs and named
12    for b in repr.buffers.iter() {
13        bindings.push(b.vis);
14    }
15    let mut meta: Vec<Visibility> = vec![];
16    if repr.meta_static_len > 0 {
17        meta.push(Visibility::Read);
18    }
19    for _ in repr.scalars.iter() {
20        meta.push(Visibility::Read);
21    }
22    (bindings, meta)
23}
24
25pub fn format_global_binding_arg<D: Dialect>(
26    name: &str,
27    binding: &Binding<D>,
28    suffix: Option<&str>,
29    attr_idx: &mut usize,
30    f: &mut core::fmt::Formatter<'_>,
31) -> core::fmt::Result {
32    let suffix = suffix.map_or("".into(), |s| format!("_{s}"));
33    let (pointer, size) = match binding.size {
34        Some(size) => ("".to_string(), format!("[{size}]")),
35        None => (" *".to_string(), "".to_string()),
36    };
37
38    let comma = if *attr_idx > 0 { "," } else { "" };
39    let address_space = AddressSpace::from(binding);
40    let ty = binding.item;
41    let attribute = address_space.attribute();
42
43    write!(
44        f,
45        "{comma}\n    {address_space} {ty}{pointer} {name}{suffix}",
46    )?;
47    // attribute
48    attribute.indexed_fmt(*attr_idx, f)?;
49    write!(f, "{size}")?;
50    *attr_idx += 1;
51    Ok(())
52}
53
54pub fn format_metal_builtin_binding_arg<D: Dialect>(
55    f: &mut core::fmt::Formatter<'_>,
56    variable: &Variable<D>,
57    comma: bool,
58) -> core::fmt::Result {
59    let ty = variable.item();
60    let attribute = variable.attribute();
61    let comma = if comma { "," } else { "" };
62    write!(f, "{comma}\n    {ty} {variable} {attribute}",)?;
63    Ok(())
64}