Skip to main content

cubecl_cpp/metal/
binding.rs

1use cubecl_core::{prelude::Visibility, server::KernelArguments};
2
3use crate::{
4    Dialect,
5    metal::AddressSpace,
6    shared::{Component, KernelArg, MslComputeKernel, Variable},
7};
8
9pub fn bindings(
10    repr: &MslComputeKernel,
11    args: &KernelArguments,
12) -> (Vec<Visibility>, Option<Visibility>, bool) {
13    let mut bindings: Vec<Visibility> = vec![];
14    // must be in the same order as the compilation order: inputs, outputs and named
15    for b in repr.buffers.iter() {
16        bindings.push(b.vis);
17    }
18    let info = (!args.info.data.is_empty()).then_some(Visibility::Read);
19    let uniform = args.info.dynamic_metadata_offset >= args.info.data.len();
20    (bindings, info, uniform)
21}
22
23pub fn format_global_binding_arg<D: Dialect>(
24    name: &str,
25    binding: &KernelArg<D>,
26    suffix: Option<&str>,
27    attr_idx: &mut usize,
28    f: &mut core::fmt::Formatter<'_>,
29) -> core::fmt::Result {
30    let suffix = suffix.map_or("".into(), |s| format!("_{s}"));
31    let (pointer, size) = match binding.size {
32        Some(size) => ("".to_string(), format!("[{size}]")),
33        None => (" *".to_string(), "".to_string()),
34    };
35
36    let comma = if *attr_idx > 0 { "," } else { "" };
37    let address_space = AddressSpace::from(binding);
38    let ty = binding.item;
39    let attribute = address_space.attribute();
40
41    write!(
42        f,
43        "{comma}\n    {address_space} {ty}{pointer} {name}{suffix}",
44    )?;
45    // attribute
46    attribute.indexed_fmt(*attr_idx, f)?;
47    write!(f, "{size}")?;
48    *attr_idx += 1;
49    Ok(())
50}
51
52pub fn format_metal_builtin_binding_arg<D: Dialect>(
53    f: &mut core::fmt::Formatter<'_>,
54    variable: &Variable<D>,
55    comma: bool,
56) -> core::fmt::Result {
57    let ty = variable.item();
58    let attribute = variable.attribute();
59    let comma = if comma { "," } else { "" };
60    write!(f, "{comma}\n    {ty} {variable} {attribute}",)?;
61    Ok(())
62}