cubecl_core/codegen/
metadata.rs1use alloc::{vec, vec::Vec};
29use bytemuck::Pod;
30use cubecl_ir::AddressType;
31use cubecl_runtime::server::MetadataBinding;
32use num_traits::NumCast;
33
34const BUFFER_LEN: u32 = 0;
36const LENGTH: u32 = 1;
37const BASE_LEN: u32 = 2;
38
39const RANK: u32 = 0;
41const SHAPE_OFFSETS: u32 = 1;
42const STRIDE_OFFSETS: u32 = 2;
43const EXTENDED_LEN: u32 = 3;
44
45#[derive(Clone, Debug, Default)]
47pub struct Metadata {
48 num_meta: u32,
49 num_extended_meta: u32,
50}
51
52impl Metadata {
53 pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
54 Self {
55 num_meta,
56 num_extended_meta,
57 }
58 }
59
60 fn offset_of(&self, id: u32) -> u32 {
61 self.num_meta * id
62 }
63
64 fn base_len(&self) -> u32 {
65 self.num_meta * BASE_LEN
66 }
67
68 pub fn static_len(&self) -> u32 {
69 self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
70 }
71
72 fn offset_of_extended(&self, id: u32) -> u32 {
73 self.base_len() + self.num_extended_meta * id
74 }
75
76 pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
77 self.offset_of(BUFFER_LEN) + buffer_idx
78 }
79
80 pub fn len_index(&self, buffer_idx: u32) -> u32 {
81 self.offset_of(LENGTH) + buffer_idx
82 }
83
84 pub fn rank_index(&self, buffer_idx: u32) -> u32 {
85 self.offset_of_extended(RANK) + buffer_idx
86 }
87
88 pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
89 self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
90 }
91
92 pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
93 self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
94 }
95}
96
97#[derive(Default)]
101pub struct MetadataBuilder {
102 state_32: State<u32>,
103 state_64: State<u64>,
104}
105
106#[derive(Default)]
107struct State<T: Pod> {
108 buffer_lens: Vec<T>,
109 lengths: Vec<T>,
110 ranks: Vec<T>,
111 shapes: Vec<T>,
112 strides: Vec<T>,
113
114 offsets: Vec<usize>,
115}
116
117impl MetadataBuilder {
118 pub fn register_array(&mut self, buffer_len: u64, len: u64, address_type: AddressType) {
120 match address_type {
121 AddressType::U64 => {
122 self.state_64.buffer_lens.push(buffer_len);
123 self.state_64.lengths.push(len);
124 }
125 AddressType::U32 => {
126 self.state_32.buffer_lens.push(buffer_len as u32);
127 self.state_32.lengths.push(len as u32);
128 }
129 }
130 }
131
132 pub fn register_tensor(
134 &mut self,
135 rank: u64,
136 buffer_len: u64,
137 len: u64,
138 shape: &[usize],
139 strides: &[usize],
140 address_type: AddressType,
141 ) {
142 match address_type {
143 AddressType::U64 => {
144 let state = &mut self.state_64;
145 state.buffer_lens.push(buffer_len);
146 state.lengths.push(len);
147 state.ranks.push(rank);
148 state.offsets.push(state.shapes.len());
149 state.shapes.extend(shape.iter().map(|s| *s as u64));
150 state.strides.extend(strides.iter().map(|s| *s as u64));
151 }
152 AddressType::U32 => {
153 let state = &mut self.state_32;
154 state.buffer_lens.push(buffer_len as u32);
155 state.lengths.push(len as u32);
156 state.ranks.push(rank as u32);
157 state.offsets.push(state.shapes.len());
158 state.shapes.extend(shape.iter().map(|s| *s as u32));
159 state.strides.extend(strides.iter().map(|s| *s as u32));
160 }
161 }
162 }
163
164 pub fn finish(&mut self, address_type: AddressType) -> MetadataBinding {
166 fn finish_inner<T: Pod + NumCast>(state: &mut State<T>) -> MetadataBinding {
167 let num_base = state.buffer_lens.len();
168 let num_ext = state.ranks.len();
169
170 let static_len = num_base * BASE_LEN as usize + num_ext * EXTENDED_LEN as usize;
172 let dynamic_len = state.shapes.len() + state.strides.len();
173 let total_len = static_len + dynamic_len;
174
175 let len_u64 = (total_len * size_of::<T>()).div_ceil(size_of::<u64>());
176 let mut meta_64 = vec![0u64; len_u64];
177 let mut meta = bytemuck::cast_slice_mut::<u64, u8>(&mut meta_64);
178
179 {
180 let buffer_lens = bytemuck::cast_slice::<T, u8>(&state.buffer_lens);
181 let lengths = bytemuck::cast_slice::<T, u8>(&state.lengths);
182 let ranks = bytemuck::cast_slice::<T, u8>(&state.ranks);
183
184 meta[..buffer_lens.len()].copy_from_slice(buffer_lens);
185 meta = &mut meta[buffer_lens.len()..];
186
187 meta[..lengths.len()].copy_from_slice(lengths);
188 meta = &mut meta[lengths.len()..];
189
190 meta[..ranks.len()].copy_from_slice(ranks);
191 meta = &mut meta[ranks.len()..];
192 }
193
194 state.buffer_lens.clear();
195 state.lengths.clear();
196 state.ranks.clear();
197
198 let shape_offset_base = static_len;
199 let strides_offset_base = shape_offset_base + state.shapes.len();
200
201 for offs in state.offsets.iter() {
202 let offset = [T::from(shape_offset_base + *offs).unwrap()];
203 let bytes = bytemuck::cast_slice(&offset);
204 meta[..bytes.len()].copy_from_slice(bytes);
205 meta = &mut meta[size_of::<T>()..];
206 }
207
208 for offs in state.offsets.drain(..) {
209 let offset = [T::from(strides_offset_base + offs).unwrap()];
210 let bytes = bytemuck::cast_slice(&offset);
211 meta[..bytes.len()].copy_from_slice(bytes);
212 meta = &mut meta[size_of::<T>()..];
213 }
214
215 {
216 let shapes = bytemuck::cast_slice::<T, u8>(&state.shapes);
217 let strides = bytemuck::cast_slice::<T, u8>(&state.strides);
218
219 meta[..shapes.len()].copy_from_slice(shapes);
220 meta = &mut meta[shapes.len()..];
221
222 meta[..strides.len()].copy_from_slice(strides);
223 }
224
225 state.shapes.clear();
226 state.strides.clear();
227
228 MetadataBinding::new(meta_64, static_len)
229 }
230
231 match address_type {
232 AddressType::U32 => finish_inner(&mut self.state_32),
233 AddressType::U64 => finish_inner(&mut self.state_64),
234 }
235 }
236}