cubecl_core/codegen/
metadata.rs1const BUFFER_LEN: u32 = 0;
30const LENGTH: u32 = 1;
31const BASE_LEN: u32 = 2;
32
33const RANK: u32 = 0;
35const SHAPE_OFFSETS: u32 = 1;
36const STRIDE_OFFSETS: u32 = 2;
37
38#[derive(Clone, Debug, Default)]
40pub struct Metadata {
41 num_meta: u32,
42 num_extended_meta: u32,
43}
44
45impl Metadata {
46 pub fn new(num_meta: u32, num_extended_meta: u32) -> Self {
47 Self {
48 num_meta,
49 num_extended_meta,
50 }
51 }
52
53 fn offset_of(&self, id: u32) -> u32 {
54 self.num_meta * id
55 }
56
57 fn base_len(&self) -> u32 {
58 self.num_meta * BASE_LEN
59 }
60
61 fn offset_of_extended(&self, id: u32) -> u32 {
62 self.base_len() + self.num_extended_meta * id
63 }
64
65 pub fn buffer_len_index(&self, buffer_idx: u32) -> u32 {
66 self.offset_of(BUFFER_LEN) + buffer_idx
67 }
68
69 pub fn len_index(&self, buffer_idx: u32) -> u32 {
70 self.offset_of(LENGTH) + buffer_idx
71 }
72
73 pub fn rank_index(&self, buffer_idx: u32) -> u32 {
74 self.offset_of_extended(RANK) + buffer_idx
75 }
76
77 pub fn shape_offset_index(&self, buffer_idx: u32) -> u32 {
78 self.offset_of_extended(SHAPE_OFFSETS) + buffer_idx
79 }
80
81 pub fn stride_offset_index(&self, buffer_idx: u32) -> u32 {
82 self.offset_of_extended(STRIDE_OFFSETS) + buffer_idx
83 }
84}
85
86#[derive(Default)]
90pub struct MetadataBuilder {
91 buffer_lens: Vec<u32>,
92 lengths: Vec<u32>,
93 ranks: Vec<u32>,
94 shapes: Vec<Vec<u32>>,
95 strides: Vec<Vec<u32>>,
96}
97
98impl MetadataBuilder {
99 pub fn with_array(&mut self, buffer_len: u32, len: u32) {
101 self.buffer_lens.push(buffer_len);
102 self.lengths.push(len);
103 }
104
105 pub fn with_tensor(
107 &mut self,
108 rank: u32,
109 buffer_len: u32,
110 len: u32,
111 shape: Vec<u32>,
112 strides: Vec<u32>,
113 ) {
114 self.buffer_lens.push(buffer_len);
115 self.lengths.push(len);
116 self.ranks.push(rank);
117 self.shapes.push(shape);
118 self.strides.push(strides);
119 }
120
121 pub fn finish(self) -> Vec<u32> {
123 let mut meta = self.buffer_lens;
124 meta.extend(self.lengths);
125 meta.extend(self.ranks.clone());
126
127 let num_ext = self.ranks.len();
128 let mut shape_offsets = Vec::with_capacity(num_ext);
129 let mut stride_offsets = Vec::with_capacity(num_ext);
130
131 let mut current_offset = meta.len() + num_ext * 2; for shape in self.shapes.iter() {
134 shape_offsets.push(current_offset as u32);
135 current_offset += shape.len();
136 }
137
138 meta.extend(shape_offsets);
139
140 for stride in self.strides.iter() {
141 stride_offsets.push(current_offset as u32);
142 current_offset += stride.len();
143 }
144
145 meta.extend(stride_offsets);
146
147 meta.extend(self.shapes.into_iter().flatten());
148 meta.extend(self.strides.into_iter().flatten());
149 meta
150 }
151}