cubecl_cpp/metal/
binding.rs1use cubecl_core::compute::Visibility;
2
3use crate::{
4 Dialect,
5 metal::AddressSpace,
6 shared::{Binding, Component, MslComputeKernel, Variable},
7};
8
9pub fn bindings(repr: &MslComputeKernel) -> Vec<(usize, Visibility)> {
10 let mut bindings: Vec<(usize, Visibility)> = vec![];
11 let mut buffer_idx = 0;
13 for b in repr.buffers.iter() {
14 bindings.push((buffer_idx, b.vis));
15 buffer_idx += 1;
16 }
17 if repr.meta_static_len > 0 {
18 bindings.push((buffer_idx, Visibility::Read));
19 buffer_idx += 1;
20 }
21 for _ in repr.scalars.iter() {
22 bindings.push((buffer_idx, Visibility::Read));
23 buffer_idx += 1;
24 }
25 bindings
26}
27
28pub fn format_global_binding_arg<D: Dialect>(
29 name: &str,
30 binding: &Binding<D>,
31 suffix: Option<&str>,
32 attr_idx: &mut usize,
33 f: &mut core::fmt::Formatter<'_>,
34) -> core::fmt::Result {
35 let suffix = suffix.map_or("".into(), |s| format!("_{s}"));
36 let (pointer, size) = match binding.size {
37 Some(size) => ("".to_string(), format!("[{}]", size)),
38 None => (" *".to_string(), "".to_string()),
39 };
40
41 let comma = if *attr_idx > 0 { "," } else { "" };
42 let address_space = AddressSpace::from(binding);
43 let ty = binding.item;
44 let attribute = address_space.attribute();
45
46 write!(
47 f,
48 "{comma}\n {address_space} {ty}{pointer} {name}{suffix}",
49 )?;
50 attribute.indexed_fmt(*attr_idx, f)?;
52 write!(f, "{size}")?;
53 *attr_idx += 1;
54 Ok(())
55}
56
57pub fn format_metal_builtin_binding_arg<D: Dialect>(
58 f: &mut core::fmt::Formatter<'_>,
59 variable: &Variable<D>,
60 comma: bool,
61) -> core::fmt::Result {
62 let ty = variable.item();
63 let attribute = variable.attribute();
64 let comma = if comma { "," } else { "" };
65 write!(f, "{comma}\n {ty} {variable} {attribute}",)?;
66 Ok(())
67}