cubecl_core/codegen/
metadata.rs1use cubecl_runtime::server::MetadataBinding;
29
30const BUFFER_LEN: u32 = 0;
32const LENGTH: u32 = 1;
33const BASE_LEN: u32 = 2;
34
35const RANK: u32 = 0;
37const SHAPE_OFFSETS: u32 = 1;
38const STRIDE_OFFSETS: u32 = 2;
39const EXTENDED_LEN: u32 = 3;
40
41#[derive(Clone, Debug, Default)]
43pub struct Metadata {
44 num_meta: u32,
45 num_extended_meta: u32,
46}
47
48impl Metadata {
49 pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
50 Self {
51 num_meta,
52 num_extended_meta,
53 }
54 }
55
56 fn offset_of(&self, id: u32) -> u32 {
57 self.num_meta * id
58 }
59
60 fn base_len(&self) -> u32 {
61 self.num_meta * BASE_LEN
62 }
63
64 pub fn static_len(&self) -> u32 {
65 self.num_meta * BASE_LEN + self.num_extended_meta * EXTENDED_LEN
66 }
67
68 fn offset_of_extended(&self, id: u32) -> u32 {
69 self.base_len() + self.num_extended_meta * id
70 }
71
72 pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
73 self.offset_of(BUFFER_LEN) + buffer_idx
74 }
75
76 pub fn len_index(&self, buffer_idx: u32) -> u32 {
77 self.offset_of(LENGTH) + buffer_idx
78 }
79
80 pub fn rank_index(&self, buffer_idx: u32) -> u32 {
81 self.offset_of_extended(RANK) + buffer_idx
82 }
83
84 pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
85 self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
86 }
87
88 pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
89 self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
90 }
91}
92
93#[derive(Default)]
97pub struct MetadataBuilder {
98 buffer_lens: Vec<u32>,
99 lengths: Vec<u32>,
100 ranks: Vec<u32>,
101 shapes: Vec<Vec<u32>>,
102 strides: Vec<Vec<u32>>,
103}
104
105impl MetadataBuilder {
106 pub fn with_array(&mut self, buffer_len: u32, len: u32) {
108 self.buffer_lens.push(buffer_len);
109 self.lengths.push(len);
110 }
111
112 pub fn with_tensor(
114 &mut self,
115 rank: u32,
116 buffer_len: u32,
117 len: u32,
118 shape: Vec<u32>,
119 strides: Vec<u32>,
120 ) {
121 self.buffer_lens.push(buffer_len);
122 self.lengths.push(len);
123 self.ranks.push(rank);
124 self.shapes.push(shape);
125 self.strides.push(strides);
126 }
127
128 pub fn finish(self) -> MetadataBinding {
130 let mut meta = self.buffer_lens;
131 meta.extend(self.lengths);
132 meta.extend(self.ranks.clone());
133
134 let num_ext = self.ranks.len();
135 let mut shape_offsets = Vec::with_capacity(num_ext);
136 let mut stride_offsets = Vec::with_capacity(num_ext);
137
138 let mut current_offset = meta.len() + num_ext * 2; for shape in self.shapes.iter() {
141 shape_offsets.push(current_offset as u32);
142 current_offset += shape.len();
143 }
144
145 meta.extend(shape_offsets);
146
147 for stride in self.strides.iter() {
148 stride_offsets.push(current_offset as u32);
149 current_offset += stride.len();
150 }
151
152 meta.extend(stride_offsets);
153
154 let static_len = meta.len();
155
156 meta.extend(self.shapes.into_iter().flatten());
157 meta.extend(self.strides.into_iter().flatten());
158
159 MetadataBinding::new(meta, static_len)
160 }
161}