Skip to main content

cubecl_core/codegen/
metadata.rs

1//! Metadata helpers to easily get offsets etc.
2//!
3//! Conceptually, metadata is represented like this:
4//! ```rust
5//! struct Metadata<const NUM_BUFS: usize, const NUM_EXT: usize> {
6//!     base: BaseMeta<NUM_BUFS>,
7//!     extended: ExtendedMeta<NUM_EXT>,
8//! }
9//!
10//! struct BaseMeta<const N: usize> {
11//!     buffer_lengths: [u32; N],
12//!     logical_lengths: [u32; N],
13//! }
14//!
15//! struct ExtendedMeta<const N: usize> {
16//!     ranks: [u32; N],
17//!     shape_offsets: [usize; N],
18//!     stride_offsets: [usize; N],
19//!     shapes: Vec<u32>,
20//!     strides: Vec<u32>
21//! }
22//! ```
23//! where `Vec` isn't an actual `Vec`, just a dynamically sized series of values.
24//!
25//! Ranks and lengths have a constant offset, while shapes/strides involve loading the tensor's
26//! offset, then adding `dim` to the offset to get each shape/stride.
27
28use alloc::vec::Vec;
29use bytemuck::Pod;
30use cubecl_ir::AddressType;
31use cubecl_zspace::{Shape, Strides};
32use num_traits::NumCast;
33
34// Metadata
35const BUFFER_LEN: u32 = 0;
36const LENGTH: u32 = 1;
37const BASE_LEN: u32 = 2;
38
39// Extended Metadata
40const RANK: u32 = 0;
41const SHAPE_OFFSETS: u32 = 1;
42const STRIDE_OFFSETS: u32 = 2;
43const EXTENDED_LEN: u32 = 3;
44
45/// Helper to calculate metadata offsets based on buffer count and position
46#[derive(Clone, Copy, Debug, Default)]
47pub struct Metadata {
48    num_meta: u32,
49    num_extended_meta: u32,
50}
51
52impl Metadata {
53    pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
54        Self {
55            num_meta,
56            num_extended_meta,
57        }
58    }
59
60    fn offset_of(&self, id: u32) -> u32 {
61        self.num_meta * id
62    }
63
64    fn base_len(&self) -> u32 {
65        self.num_meta * BASE_LEN
66    }
67
68    pub fn static_len(&self) -> u32 {
69        self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
70    }
71
72    pub fn num_meta(&self) -> u32 {
73        self.num_meta
74    }
75
76    pub fn num_extended_meta(&self) -> u32 {
77        self.num_extended_meta
78    }
79
80    fn offset_of_extended(&self, id: u32) -> u32 {
81        self.base_len() + self.num_extended_meta * id
82    }
83
84    pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
85        self.offset_of(BUFFER_LEN) + buffer_idx
86    }
87
88    pub fn len_index(&self, buffer_idx: u32) -> u32 {
89        self.offset_of(LENGTH) + buffer_idx
90    }
91
92    pub fn rank_index(&self, buffer_idx: u32) -> u32 {
93        self.offset_of_extended(RANK) + buffer_idx
94    }
95
96    pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
97        self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
98    }
99
100    pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
101        self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
102    }
103}
104
105/// Builder for a serialized metadata struct
106///
107/// Inputs/Outputs must be added in the same order they're defined in the bind group
108#[derive(Default)]
109pub struct MetadataBuilder {
110    state_32: State<u32>,
111    state_64: State<u64>,
112}
113
114#[derive(Default)]
115struct State<T: Pod> {
116    buffer_lens: Vec<T>,
117    lengths: Vec<T>,
118    ranks: Vec<T>,
119    shapes: Vec<T>,
120    strides: Vec<T>,
121
122    offsets: Vec<usize>,
123}
124
125impl MetadataBuilder {
126    /// Add an array to a builder
127    pub fn register_array(&mut self, buffer_len: u64, len: u64, address_type: AddressType) {
128        match address_type {
129            AddressType::U64 => {
130                self.state_64.buffer_lens.push(buffer_len);
131                self.state_64.lengths.push(len);
132            }
133            AddressType::U32 => {
134                self.state_32.buffer_lens.push(buffer_len as u32);
135                self.state_32.lengths.push(len as u32);
136            }
137        }
138    }
139
140    /// Add a tensor to a builder
141    pub fn register_tensor(
142        &mut self,
143        rank: u64,
144        buffer_len: u64,
145        len: u64,
146        shape: Shape,
147        strides: Strides,
148        address_type: AddressType,
149    ) {
150        match address_type {
151            AddressType::U64 => {
152                let state = &mut self.state_64;
153                state.buffer_lens.push(buffer_len);
154                state.lengths.push(len);
155                state.ranks.push(rank);
156                state.offsets.push(state.shapes.len());
157                state.shapes.extend(shape.iter().map(|s| *s as u64));
158                state.strides.extend(strides.iter().map(|s| *s as u64));
159            }
160            AddressType::U32 => {
161                let state = &mut self.state_32;
162                state.buffer_lens.push(buffer_len as u32);
163                state.lengths.push(len as u32);
164                state.ranks.push(rank as u32);
165                state.offsets.push(state.shapes.len());
166                state.shapes.extend(shape.iter().map(|s| *s as u32));
167                state.strides.extend(strides.iter().map(|s| *s as u32));
168            }
169        }
170    }
171
172    pub fn static_len(&self, address_type: AddressType) -> usize {
173        let (base, ext) = match address_type {
174            AddressType::U32 => (self.state_32.buffer_lens.len(), self.state_32.ranks.len()),
175            AddressType::U64 => (self.state_64.buffer_lens.len(), self.state_64.ranks.len()),
176        };
177        base * BASE_LEN as usize + ext * EXTENDED_LEN as usize
178    }
179
180    pub fn dynamic_len(&self, address_type: AddressType) -> usize {
181        match address_type {
182            AddressType::U32 => self.state_32.shapes.len() + self.state_32.strides.len(),
183            AddressType::U64 => self.state_64.shapes.len() + self.state_64.strides.len(),
184        }
185    }
186
187    /// Build the final serialized metadata struct
188    pub fn finish(&mut self, address_type: AddressType, out: (&mut [u64], &mut [u64])) {
189        fn finish_inner<T: Pod + NumCast>(state: &mut State<T>, out: (&mut [u64], &mut [u64])) {
190            let mut sized = bytemuck::cast_slice_mut::<u64, u8>(out.0);
191            let mut dynamic = bytemuck::cast_slice_mut::<u64, u8>(out.1);
192
193            {
194                let buffer_lens = bytemuck::cast_slice::<T, u8>(&state.buffer_lens);
195                let lengths = bytemuck::cast_slice::<T, u8>(&state.lengths);
196                let ranks = bytemuck::cast_slice::<T, u8>(&state.ranks);
197
198                sized[..buffer_lens.len()].copy_from_slice(buffer_lens);
199                sized = &mut sized[buffer_lens.len()..];
200
201                sized[..lengths.len()].copy_from_slice(lengths);
202                sized = &mut sized[lengths.len()..];
203
204                sized[..ranks.len()].copy_from_slice(ranks);
205                sized = &mut sized[ranks.len()..];
206            }
207
208            state.buffer_lens.clear();
209            state.lengths.clear();
210            state.ranks.clear();
211
212            let strides_offset_base = state.shapes.len();
213
214            for offs in state.offsets.iter() {
215                let offset = [T::from(*offs).unwrap()];
216                let bytes = bytemuck::cast_slice(&offset);
217                sized[..bytes.len()].copy_from_slice(bytes);
218                sized = &mut sized[size_of::<T>()..];
219            }
220
221            for offs in state.offsets.drain(..) {
222                let offset = [T::from(strides_offset_base + offs).unwrap()];
223                let bytes = bytemuck::cast_slice(&offset);
224                sized[..bytes.len()].copy_from_slice(bytes);
225                sized = &mut sized[size_of::<T>()..];
226            }
227
228            {
229                let shapes = bytemuck::cast_slice::<T, u8>(&state.shapes);
230                let strides = bytemuck::cast_slice::<T, u8>(&state.strides);
231
232                dynamic[..shapes.len()].copy_from_slice(shapes);
233                dynamic = &mut dynamic[shapes.len()..];
234
235                dynamic[..strides.len()].copy_from_slice(strides);
236            }
237
238            state.shapes.clear();
239            state.strides.clear();
240        }
241
242        match address_type {
243            AddressType::U32 => finish_inner(&mut self.state_32, out),
244            AddressType::U64 => finish_inner(&mut self.state_64, out),
245        }
246    }
247}