cubecl_cpp/metal/
binding.rs

1use cubecl_core::compute::Visibility;
2
3use crate::{
4    Dialect,
5    metal::AddressSpace,
6    shared::{Binding, Component, MslComputeKernel, Variable},
7};
8
9pub fn bindings(repr: &MslComputeKernel) -> Vec<(usize, Visibility)> {
10    let mut bindings: Vec<(usize, Visibility)> = vec![];
11    // must be in the same order as the compilation order: inputs, outputs and named
12    let mut buffer_idx = 0;
13    for b in repr.buffers.iter() {
14        bindings.push((buffer_idx, b.vis));
15        buffer_idx += 1;
16    }
17    if repr.meta_static_len > 0 {
18        bindings.push((buffer_idx, Visibility::Read));
19        buffer_idx += 1;
20    }
21    for _ in repr.scalars.iter() {
22        bindings.push((buffer_idx, Visibility::Read));
23        buffer_idx += 1;
24    }
25    bindings
26}
27
28pub fn format_global_binding_arg<D: Dialect>(
29    name: &str,
30    binding: &Binding<D>,
31    suffix: Option<&str>,
32    attr_idx: &mut usize,
33    f: &mut core::fmt::Formatter<'_>,
34) -> core::fmt::Result {
35    let suffix = suffix.map_or("".into(), |s| format!("_{s}"));
36    let (pointer, size) = match binding.size {
37        Some(size) => ("".to_string(), format!("[{}]", size)),
38        None => (" *".to_string(), "".to_string()),
39    };
40
41    let comma = if *attr_idx > 0 { "," } else { "" };
42    let address_space = AddressSpace::from(binding);
43    let ty = binding.item;
44    let attribute = address_space.attribute();
45
46    write!(
47        f,
48        "{comma}\n    {address_space} {ty}{pointer} {name}{suffix}",
49    )?;
50    // attribute
51    attribute.indexed_fmt(*attr_idx, f)?;
52    write!(f, "{size}")?;
53    *attr_idx += 1;
54    Ok(())
55}
56
57pub fn format_metal_builtin_binding_arg<D: Dialect>(
58    f: &mut core::fmt::Formatter<'_>,
59    variable: &Variable<D>,
60    comma: bool,
61) -> core::fmt::Result {
62    let ty = variable.item();
63    let attribute = variable.attribute();
64    let comma = if comma { "," } else { "" };
65    write!(f, "{comma}\n    {ty} {variable} {attribute}",)?;
66    Ok(())
67}