Skip to main content

cubecl_core/codegen/
scalars.rs

1use alloc::vec::Vec;
2
3use cubecl_ir::StorageType;
4
5use crate::{INFO_ALIGN, ScalarArgType};
6
7/// Stores the data and type for a scalar arg
8pub type ScalarValues = Vec<u8>;
9
10#[derive(Default)]
11pub struct ScalarBuilder {
12    /// Sorted list of scalars, should be faster than `BTreeMap` for this purpose. Benchmark later.
13    scalars: Vec<(StorageType, ScalarValues)>,
14}
15
16impl ScalarBuilder {
17    /// Add a new scalar value to the state.
18    pub fn push<T: ScalarArgType>(&mut self, val: T) {
19        let val = [val];
20        let bytes = T::as_bytes(&val);
21        self.get_or_insert_mut(T::cube_type())
22            .extend(bytes.iter().copied());
23    }
24
25    /// Add a new raw value to the state.
26    pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
27        self.get_or_insert_mut(dtype).extend(bytes.iter().copied());
28    }
29
30    fn get_or_insert_mut(&mut self, ty: StorageType) -> &mut ScalarValues {
31        let pos = self.scalars.iter().position(|(k, _)| *k >= ty);
32
33        match pos {
34            Some(i) if self.scalars[i].0 == ty => &mut self.scalars[i].1,
35            Some(i) => {
36                self.scalars.insert(i, (ty, Vec::new()));
37                &mut self.scalars[i].1
38            }
39            None => {
40                self.scalars.push((ty, Vec::new()));
41                &mut self.scalars.last_mut().unwrap().1
42            }
43        }
44    }
45
46    pub fn len_aligned(&self) -> usize {
47        self.scalars
48            .iter()
49            .map(|(_, v)| v.len().div_ceil(INFO_ALIGN))
50            .sum()
51    }
52
53    pub fn finish(&mut self, out: &mut [u64]) {
54        let mut out_u8 = bytemuck::cast_slice_mut::<u64, u8>(out);
55
56        for (_, values) in self.scalars.iter_mut().filter(|(_, v)| !v.is_empty()) {
57            let len_padded = values.len().next_multiple_of(INFO_ALIGN);
58
59            out_u8[0..values.len()].copy_from_slice(values);
60            out_u8 = &mut out_u8[len_padded..];
61            values.clear();
62        }
63    }
64}