1use core::mem;
2
3use burn_tensor::{
4 DType, Element, Shape, TensorData, TensorMetadata,
5 quantization::{
6 QParams, QTensorPrimitive, QuantizationMode, QuantizationScheme, QuantizationStrategy,
7 QuantizationType, SymmetricQuantization,
8 },
9};
10
11use alloc::vec::Vec;
12use ndarray::{ArcArray, ArrayD, IxDyn};
13
14use crate::element::QuantElement;
15
16#[derive(new, Debug, Clone)]
18pub struct NdArrayTensor<E> {
19 pub array: ArcArray<E, IxDyn>,
21}
22
23impl<E: Element> TensorMetadata for NdArrayTensor<E> {
24 fn dtype(&self) -> DType {
25 E::dtype()
26 }
27
28 fn shape(&self) -> Shape {
29 Shape::from(self.array.shape().to_vec())
30 }
31}
32
33#[derive(Debug, Clone)]
35pub enum NdArrayTensorFloat {
36 F32(NdArrayTensor<f32>),
38 F64(NdArrayTensor<f64>),
40}
41
42impl From<NdArrayTensor<f32>> for NdArrayTensorFloat {
43 fn from(value: NdArrayTensor<f32>) -> Self {
44 NdArrayTensorFloat::F32(value)
45 }
46}
47
48impl From<NdArrayTensor<f64>> for NdArrayTensorFloat {
49 fn from(value: NdArrayTensor<f64>) -> Self {
50 NdArrayTensorFloat::F64(value)
51 }
52}
53
54impl TensorMetadata for NdArrayTensorFloat {
55 fn dtype(&self) -> DType {
56 match self {
57 NdArrayTensorFloat::F32(tensor) => tensor.dtype(),
58 NdArrayTensorFloat::F64(tensor) => tensor.dtype(),
59 }
60 }
61
62 fn shape(&self) -> Shape {
63 match self {
64 NdArrayTensorFloat::F32(tensor) => tensor.shape(),
65 NdArrayTensorFloat::F64(tensor) => tensor.shape(),
66 }
67 }
68}
69
70#[macro_export]
72macro_rules! new_tensor_float {
73 ($tensor:expr) => {{
75 match E::dtype() {
76 burn_tensor::DType::F64 => $crate::NdArrayTensorFloat::F64($tensor),
77 burn_tensor::DType::F32 => $crate::NdArrayTensorFloat::F32($tensor),
78 _ => unimplemented!("Unsupported dtype"),
80 }
81 }};
82}
83
84#[macro_export]
90macro_rules! execute_with_float_dtype {
91 (($lhs:expr, $rhs:expr), $op:expr) => {{
93 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
94 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
95 match ($lhs, $rhs) {
96 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
97 $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
98 }
99 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
100 $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
101 }
102 _ => panic!(
103 "Data type mismatch (lhs: {:?}, rhs: {:?})",
104 lhs_dtype, rhs_dtype
105 ),
106 }
107 }};
108
109 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
111 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
112 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
113 match ($lhs, $rhs) {
114 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
115 type $element = f64;
116 $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
117 }
118 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
119 type $element = f32;
120 $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
121 }
122 _ => panic!(
123 "Data type mismatch (lhs: {:?}, rhs: {:?})",
124 lhs_dtype, rhs_dtype
125 ),
126 }
127 }};
128
129 (($lhs:expr, $rhs:expr) => $op:expr) => {{
131 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
132 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
133 match ($lhs, $rhs) {
134 ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
135 $op(lhs, rhs)
136 }
137 ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
138 $op(lhs, rhs)
139 }
140 _ => panic!(
141 "Data type mismatch (lhs: {:?}, rhs: {:?})",
142 lhs_dtype, rhs_dtype
143 ),
144 }
145 }};
146
147 ($tensor:expr, $op:expr) => {{
149 match $tensor {
150 $crate::NdArrayTensorFloat::F64(tensor) => $crate::NdArrayTensorFloat::F64($op(tensor)),
151 $crate::NdArrayTensorFloat::F32(tensor) => $crate::NdArrayTensorFloat::F32($op(tensor)),
152 }
153 }};
154
155 ($tensor:expr, $element:ident, $op:expr) => {{
157 match $tensor {
158 $crate::NdArrayTensorFloat::F64(tensor) => {
159 type $element = f64;
160 $crate::NdArrayTensorFloat::F64($op(tensor))
161 }
162 $crate::NdArrayTensorFloat::F32(tensor) => {
163 type $element = f32;
164 $crate::NdArrayTensorFloat::F32($op(tensor))
165 }
166 }
167 }};
168
169 ($tensor:expr => $op:expr) => {{
171 match $tensor {
172 $crate::NdArrayTensorFloat::F64(tensor) => $op(tensor),
173 $crate::NdArrayTensorFloat::F32(tensor) => $op(tensor),
174 }
175 }};
176
177 ($tensor:expr, $element:ident => $op:expr) => {{
179 match $tensor {
180 $crate::NdArrayTensorFloat::F64(tensor) => {
181 type $element = f64;
182 $op(tensor)
183 }
184 $crate::NdArrayTensorFloat::F32(tensor) => {
185 type $element = f32;
186 $op(tensor)
187 }
188 }
189 }};
190}
191
192mod utils {
193 use super::*;
194
195 impl<E> NdArrayTensor<E>
196 where
197 E: Element,
198 {
199 pub(crate) fn into_data(self) -> TensorData {
200 let shape = self.shape();
201
202 let vec = if self.is_contiguous() {
203 match self.array.try_into_owned_nocopy() {
204 Ok(owned) => {
205 let (mut vec, offset) = owned.into_raw_vec_and_offset();
206 if let Some(offset) = offset {
207 vec.drain(..offset);
208 }
209 vec
210 }
211 Err(array) => array.into_iter().collect(),
212 }
213 } else {
214 self.array.into_iter().collect()
215 };
216
217 TensorData::new(vec, shape)
218 }
219
220 pub(crate) fn is_contiguous(&self) -> bool {
221 let shape = self.array.shape();
222 let strides = self.array.strides();
223
224 if shape.is_empty() {
225 return true;
226 }
227
228 if shape.len() == 1 {
229 return strides[0] == 1;
230 }
231
232 let mut prev_stride = 1;
233 let mut current_num_elems_shape = 1;
234
235 for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
236 let stride = if *stride <= 0 {
237 return false;
238 } else {
239 *stride as usize
240 };
241 if i > 0 {
242 if current_num_elems_shape != stride {
243 return false;
244 }
245
246 if prev_stride > stride {
247 return false;
248 }
249 }
250
251 current_num_elems_shape *= shape;
252 prev_stride = stride;
253 }
254
255 true
256 }
257 }
258}
259
260#[macro_export(local_inner_macros)]
262macro_rules! to_typed_dims {
263 (
264 $n:expr,
265 $dims:expr,
266 justdim
267 ) => {{
268 let mut dims = [0; $n];
269 for i in 0..$n {
270 dims[i] = $dims[i];
271 }
272 let dim: Dim<[usize; $n]> = Dim(dims);
273 dim
274 }};
275}
276
277#[macro_export(local_inner_macros)]
279macro_rules! reshape {
280 (
281 ty $ty:ty,
282 n $n:expr,
283 shape $shape:expr,
284 array $array:expr
285 ) => {{
286 let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
287 let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
288 true => {
289 match $array.to_shape(dim) {
290 Ok(val) => val.into_shared(),
291 Err(err) => {
292 core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
293 }
294 }
295 },
296 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
297 };
298 let array = array.into_dyn();
299
300 NdArrayTensor::new(array)
301 }};
302 (
303 ty $ty:ty,
304 shape $shape:expr,
305 array $array:expr,
306 d $D:expr
307 ) => {{
308 match $D {
309 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
310 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
311 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
312 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
313 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
314 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
315 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
316 }
317 }};
318}
319
320impl<E> NdArrayTensor<E>
321where
322 E: Element,
323{
324 pub fn from_data(mut data: TensorData) -> NdArrayTensor<E> {
326 let shape = mem::take(&mut data.shape);
327
328 let array = match data.into_vec::<E>() {
329 Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
331 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
332 };
333
334 NdArrayTensor::new(array)
335 }
336}
337
338#[derive(Clone, Debug)]
340pub struct NdArrayQTensor<Q: QuantElement> {
341 pub qtensor: NdArrayTensor<Q>,
343 pub scheme: QuantizationScheme,
345 pub qparams: Vec<QParams<f32, Q>>,
347}
348
349impl<Q: QuantElement> NdArrayQTensor<Q> {
350 pub fn strategy(&self) -> QuantizationStrategy {
352 match self.scheme {
353 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
354 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
355 self.qparams[0].scale,
356 ))
357 }
358 }
359 }
360}
361
362impl<Q: QuantElement> QTensorPrimitive for NdArrayQTensor<Q> {
363 fn scheme(&self) -> &QuantizationScheme {
364 &self.scheme
365 }
366}
367
368impl<Q: QuantElement> TensorMetadata for NdArrayQTensor<Q> {
369 fn dtype(&self) -> DType {
370 DType::QFloat(self.scheme)
371 }
372
373 fn shape(&self) -> Shape {
374 self.qtensor.shape()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use crate::NdArray;
381
382 use super::*;
383 use burn_common::rand::get_seeded_rng;
384 use burn_tensor::{
385 Distribution,
386 ops::{FloatTensorOps, QTensorOps},
387 quantization::{QuantizationParametersPrimitive, QuantizationType},
388 };
389
390 #[test]
391 fn should_support_into_and_from_data_1d() {
392 let data_expected = TensorData::random::<f32, _, _>(
393 Shape::new([3]),
394 Distribution::Default,
395 &mut get_seeded_rng(),
396 );
397 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
398
399 let data_actual = tensor.into_data();
400
401 assert_eq!(data_expected, data_actual);
402 }
403
404 #[test]
405 fn should_support_into_and_from_data_2d() {
406 let data_expected = TensorData::random::<f32, _, _>(
407 Shape::new([2, 3]),
408 Distribution::Default,
409 &mut get_seeded_rng(),
410 );
411 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
412
413 let data_actual = tensor.into_data();
414
415 assert_eq!(data_expected, data_actual);
416 }
417
418 #[test]
419 fn should_support_into_and_from_data_3d() {
420 let data_expected = TensorData::random::<f32, _, _>(
421 Shape::new([2, 3, 4]),
422 Distribution::Default,
423 &mut get_seeded_rng(),
424 );
425 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
426
427 let data_actual = tensor.into_data();
428
429 assert_eq!(data_expected, data_actual);
430 }
431
432 #[test]
433 fn should_support_into_and_from_data_4d() {
434 let data_expected = TensorData::random::<f32, _, _>(
435 Shape::new([2, 3, 4, 2]),
436 Distribution::Default,
437 &mut get_seeded_rng(),
438 );
439 let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
440
441 let data_actual = tensor.into_data();
442
443 assert_eq!(data_expected, data_actual);
444 }
445
446 #[test]
447 fn should_support_qtensor_strategy() {
448 type B = NdArray<f32, i64, i8>;
449 let scale: f32 = 0.009_019_608;
450 let device = Default::default();
451
452 let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
453 let scheme =
454 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8);
455 let qparams = QuantizationParametersPrimitive {
456 scale: B::float_from_data(TensorData::from([scale]), &device),
457 offset: None,
458 };
459 let qtensor: NdArrayQTensor<i8> = B::quantize(tensor, &scheme, qparams);
460
461 assert_eq!(qtensor.scheme(), &scheme);
462 assert_eq!(
463 qtensor.strategy(),
464 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale))
465 );
466 }
467}