cubecl_core/codegen/
metadata.rs

1//! Metadata helpers to easily get offsets etc.
2//!
3//! Conceptually, metadata is represented like this:
4//! ```rust
5//! struct Metadata<const NUM_BUFS: usize, const NUM_EXT: usize> {
6//!     base: BaseMeta<NUM_BUFS>,
7//!     extended: ExtendedMeta<NUM_EXT>,
8//! }
9//!
10//! struct BaseMeta<const N: usize> {
11//!     buffer_lengths: [u32; N],
12//!     logical_lengths: [u32; N],
13//! }
14//!
15//! struct ExtendedMeta<const N: usize> {
16//!     ranks: [u32; N],
17//!     shape_offsets: [usize; N],
18//!     stride_offsets: [usize; N],
19//!     shapes: Vec<u32>,
20//!     strides: Vec<u32>
21//! }
22//! ```
23//! where `Vec` isn't an actual `Vec`, just a dynamically sized series of values.
24//!
25//! Ranks and lengths have a constant offset, while shapes/strides involve loading the tensor's
26//! offset, then adding `dim` to the offset to get each shape/stride.
27
28// Metadata
29const BUFFER_LEN: u32 = 0;
30const LENGTH: u32 = 1;
31const BASE_LEN: u32 = 2;
32
33// Extended Metadata
34const RANK: u32 = 0;
35const SHAPE_OFFSETS: u32 = 1;
36const STRIDE_OFFSETS: u32 = 2;
37
38/// Helper to calculate metadata offsets based on buffer count and position
39#[derive(Clone, Debug, Default)]
40pub struct Metadata {
41    num_meta: u32,
42    num_extended_meta: u32,
43}
44
45impl Metadata {
46    pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
47        Self {
48            num_meta,
49            num_extended_meta,
50        }
51    }
52
53    fn offset_of(&self, id: u32) -> u32 {
54        self.num_meta * id
55    }
56
57    fn base_len(&self) -> u32 {
58        self.num_meta * BASE_LEN
59    }
60
61    fn offset_of_extended(&self, id: u32) -> u32 {
62        self.base_len() + self.num_extended_meta * id
63    }
64
65    pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
66        self.offset_of(BUFFER_LEN) + buffer_idx
67    }
68
69    pub fn len_index(&self, buffer_idx: u32) -> u32 {
70        self.offset_of(LENGTH) + buffer_idx
71    }
72
73    pub fn rank_index(&self, buffer_idx: u32) -> u32 {
74        self.offset_of_extended(RANK) + buffer_idx
75    }
76
77    pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
78        self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
79    }
80
81    pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
82        self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
83    }
84}
85
86/// Builder for a serialized metadata struct
87///
88/// Inputs/Outputs must be added in the same order they're defined in the bind group
89#[derive(Default)]
90pub struct MetadataBuilder {
91    buffer_lens: Vec<u32>,
92    lengths: Vec<u32>,
93    ranks: Vec<u32>,
94    shapes: Vec<Vec<u32>>,
95    strides: Vec<Vec<u32>>,
96}
97
98impl MetadataBuilder {
99    /// Add an array to a builder
100    pub fn with_array(&mut self, buffer_len: u32, len: u32) {
101        self.buffer_lens.push(buffer_len);
102        self.lengths.push(len);
103    }
104
105    /// Add a tensor to a builder
106    pub fn with_tensor(
107        &mut self,
108        rank: u32,
109        buffer_len: u32,
110        len: u32,
111        shape: Vec<u32>,
112        strides: Vec<u32>,
113    ) {
114        self.buffer_lens.push(buffer_len);
115        self.lengths.push(len);
116        self.ranks.push(rank);
117        self.shapes.push(shape);
118        self.strides.push(strides);
119    }
120
121    /// Build the final serialized metadata struct
122    pub fn finish(self) -> Vec<u32> {
123        let mut meta = self.buffer_lens;
124        meta.extend(self.lengths);
125        meta.extend(self.ranks.clone());
126
127        let num_ext = self.ranks.len();
128        let mut shape_offsets = Vec::with_capacity(num_ext);
129        let mut stride_offsets = Vec::with_capacity(num_ext);
130
131        let mut current_offset = meta.len() + num_ext * 2; // Total fields in static portion
132
133        for shape in self.shapes.iter() {
134            shape_offsets.push(current_offset as u32);
135            current_offset += shape.len();
136        }
137
138        meta.extend(shape_offsets);
139
140        for stride in self.strides.iter() {
141            stride_offsets.push(current_offset as u32);
142            current_offset += stride.len();
143        }
144
145        meta.extend(stride_offsets);
146
147        meta.extend(self.shapes.into_iter().flatten());
148        meta.extend(self.strides.into_iter().flatten());
149        meta
150    }
151}