cubecl_core/codegen/
metadata.rs

1//! Metadata helpers to easily get offsets etc.
2//!
3//! Conceptually, metadata is represented like this:
4//! ```rust
5//! struct Metadata<const NUM_BUFS: usize, const NUM_EXT: usize> {
6//!     base: BaseMeta<NUM_BUFS>,
7//!     extended: ExtendedMeta<NUM_EXT>,
8//! }
9//!
10//! struct BaseMeta<const N: usize> {
11//!     buffer_lengths: [u32; N],
12//!     logical_lengths: [u32; N],
13//! }
14//!
15//! struct ExtendedMeta<const N: usize> {
16//!     ranks: [u32; N],
17//!     shape_offsets: [usize; N],
18//!     stride_offsets: [usize; N],
19//!     shapes: Vec<u32>,
20//!     strides: Vec<u32>
21//! }
22//! ```
23//! where `Vec` isn't an actual `Vec`, just a dynamically sized series of values.
24//!
25//! Ranks and lengths have a constant offset, while shapes/strides involve loading the tensor's
26//! offset, then adding `dim` to the offset to get each shape/stride.
27
28use bytemuck::cast_slice_mut;
29use cubecl_ir::StorageType;
30use cubecl_runtime::server::MetadataBinding;
31
32use crate::prelude::InputScalar;
33
34// Metadata
35const BUFFER_LEN: u32 = 0;
36const LENGTH: u32 = 1;
37const BASE_LEN: u32 = 2;
38
39// Extended Metadata
40const RANK: u32 = 0;
41const SHAPE_OFFSETS: u32 = 1;
42const STRIDE_OFFSETS: u32 = 2;
43const EXTENDED_LEN: u32 = 3;
44
45/// Helper to calculate metadata offsets based on buffer count and position
46#[derive(Clone, Debug, Default)]
47pub struct Metadata {
48    num_meta: u32,
49    num_extended_meta: u32,
50}
51
52impl Metadata {
53    pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
54        Self {
55            num_meta,
56            num_extended_meta,
57        }
58    }
59
60    fn offset_of(&self, id: u32) -> u32 {
61        self.num_meta * id
62    }
63
64    fn base_len(&self) -> u32 {
65        self.num_meta * BASE_LEN
66    }
67
68    pub fn static_len(&self) -> u32 {
69        self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
70    }
71
72    fn offset_of_extended(&self, id: u32) -> u32 {
73        self.base_len() + self.num_extended_meta * id
74    }
75
76    pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
77        self.offset_of(BUFFER_LEN) + buffer_idx
78    }
79
80    pub fn len_index(&self, buffer_idx: u32) -> u32 {
81        self.offset_of(LENGTH) + buffer_idx
82    }
83
84    pub fn rank_index(&self, buffer_idx: u32) -> u32 {
85        self.offset_of_extended(RANK) + buffer_idx
86    }
87
88    pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
89        self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
90    }
91
92    pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
93        self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
94    }
95}
96
97/// Builder for a serialized metadata struct
98///
99/// Inputs/Outputs must be added in the same order they're defined in the bind group
100pub struct MetadataBuilder {
101    buffer_lens: Vec<InputScalar>,
102    lengths: Vec<InputScalar>,
103    ranks: Vec<InputScalar>,
104    shapes: Vec<Vec<InputScalar>>,
105    strides: Vec<Vec<InputScalar>>,
106
107    address_type: StorageType,
108}
109
110impl MetadataBuilder {
111    pub fn new(address_type: StorageType) -> Self {
112        Self {
113            buffer_lens: Default::default(),
114            lengths: Default::default(),
115            ranks: Default::default(),
116            shapes: Default::default(),
117            strides: Default::default(),
118            address_type,
119        }
120    }
121
122    /// Add an array to a builder
123    pub fn with_array(&mut self, buffer_len: u64, len: u64) {
124        self.buffer_lens
125            .push(InputScalar::new(buffer_len, self.address_type));
126        self.lengths.push(InputScalar::new(len, self.address_type));
127    }
128
129    /// Add a tensor to a builder
130    pub fn with_tensor(
131        &mut self,
132        rank: u64,
133        buffer_len: u64,
134        len: u64,
135        shape: Vec<u64>,
136        strides: Vec<u64>,
137    ) {
138        self.buffer_lens
139            .push(InputScalar::new(buffer_len, self.address_type));
140        self.lengths.push(InputScalar::new(len, self.address_type));
141        self.ranks.push(InputScalar::new(rank, self.address_type));
142        self.shapes.push(
143            shape
144                .into_iter()
145                .map(|s| InputScalar::new(s, self.address_type))
146                .collect(),
147        );
148        self.strides.push(
149            strides
150                .into_iter()
151                .map(|s| InputScalar::new(s, self.address_type))
152                .collect(),
153        );
154    }
155
156    /// Build the final serialized metadata struct
157    pub fn finish(self) -> MetadataBinding {
158        let addr_size = self.address_type.size();
159        let mut meta = self
160            .buffer_lens
161            .iter()
162            .flat_map(|it| it.as_bytes())
163            .collect::<Vec<_>>();
164
165        meta.extend(self.lengths.iter().flat_map(|it| it.as_bytes()));
166        meta.extend(self.ranks.iter().flat_map(|it| it.as_bytes()));
167
168        let num_ext = self.ranks.len();
169        let mut shape_offsets = Vec::with_capacity(num_ext * addr_size);
170        let mut stride_offsets = Vec::with_capacity(num_ext * addr_size);
171
172        let mut current_offset = meta.len() / addr_size + num_ext * 2; // Total fields in static portion
173
174        for shape in self.shapes.iter() {
175            let offset = InputScalar::new(current_offset, self.address_type);
176            shape_offsets.extend(offset.as_bytes());
177            current_offset += shape.len();
178        }
179
180        meta.extend(shape_offsets);
181
182        for stride in self.strides.iter() {
183            let offset = InputScalar::new(current_offset, self.address_type);
184            stride_offsets.extend(offset.as_bytes());
185            current_offset += stride.len();
186        }
187
188        meta.extend(stride_offsets);
189
190        let static_len = meta.len() / addr_size;
191
192        meta.extend(self.shapes.iter().flatten().flat_map(|it| it.as_bytes()));
193        meta.extend(
194            self.strides
195                .into_iter()
196                .flatten()
197                .flat_map(|it| it.as_bytes()),
198        );
199
200        let total_size_64 = meta.len().div_ceil(size_of::<u64>());
201        let mut meta_64 = vec![0u64; total_size_64];
202        cast_slice_mut::<u64, u8>(&mut meta_64)[..meta.len()].copy_from_slice(&meta);
203
204        MetadataBinding::new(meta_64, static_len)
205    }
206}