1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 ir::{AddressType, StorageType},
5 prelude::{CubePrimitive, TensorBinding},
6 server::LaunchError,
7 zspace::Shape,
8};
9use cubecl::{
10 quant::scheme::{BlockSize, QuantLevel},
11 std::tensor::{into_contiguous_packed, into_contiguous_pitched},
12};
13use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
14
15#[derive(Debug)]
16#[allow(clippy::large_enum_variant)]
17pub enum InputBinding<R: Runtime> {
18 Normal(TensorBinding<R>, StorageType),
19 Quantized {
20 data: TensorBinding<R>,
21 data_dtype: StorageType,
22 scale: TensorBinding<R>,
23 scale_dtype: StorageType,
24 shape: Shape,
26 scheme: QuantScheme,
27 },
28}
29
30impl<R: Runtime> Clone for InputBinding<R> {
31 fn clone(&self) -> Self {
32 match self {
33 Self::Normal(arg0, arg1) => Self::Normal(arg0.clone(), *arg1),
34 Self::Quantized {
35 data,
36 data_dtype,
37 scale,
38 scale_dtype,
39 shape,
40 scheme,
41 } => Self::Quantized {
42 data: data.clone(),
43 data_dtype: *data_dtype,
44 scale: scale.clone(),
45 scale_dtype: *scale_dtype,
46 shape: shape.clone(),
47 scheme: *scheme,
48 },
49 }
50 }
51}
52
53impl<R: Runtime> InputBinding<R> {
54 pub fn new(data: TensorBinding<R>, dtype: StorageType) -> Self {
55 Self::Normal(data, dtype)
56 }
57
58 pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
59 match self {
60 Self::Normal(handle, _dtype) => {
61 handle.shape.swap(dim0, dim1);
62 handle.strides.swap(dim0, dim1);
63 }
64 Self::Quantized {
65 data,
66 scale,
67 shape,
68 scheme,
69 data_dtype: _,
70 scale_dtype: _,
71 } => {
72 let rank = data.shape.len();
73
74 data.shape.swap(dim0, dim1);
75 data.strides.swap(dim0, dim1);
76
77 if let QuantLevel::Block(block) = &mut scheme.level {
79 scale.shape.swap(dim0, dim1);
80 scale.strides.swap(dim0, dim1);
81
82 let mut block_size = block.to_dim_vec(rank);
83 block_size.swap(dim0, dim1);
84 *block = BlockSize::new_trim(block_size)
85 }
86
87 shape.swap(dim0, dim1);
88
89 if let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
91 &mut scheme.store
92 {
93 if *packed_dim == rank - dim0 - 1 {
94 *packed_dim = rank - dim1 - 1;
95 } else if *packed_dim == rank - dim1 - 1 {
96 *packed_dim = rank - dim0 - 1;
97 }
98 }
99 }
100 }
101 }
102 pub fn quantized(
103 data: TensorBinding<R>,
104 scale: TensorBinding<R>,
105 shape: Shape,
106 scheme: QuantScheme,
107 data_dtype: StorageType,
108 scale_dtype: StorageType,
109 ) -> Self {
110 Self::Quantized {
111 data,
112 scale,
113 shape,
114 scheme,
115 data_dtype,
116 scale_dtype,
117 }
118 }
119
120 pub fn data(&self) -> &TensorBinding<R> {
121 match self {
122 InputBinding::Normal(handle, ..) => handle,
123 InputBinding::Quantized { data, .. } => data,
124 }
125 }
126
127 pub fn data_elem_size(&self) -> usize {
128 match self {
129 InputBinding::Normal(_, ty) => ty.size(),
130 InputBinding::Quantized { data_dtype, .. } => data_dtype.size(),
131 }
132 }
133
134 pub fn into_data(self) -> TensorBinding<R> {
135 match self {
136 InputBinding::Normal(handle, ..) => handle,
137 InputBinding::Quantized { data, .. } => data,
138 }
139 }
140
141 pub fn data_mut(&mut self) -> &mut TensorBinding<R> {
142 match self {
143 InputBinding::Normal(handle, ..) => handle,
144 InputBinding::Quantized { data, .. } => data,
145 }
146 }
147
148 pub fn scale(&self) -> Option<&TensorBinding<R>> {
149 match self {
150 InputBinding::Normal(..) => None,
151 InputBinding::Quantized { scale, .. } => Some(scale),
152 }
153 }
154
155 pub fn scheme(&self) -> Option<&QuantScheme> {
156 match self {
157 InputBinding::Normal(..) => None,
158 InputBinding::Quantized { scheme, .. } => Some(scheme),
159 }
160 }
161
162 pub fn shape(&self) -> &Shape {
163 match self {
164 InputBinding::Normal(handle, ..) => &handle.shape,
165 InputBinding::Quantized { shape, .. } => shape,
166 }
167 }
168
169 pub fn into_contiguous(self, client: &ComputeClient<R>) -> Result<Self, LaunchError> {
170 let val = match self {
171 Self::Normal(data, dtype) => Self::Normal(
172 into_contiguous_pitched(client, data, dtype).binding(),
173 dtype,
174 ),
175 Self::Quantized {
176 data,
177 scale,
178 shape,
179 scheme,
180 data_dtype,
181 scale_dtype,
182 } => {
183 let mut scheme = scheme;
184 let data = match scheme.store {
185 QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
187 let mut data = into_contiguous_packed(
188 client,
189 data,
190 packed_dim,
191 &shape,
192 scheme.num_quants(),
193 u8::as_type_native_unchecked().storage_type(),
194 );
195 scheme = scheme.with_store(QuantStore::PackedNative(0));
196 data.dtype = data_dtype;
197 data
198 }
199 QuantStore::PackedU32(packed_dim) => {
200 let mut data = into_contiguous_packed(
201 client,
202 data,
203 packed_dim,
204 &shape,
205 scheme.num_quants(),
206 u32::as_type_native_unchecked().storage_type(),
207 );
208 data.dtype = data_dtype;
209 scheme = scheme.with_store(QuantStore::PackedU32(0));
210 data
211 }
212 _ => into_contiguous_pitched(client, data, data_dtype),
213 };
214
215 Self::Quantized {
216 data: data.binding(),
217 scale,
218 shape,
219 scheme,
220 data_dtype,
221 scale_dtype,
222 }
223 }
224 };
225
226 Ok(val)
227 }
228
229 pub fn required_address_type(&self) -> AddressType {
230 match self {
231 InputBinding::Normal(handle, ty) => handle.required_address_type(ty.size()),
232 InputBinding::Quantized {
233 data,
234 shape,
235 scheme,
236 ..
237 } => {
238 let handle_addr = data.required_address_type(scheme.size_bits_stored() / 8);
239 let conceptual_addr = AddressType::from_len(shape.iter().product());
240 handle_addr.max(conceptual_addr)
241 }
242 }
243 }
244}