cubecl_core/codegen/
metadata.rs1use bytemuck::cast_slice_mut;
29use cubecl_ir::StorageType;
30use cubecl_runtime::server::MetadataBinding;
31
32use crate::prelude::InputScalar;
33
34const BUFFER_LEN: u32 = 0;
36const LENGTH: u32 = 1;
37const BASE_LEN: u32 = 2;
38
39const RANK: u32 = 0;
41const SHAPE_OFFSETS: u32 = 1;
42const STRIDE_OFFSETS: u32 = 2;
43const EXTENDED_LEN: u32 = 3;
44
45#[derive(Clone, Debug, Default)]
47pub struct Metadata {
48 num_meta: u32,
49 num_extended_meta: u32,
50}
51
52impl Metadata {
53 pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
54 Self {
55 num_meta,
56 num_extended_meta,
57 }
58 }
59
60 fn offset_of(&self, id: u32) -> u32 {
61 self.num_meta * id
62 }
63
64 fn base_len(&self) -> u32 {
65 self.num_meta * BASE_LEN
66 }
67
68 pub fn static_len(&self) -> u32 {
69 self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
70 }
71
72 fn offset_of_extended(&self, id: u32) -> u32 {
73 self.base_len() + self.num_extended_meta * id
74 }
75
76 pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
77 self.offset_of(BUFFER_LEN) + buffer_idx
78 }
79
80 pub fn len_index(&self, buffer_idx: u32) -> u32 {
81 self.offset_of(LENGTH) + buffer_idx
82 }
83
84 pub fn rank_index(&self, buffer_idx: u32) -> u32 {
85 self.offset_of_extended(RANK) + buffer_idx
86 }
87
88 pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
89 self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
90 }
91
92 pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
93 self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
94 }
95}
96
97pub struct MetadataBuilder {
101 buffer_lens: Vec<InputScalar>,
102 lengths: Vec<InputScalar>,
103 ranks: Vec<InputScalar>,
104 shapes: Vec<Vec<InputScalar>>,
105 strides: Vec<Vec<InputScalar>>,
106
107 address_type: StorageType,
108}
109
110impl MetadataBuilder {
111 pub fn new(address_type: StorageType) -> Self {
112 Self {
113 buffer_lens: Default::default(),
114 lengths: Default::default(),
115 ranks: Default::default(),
116 shapes: Default::default(),
117 strides: Default::default(),
118 address_type,
119 }
120 }
121
122 pub fn with_array(&mut self, buffer_len: u64, len: u64) {
124 self.buffer_lens
125 .push(InputScalar::new(buffer_len, self.address_type));
126 self.lengths.push(InputScalar::new(len, self.address_type));
127 }
128
129 pub fn with_tensor(
131 &mut self,
132 rank: u64,
133 buffer_len: u64,
134 len: u64,
135 shape: Vec<u64>,
136 strides: Vec<u64>,
137 ) {
138 self.buffer_lens
139 .push(InputScalar::new(buffer_len, self.address_type));
140 self.lengths.push(InputScalar::new(len, self.address_type));
141 self.ranks.push(InputScalar::new(rank, self.address_type));
142 self.shapes.push(
143 shape
144 .into_iter()
145 .map(|s| InputScalar::new(s, self.address_type))
146 .collect(),
147 );
148 self.strides.push(
149 strides
150 .into_iter()
151 .map(|s| InputScalar::new(s, self.address_type))
152 .collect(),
153 );
154 }
155
156 pub fn finish(self) -> MetadataBinding {
158 let addr_size = self.address_type.size();
159 let mut meta = self
160 .buffer_lens
161 .iter()
162 .flat_map(|it| it.as_bytes())
163 .collect::<Vec<_>>();
164
165 meta.extend(self.lengths.iter().flat_map(|it| it.as_bytes()));
166 meta.extend(self.ranks.iter().flat_map(|it| it.as_bytes()));
167
168 let num_ext = self.ranks.len();
169 let mut shape_offsets = Vec::with_capacity(num_ext * addr_size);
170 let mut stride_offsets = Vec::with_capacity(num_ext * addr_size);
171
172 let mut current_offset = meta.len() / addr_size + num_ext * 2; for shape in self.shapes.iter() {
175 let offset = InputScalar::new(current_offset, self.address_type);
176 shape_offsets.extend(offset.as_bytes());
177 current_offset += shape.len();
178 }
179
180 meta.extend(shape_offsets);
181
182 for stride in self.strides.iter() {
183 let offset = InputScalar::new(current_offset, self.address_type);
184 stride_offsets.extend(offset.as_bytes());
185 current_offset += stride.len();
186 }
187
188 meta.extend(stride_offsets);
189
190 let static_len = meta.len() / addr_size;
191
192 meta.extend(self.shapes.iter().flatten().flat_map(|it| it.as_bytes()));
193 meta.extend(
194 self.strides
195 .into_iter()
196 .flatten()
197 .flat_map(|it| it.as_bytes()),
198 );
199
200 let total_size_64 = meta.len().div_ceil(size_of::<u64>());
201 let mut meta_64 = vec![0u64; total_size_64];
202 cast_slice_mut::<u64, u8>(&mut meta_64)[..meta.len()].copy_from_slice(&meta);
203
204 MetadataBinding::new(meta_64, static_len)
205 }
206}