1use burn_tensor::{
2 quantization::{
3 AffineQuantization, QParams, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
4 QuantizationType, SymmetricQuantization,
5 },
6 DType, Element, Shape, TensorData, TensorMetadata,
7};
8
9use ndarray::{ArcArray, Array, Dim, IxDyn};
10
11use crate::element::QuantElement;
12
13#[derive(new, Debug, Clone)]
15pub struct NdArrayTensor<E> {
16 pub array: ArcArray<E, IxDyn>,
18}
19
20impl<E: Element> TensorMetadata for NdArrayTensor<E> {
21 fn dtype(&self) -> DType {
22 E::dtype()
23 }
24
25 fn shape(&self) -> Shape {
26 Shape::from(self.array.shape().to_vec())
27 }
28}
29
30#[derive(Debug, Clone)]
32pub enum NdArrayTensorFloat {
33 F32(NdArrayTensor<f32>),
35 F64(NdArrayTensor<f64>),
37}
38
39impl From<NdArrayTensor<f32>> for NdArrayTensorFloat {
40 fn from(value: NdArrayTensor<f32>) -> Self {
41 NdArrayTensorFloat::F32(value)
42 }
43}
44
45impl From<NdArrayTensor<f64>> for NdArrayTensorFloat {
46 fn from(value: NdArrayTensor<f64>) -> Self {
47 NdArrayTensorFloat::F64(value)
48 }
49}
50
51impl TensorMetadata for NdArrayTensorFloat {
52 fn dtype(&self) -> DType {
53 match self {
54 NdArrayTensorFloat::F32(tensor) => tensor.dtype(),
55 NdArrayTensorFloat::F64(tensor) => tensor.dtype(),
56 }
57 }
58
59 fn shape(&self) -> Shape {
60 match self {
61 NdArrayTensorFloat::F32(tensor) => tensor.shape(),
62 NdArrayTensorFloat::F64(tensor) => tensor.shape(),
63 }
64 }
65}
66
67#[macro_export]
69macro_rules! new_tensor_float {
70 ($tensor:expr) => {{
72 match E::dtype() {
73 burn_tensor::DType::F64 => $crate::NdArrayTensorFloat::F64($tensor),
74 burn_tensor::DType::F32 => $crate::NdArrayTensorFloat::F32($tensor),
75 _ => unimplemented!("Unsupported dtype"),
77 }
78 }};
79}
80
81#[macro_export]
87macro_rules! execute_with_float_dtype {
88 (($lhs:expr, $rhs:expr), $op:expr) => {{
90 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
91 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
92 match ($lhs, $rhs) {
93 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
94 $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
95 }
96 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
97 $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
98 }
99 _ => panic!(
100 "Data type mismatch (lhs: {:?}, rhs: {:?})",
101 lhs_dtype, rhs_dtype
102 ),
103 }
104 }};
105
106 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
108 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
109 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
110 match ($lhs, $rhs) {
111 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
112 type $element = f64;
113 $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
114 }
115 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
116 type $element = f32;
117 $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
118 }
119 _ => panic!(
120 "Data type mismatch (lhs: {:?}, rhs: {:?})",
121 lhs_dtype, rhs_dtype
122 ),
123 }
124 }};
125
126 (($lhs:expr, $rhs:expr) => $op:expr) => {{
128 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
129 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
130 match ($lhs, $rhs) {
131 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
132 $op(lhs, rhs)
133 }
134 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
135 $op(lhs, rhs)
136 }
137 _ => panic!(
138 "Data type mismatch (lhs: {:?}, rhs: {:?})",
139 lhs_dtype, rhs_dtype
140 ),
141 }
142 }};
143
144 ($tensor:expr, $op:expr) => {{
146 match $tensor {
147 $crate::NdArrayTensorFloat::F64(tensor) => $crate::NdArrayTensorFloat::F64($op(tensor)),
148 $crate::NdArrayTensorFloat::F32(tensor) => $crate::NdArrayTensorFloat::F32($op(tensor)),
149 }
150 }};
151
152 ($tensor:expr, $element:ident, $op:expr) => {{
154 match $tensor {
155 $crate::NdArrayTensorFloat::F64(tensor) => {
156 type $element = f64;
157 $crate::NdArrayTensorFloat::F64($op(tensor))
158 }
159 $crate::NdArrayTensorFloat::F32(tensor) => {
160 type $element = f32;
161 $crate::NdArrayTensorFloat::F32($op(tensor))
162 }
163 }
164 }};
165
166 ($tensor:expr => $op:expr) => {{
168 match $tensor {
169 $crate::NdArrayTensorFloat::F64(tensor) => $op(tensor),
170 $crate::NdArrayTensorFloat::F32(tensor) => $op(tensor),
171 }
172 }};
173
174 ($tensor:expr, $element:ident => $op:expr) => {{
176 match $tensor {
177 $crate::NdArrayTensorFloat::F64(tensor) => {
178 type $element = f64;
179 $op(tensor)
180 }
181 $crate::NdArrayTensorFloat::F32(tensor) => {
182 type $element = f32;
183 $op(tensor)
184 }
185 }
186 }};
187}
188
189#[cfg(test)]
190mod utils {
191 use super::*;
192 use crate::element::FloatNdArrayElement;
193
194 impl<E> NdArrayTensor<E>
195 where
196 E: Default + Clone,
197 {
198 pub(crate) fn into_data(self) -> TensorData
199 where
200 E: FloatNdArrayElement,
201 {
202 let shape = self.shape();
203
204 let vec = if self.is_contiguous() {
205 match self.array.try_into_owned_nocopy() {
206 Ok(owned) => {
207 let (vec, _offset) = owned.into_raw_vec_and_offset();
208 vec
209 }
210 Err(array) => array.into_iter().collect(),
211 }
212 } else {
213 self.array.into_iter().collect()
214 };
215
216 TensorData::new(vec, shape)
217 }
218
219 pub(crate) fn is_contiguous(&self) -> bool {
220 let shape = self.array.shape();
221 let strides = self.array.strides();
222
223 if shape.is_empty() {
224 return true;
225 }
226
227 if shape.len() == 1 {
228 return strides[0] == 1;
229 }
230
231 let mut prev_stride = 1;
232 let mut current_num_elems_shape = 1;
233
234 for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
235 let stride = if *stride <= 0 {
236 return false;
237 } else {
238 *stride as usize
239 };
240 if i > 0 {
241 if current_num_elems_shape != stride {
242 return false;
243 }
244
245 if prev_stride >= stride {
246 return false;
247 }
248 }
249
250 current_num_elems_shape *= shape;
251 prev_stride = stride;
252 }
253
254 true
255 }
256 }
257}
258
259#[macro_export(local_inner_macros)]
261macro_rules! to_typed_dims {
262 (
263 $n:expr,
264 $dims:expr,
265 justdim
266 ) => {{
267 let mut dims = [0; $n];
268 for i in 0..$n {
269 dims[i] = $dims[i];
270 }
271 let dim: Dim<[usize; $n]> = Dim(dims);
272 dim
273 }};
274}
275
276#[macro_export(local_inner_macros)]
278macro_rules! reshape {
279 (
280 ty $ty:ty,
281 n $n:expr,
282 shape $shape:expr,
283 array $array:expr
284 ) => {{
285 let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
286 let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
287 true => $array
288 .to_shape(dim)
289 .expect("Safe to change shape without relayout")
290 .into_shared(),
291 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
292 };
293 let array = array.into_dyn();
294
295 NdArrayTensor::new(array)
296 }};
297 (
298 ty $ty:ty,
299 shape $shape:expr,
300 array $array:expr,
301 d $D:expr
302 ) => {{
303 match $D {
304 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
305 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
306 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
307 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
308 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
309 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
310 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
311 }
312 }};
313}
314
315impl<E> NdArrayTensor<E>
316where
317 E: Element,
318{
319 pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
321 let shape: Shape = data.shape.clone().into();
322 let into_array = |data: TensorData| match data.into_vec::<E>() {
323 Ok(vec) => Array::from_vec(vec).into_shared(),
324 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
325 };
326 let to_array = |data: TensorData| Array::from_iter(data.iter::<E>()).into_shared();
327
328 let array = if data.dtype == E::dtype() {
329 into_array(data)
330 } else {
331 to_array(data)
332 };
333 let ndims = shape.num_dims();
334
335 reshape!(
336 ty E,
337 shape shape,
338 array array,
339 d ndims
340 )
341 }
342}
343
344#[derive(Clone, Debug)]
346pub struct NdArrayQTensor<Q: QuantElement> {
347 pub qtensor: NdArrayTensor<Q>,
349 pub scheme: QuantizationScheme,
351 pub qparams: QParams<f32, Q>,
353}
354
355impl<Q: QuantElement> NdArrayQTensor<Q> {
356 pub fn strategy(&self) -> QuantizationStrategy {
358 match self.scheme {
359 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
360 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
361 self.qparams.scale,
362 self.qparams.offset.unwrap().elem(),
363 ))
364 }
365 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
366 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
367 self.qparams.scale,
368 ))
369 }
370 }
371 }
372}
373
374impl<Q: QuantElement> QTensorPrimitive for NdArrayQTensor<Q> {
375 fn scheme(&self) -> &QuantizationScheme {
376 &self.scheme
377 }
378}
379
380impl<Q: QuantElement> TensorMetadata for NdArrayQTensor<Q> {
381 fn dtype(&self) -> DType {
382 DType::QFloat(self.scheme)
383 }
384
385 fn shape(&self) -> Shape {
386 self.qtensor.shape()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use crate::NdArray;
393
394 use super::*;
395 use burn_common::rand::get_seeded_rng;
396 use burn_tensor::{
397 ops::{FloatTensorOps, IntTensorOps, QTensorOps},
398 quantization::{AffineQuantization, QuantizationParametersPrimitive, QuantizationType},
399 Distribution,
400 };
401
402 #[test]
403 fn should_support_into_and_from_data_1d() {
404 let data_expected = TensorData::random::<f32, _, _>(
405 Shape::new([3]),
406 Distribution::Default,
407 &mut get_seeded_rng(),
408 );
409 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
410
411 let data_actual = tensor.into_data();
412
413 assert_eq!(data_expected, data_actual);
414 }
415
416 #[test]
417 fn should_support_into_and_from_data_2d() {
418 let data_expected = TensorData::random::<f32, _, _>(
419 Shape::new([2, 3]),
420 Distribution::Default,
421 &mut get_seeded_rng(),
422 );
423 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
424
425 let data_actual = tensor.into_data();
426
427 assert_eq!(data_expected, data_actual);
428 }
429
430 #[test]
431 fn should_support_into_and_from_data_3d() {
432 let data_expected = TensorData::random::<f32, _, _>(
433 Shape::new([2, 3, 4]),
434 Distribution::Default,
435 &mut get_seeded_rng(),
436 );
437 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
438
439 let data_actual = tensor.into_data();
440
441 assert_eq!(data_expected, data_actual);
442 }
443
444 #[test]
445 fn should_support_into_and_from_data_4d() {
446 let data_expected = TensorData::random::<f32, _, _>(
447 Shape::new([2, 3, 4, 2]),
448 Distribution::Default,
449 &mut get_seeded_rng(),
450 );
451 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
452
453 let data_actual = tensor.into_data();
454
455 assert_eq!(data_expected, data_actual);
456 }
457
458 #[test]
459 fn should_support_qtensor_strategy() {
460 type B = NdArray<f32, i64, i8>;
461 let device = Default::default();
462
463 let tensor = B::float_from_data(TensorData::from([-1.8, -1.0, 0.0, 0.5]), &device);
464 let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
465 let qparams = QuantizationParametersPrimitive {
466 scale: B::float_from_data(TensorData::from([0.009_019_608]), &device),
467 offset: Some(B::int_from_data(TensorData::from([72]), &device)),
468 };
469 let qtensor: NdArrayQTensor<i8> = B::quantize(tensor, &scheme, qparams);
470
471 assert_eq!(qtensor.scheme(), &scheme);
472 assert_eq!(
473 qtensor.strategy(),
474 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
475 );
476 }
477}