Skip to main content

cubek_std/
input_binding.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    ir::{AddressType, StorageType},
5    prelude::{CubePrimitive, TensorBinding},
6    server::LaunchError,
7    zspace::Shape,
8};
9use cubecl::{
10    quant::scheme::{BlockSize, QuantLevel},
11    std::tensor::{into_contiguous_packed, into_contiguous_pitched},
12};
13use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
14
15#[derive(Debug)]
16#[allow(clippy::large_enum_variant)]
17pub enum InputBinding<R: Runtime> {
18    Normal(TensorBinding<R>, StorageType),
19    Quantized {
20        data: TensorBinding<R>,
21        data_dtype: StorageType,
22        scale: TensorBinding<R>,
23        scale_dtype: StorageType,
24        /// Unpacked shape, excluding padding
25        shape: Shape,
26        scheme: QuantScheme,
27    },
28}
29
30impl<R: Runtime> Clone for InputBinding<R> {
31    fn clone(&self) -> Self {
32        match self {
33            Self::Normal(arg0, arg1) => Self::Normal(arg0.clone(), *arg1),
34            Self::Quantized {
35                data,
36                data_dtype,
37                scale,
38                scale_dtype,
39                shape,
40                scheme,
41            } => Self::Quantized {
42                data: data.clone(),
43                data_dtype: *data_dtype,
44                scale: scale.clone(),
45                scale_dtype: *scale_dtype,
46                shape: shape.clone(),
47                scheme: *scheme,
48            },
49        }
50    }
51}
52
53impl<R: Runtime> InputBinding<R> {
54    pub fn new(data: TensorBinding<R>, dtype: StorageType) -> Self {
55        Self::Normal(data, dtype)
56    }
57
58    pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
59        match self {
60            Self::Normal(handle, _dtype) => {
61                handle.shape.swap(dim0, dim1);
62                handle.strides.swap(dim0, dim1);
63            }
64            Self::Quantized {
65                data,
66                scale,
67                shape,
68                scheme,
69                data_dtype: _,
70                scale_dtype: _,
71            } => {
72                let rank = data.shape.len();
73
74                data.shape.swap(dim0, dim1);
75                data.strides.swap(dim0, dim1);
76
77                // Swap dims for scale and block size if block scaled quant is used
78                if let QuantLevel::Block(block) = &mut scheme.level {
79                    scale.shape.swap(dim0, dim1);
80                    scale.strides.swap(dim0, dim1);
81
82                    let mut block_size = block.to_dim_vec(rank);
83                    block_size.swap(dim0, dim1);
84                    *block = BlockSize::new_trim(block_size)
85                }
86
87                shape.swap(dim0, dim1);
88
89                // Swap packed dim if packed dim is either of `dim0` or `dim1`
90                if let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
91                    &mut scheme.store
92                {
93                    if *packed_dim == rank - dim0 - 1 {
94                        *packed_dim = rank - dim1 - 1;
95                    } else if *packed_dim == rank - dim1 - 1 {
96                        *packed_dim = rank - dim0 - 1;
97                    }
98                }
99            }
100        }
101    }
102    pub fn quantized(
103        data: TensorBinding<R>,
104        scale: TensorBinding<R>,
105        shape: Shape,
106        scheme: QuantScheme,
107        data_dtype: StorageType,
108        scale_dtype: StorageType,
109    ) -> Self {
110        Self::Quantized {
111            data,
112            scale,
113            shape,
114            scheme,
115            data_dtype,
116            scale_dtype,
117        }
118    }
119
120    pub fn data(&self) -> &TensorBinding<R> {
121        match self {
122            InputBinding::Normal(handle, ..) => handle,
123            InputBinding::Quantized { data, .. } => data,
124        }
125    }
126
127    pub fn data_elem_size(&self) -> usize {
128        match self {
129            InputBinding::Normal(_, ty) => ty.size(),
130            InputBinding::Quantized { data_dtype, .. } => data_dtype.size(),
131        }
132    }
133
134    pub fn into_data(self) -> TensorBinding<R> {
135        match self {
136            InputBinding::Normal(handle, ..) => handle,
137            InputBinding::Quantized { data, .. } => data,
138        }
139    }
140
141    pub fn data_mut(&mut self) -> &mut TensorBinding<R> {
142        match self {
143            InputBinding::Normal(handle, ..) => handle,
144            InputBinding::Quantized { data, .. } => data,
145        }
146    }
147
148    pub fn scale(&self) -> Option<&TensorBinding<R>> {
149        match self {
150            InputBinding::Normal(..) => None,
151            InputBinding::Quantized { scale, .. } => Some(scale),
152        }
153    }
154
155    pub fn scheme(&self) -> Option<&QuantScheme> {
156        match self {
157            InputBinding::Normal(..) => None,
158            InputBinding::Quantized { scheme, .. } => Some(scheme),
159        }
160    }
161
162    pub fn shape(&self) -> &Shape {
163        match self {
164            InputBinding::Normal(handle, ..) => &handle.shape,
165            InputBinding::Quantized { shape, .. } => shape,
166        }
167    }
168
169    pub fn into_contiguous(self, client: &ComputeClient<R>) -> Result<Self, LaunchError> {
170        let val = match self {
171            Self::Normal(data, dtype) => Self::Normal(
172                into_contiguous_pitched(client, data, dtype).binding(),
173                dtype,
174            ),
175            Self::Quantized {
176                data,
177                scale,
178                shape,
179                scheme,
180                data_dtype,
181                scale_dtype,
182            } => {
183                let mut scheme = scheme;
184                let data = match scheme.store {
185                    // e2m1 has native packing (e2m1x2) so also needs to be re-packed
186                    QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
187                        let mut data = into_contiguous_packed(
188                            client,
189                            data,
190                            packed_dim,
191                            &shape,
192                            scheme.num_quants(),
193                            u8::as_type_native_unchecked().storage_type(),
194                        );
195                        scheme = scheme.with_store(QuantStore::PackedNative(0));
196                        data.dtype = data_dtype;
197                        data
198                    }
199                    QuantStore::PackedU32(packed_dim) => {
200                        let mut data = into_contiguous_packed(
201                            client,
202                            data,
203                            packed_dim,
204                            &shape,
205                            scheme.num_quants(),
206                            u32::as_type_native_unchecked().storage_type(),
207                        );
208                        data.dtype = data_dtype;
209                        scheme = scheme.with_store(QuantStore::PackedU32(0));
210                        data
211                    }
212                    _ => into_contiguous_pitched(client, data, data_dtype),
213                };
214
215                Self::Quantized {
216                    data: data.binding(),
217                    scale,
218                    shape,
219                    scheme,
220                    data_dtype,
221                    scale_dtype,
222                }
223            }
224        };
225
226        Ok(val)
227    }
228
229    pub fn required_address_type(&self) -> AddressType {
230        match self {
231            InputBinding::Normal(handle, ty) => handle.required_address_type(ty.size()),
232            InputBinding::Quantized {
233                data,
234                shape,
235                scheme,
236                ..
237            } => {
238                let handle_addr = data.required_address_type(scheme.size_bits_stored() / 8);
239                let conceptual_addr = AddressType::from_len(shape.iter().product());
240                handle_addr.max(conceptual_addr)
241            }
242        }
243    }
244}