cubecl_core/codegen/
scalars.rs1use alloc::vec::Vec;
2
3use cubecl_ir::StorageType;
4
5use crate::{INFO_ALIGN, ScalarArgType};
6
7pub type ScalarValues = Vec<u8>;
9
10#[derive(Default)]
11pub struct ScalarBuilder {
12 scalars: Vec<(StorageType, ScalarValues)>,
14}
15
16impl ScalarBuilder {
17 pub fn push<T: ScalarArgType>(&mut self, val: T) {
19 let val = [val];
20 let bytes = T::as_bytes(&val);
21 self.get_or_insert_mut(T::cube_type())
22 .extend(bytes.iter().copied());
23 }
24
25 pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
27 self.get_or_insert_mut(dtype).extend(bytes.iter().copied());
28 }
29
30 fn get_or_insert_mut(&mut self, ty: StorageType) -> &mut ScalarValues {
31 let pos = self.scalars.iter().position(|(k, _)| *k >= ty);
32
33 match pos {
34 Some(i) if self.scalars[i].0 == ty => &mut self.scalars[i].1,
35 Some(i) => {
36 self.scalars.insert(i, (ty, Vec::new()));
37 &mut self.scalars[i].1
38 }
39 None => {
40 self.scalars.push((ty, Vec::new()));
41 &mut self.scalars.last_mut().unwrap().1
42 }
43 }
44 }
45
46 pub fn len_aligned(&self) -> usize {
47 self.scalars
48 .iter()
49 .map(|(_, v)| v.len().div_ceil(INFO_ALIGN))
50 .sum()
51 }
52
53 pub fn finish(&mut self, out: &mut [u64]) {
54 let mut out_u8 = bytemuck::cast_slice_mut::<u64, u8>(out);
55
56 for (_, values) in self.scalars.iter_mut().filter(|(_, v)| !v.is_empty()) {
57 let len_padded = values.len().next_multiple_of(INFO_ALIGN);
58
59 out_u8[0..values.len()].copy_from_slice(values);
60 out_u8 = &mut out_u8[len_padded..];
61 values.clear();
62 }
63 }
64}