cubecl_core/codegen/
metadata.rs1use alloc::vec::Vec;
29use bytemuck::Pod;
30use cubecl_ir::AddressType;
31use cubecl_zspace::{Shape, Strides};
32use num_traits::NumCast;
33
34const BUFFER_LEN: u32 = 0;
36const LENGTH: u32 = 1;
37const BASE_LEN: u32 = 2;
38
39const RANK: u32 = 0;
41const SHAPE_OFFSETS: u32 = 1;
42const STRIDE_OFFSETS: u32 = 2;
43const EXTENDED_LEN: u32 = 3;
44
45#[derive(Clone, Copy, Debug, Default)]
47pub struct Metadata {
48 num_meta: u32,
49 num_extended_meta: u32,
50}
51
52impl Metadata {
53 pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
54 Self {
55 num_meta,
56 num_extended_meta,
57 }
58 }
59
60 fn offset_of(&self, id: u32) -> u32 {
61 self.num_meta * id
62 }
63
64 fn base_len(&self) -> u32 {
65 self.num_meta * BASE_LEN
66 }
67
68 pub fn static_len(&self) -> u32 {
69 self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
70 }
71
72 pub fn num_meta(&self) -> u32 {
73 self.num_meta
74 }
75
76 pub fn num_extended_meta(&self) -> u32 {
77 self.num_extended_meta
78 }
79
80 fn offset_of_extended(&self, id: u32) -> u32 {
81 self.base_len() + self.num_extended_meta * id
82 }
83
84 pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
85 self.offset_of(BUFFER_LEN) + buffer_idx
86 }
87
88 pub fn len_index(&self, buffer_idx: u32) -> u32 {
89 self.offset_of(LENGTH) + buffer_idx
90 }
91
92 pub fn rank_index(&self, buffer_idx: u32) -> u32 {
93 self.offset_of_extended(RANK) + buffer_idx
94 }
95
96 pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
97 self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
98 }
99
100 pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
101 self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
102 }
103}
104
105#[derive(Default)]
109pub struct MetadataBuilder {
110 state_32: State<u32>,
111 state_64: State<u64>,
112}
113
114#[derive(Default)]
115struct State<T: Pod> {
116 buffer_lens: Vec<T>,
117 lengths: Vec<T>,
118 ranks: Vec<T>,
119 shapes: Vec<T>,
120 strides: Vec<T>,
121
122 offsets: Vec<usize>,
123}
124
125impl MetadataBuilder {
126 pub fn register_array(&mut self, buffer_len: u64, len: u64, address_type: AddressType) {
128 match address_type {
129 AddressType::U64 => {
130 self.state_64.buffer_lens.push(buffer_len);
131 self.state_64.lengths.push(len);
132 }
133 AddressType::U32 => {
134 self.state_32.buffer_lens.push(buffer_len as u32);
135 self.state_32.lengths.push(len as u32);
136 }
137 }
138 }
139
140 pub fn register_tensor(
142 &mut self,
143 rank: u64,
144 buffer_len: u64,
145 len: u64,
146 shape: Shape,
147 strides: Strides,
148 address_type: AddressType,
149 ) {
150 match address_type {
151 AddressType::U64 => {
152 let state = &mut self.state_64;
153 state.buffer_lens.push(buffer_len);
154 state.lengths.push(len);
155 state.ranks.push(rank);
156 state.offsets.push(state.shapes.len());
157 state.shapes.extend(shape.iter().map(|s| *s as u64));
158 state.strides.extend(strides.iter().map(|s| *s as u64));
159 }
160 AddressType::U32 => {
161 let state = &mut self.state_32;
162 state.buffer_lens.push(buffer_len as u32);
163 state.lengths.push(len as u32);
164 state.ranks.push(rank as u32);
165 state.offsets.push(state.shapes.len());
166 state.shapes.extend(shape.iter().map(|s| *s as u32));
167 state.strides.extend(strides.iter().map(|s| *s as u32));
168 }
169 }
170 }
171
172 pub fn static_len(&self, address_type: AddressType) -> usize {
173 let (base, ext) = match address_type {
174 AddressType::U32 => (self.state_32.buffer_lens.len(), self.state_32.ranks.len()),
175 AddressType::U64 => (self.state_64.buffer_lens.len(), self.state_64.ranks.len()),
176 };
177 base * BASE_LEN as usize + ext * EXTENDED_LEN as usize
178 }
179
180 pub fn dynamic_len(&self, address_type: AddressType) -> usize {
181 match address_type {
182 AddressType::U32 => self.state_32.shapes.len() + self.state_32.strides.len(),
183 AddressType::U64 => self.state_64.shapes.len() + self.state_64.strides.len(),
184 }
185 }
186
187 pub fn finish(&mut self, address_type: AddressType, out: (&mut [u64], &mut [u64])) {
189 fn finish_inner<T: Pod + NumCast>(state: &mut State<T>, out: (&mut [u64], &mut [u64])) {
190 let mut sized = bytemuck::cast_slice_mut::<u64, u8>(out.0);
191 let mut dynamic = bytemuck::cast_slice_mut::<u64, u8>(out.1);
192
193 {
194 let buffer_lens = bytemuck::cast_slice::<T, u8>(&state.buffer_lens);
195 let lengths = bytemuck::cast_slice::<T, u8>(&state.lengths);
196 let ranks = bytemuck::cast_slice::<T, u8>(&state.ranks);
197
198 sized[..buffer_lens.len()].copy_from_slice(buffer_lens);
199 sized = &mut sized[buffer_lens.len()..];
200
201 sized[..lengths.len()].copy_from_slice(lengths);
202 sized = &mut sized[lengths.len()..];
203
204 sized[..ranks.len()].copy_from_slice(ranks);
205 sized = &mut sized[ranks.len()..];
206 }
207
208 state.buffer_lens.clear();
209 state.lengths.clear();
210 state.ranks.clear();
211
212 let strides_offset_base = state.shapes.len();
213
214 for offs in state.offsets.iter() {
215 let offset = [T::from(*offs).unwrap()];
216 let bytes = bytemuck::cast_slice(&offset);
217 sized[..bytes.len()].copy_from_slice(bytes);
218 sized = &mut sized[size_of::<T>()..];
219 }
220
221 for offs in state.offsets.drain(..) {
222 let offset = [T::from(strides_offset_base + offs).unwrap()];
223 let bytes = bytemuck::cast_slice(&offset);
224 sized[..bytes.len()].copy_from_slice(bytes);
225 sized = &mut sized[size_of::<T>()..];
226 }
227
228 {
229 let shapes = bytemuck::cast_slice::<T, u8>(&state.shapes);
230 let strides = bytemuck::cast_slice::<T, u8>(&state.strides);
231
232 dynamic[..shapes.len()].copy_from_slice(shapes);
233 dynamic = &mut dynamic[shapes.len()..];
234
235 dynamic[..strides.len()].copy_from_slice(strides);
236 }
237
238 state.shapes.clear();
239 state.strides.clear();
240 }
241
242 match address_type {
243 AddressType::U32 => finish_inner(&mut self.state_32, out),
244 AddressType::U64 => finish_inner(&mut self.state_64, out),
245 }
246 }
247}