1use alloc::vec::Vec;
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use burn_backend::{
9 DType, ExecutionError, FloatDType, TensorData, TensorMetadata,
10 ops::{IntTensorOps, QTensorOps},
11 quantization::{
12 QuantLevel, QuantScheme, QuantStore, QuantizationParametersPrimitive, QuantizedBytes,
13 },
14 tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
15};
16use burn_std::{Bytes, Shape, Slice, bf16, f16};
17
18use super::float_storage_as_f32;
19use crate::{Flex, FlexQTensor, FlexTensor, Layout};
20
21impl QTensorOps<Flex> for Flex {
22 fn q_from_data(data: TensorData, _device: &Device<Flex>) -> QuantizedTensor<Flex> {
23 let scheme = match data.dtype {
24 DType::QFloat(scheme) => scheme,
25 _ => panic!("Expected quantized dtype, got {:?}", data.dtype),
26 };
27
28 let shape = data.shape.clone();
29 let num_elements = data.num_elements();
30
31 let q_bytes = QuantizedBytes {
32 bytes: data.into_bytes(),
33 scheme,
34 num_elements,
35 };
36
37 let (values, qparams) = q_bytes.into_vec_i8();
38 let tensor_data = TensorData::new(values, shape);
39 let tensor = FlexTensor::from_data(tensor_data);
40
41 let scheme = scheme.with_store(QuantStore::Native);
43
44 FlexQTensor::new(tensor, scheme, qparams.scales)
45 }
46
47 fn quantize_dynamic(tensor: FloatTensor<Flex>, scheme: &QuantScheme) -> QuantizedTensor<Flex> {
48 let shape = tensor.shape();
49 let tensor = tensor.to_contiguous();
50 let float_data = float_storage_as_f32(&tensor);
51 let (a, b) = scheme.value.range();
52 let range = b - a;
53
54 let (quantized, scales) = match scheme.level {
55 QuantLevel::Tensor => {
56 let mut alpha: f32 = 0.0;
58 for &x in &*float_data {
59 let abs = x.abs();
60 if abs > alpha {
61 alpha = abs;
62 }
63 }
64 let scale = validated_scale(2.0 * alpha / range);
65 let inv_scale = 1.0 / scale;
66
67 let quantized = float_data
69 .iter()
70 .map(|&x| (x * inv_scale).round().clamp(a, b) as i8)
71 .collect::<Vec<i8>>();
72
73 (quantized, alloc::vec![scale])
74 }
75 QuantLevel::Block(block_size) => {
76 let block_elems = block_size.num_elements();
77 debug_assert!(
78 float_data.len().is_multiple_of(block_elems),
79 "tensor length {} not divisible by block size {}",
80 float_data.len(),
81 block_elems
82 );
83 let num_blocks = float_data.len() / block_elems;
84 let mut scales = Vec::with_capacity(num_blocks);
85 let mut quantized = Vec::with_capacity(float_data.len());
86
87 for block in float_data.chunks(block_elems) {
88 let mut alpha: f32 = 0.0;
90 for &x in block {
91 let abs = x.abs();
92 if abs > alpha {
93 alpha = abs;
94 }
95 }
96 let scale = validated_scale(2.0 * alpha / range);
97 let inv_scale = 1.0 / scale;
98 scales.push(scale);
99
100 for &x in block {
102 quantized.push((x * inv_scale).round().clamp(a, b) as i8);
103 }
104 }
105
106 (quantized, scales)
107 }
108 };
109
110 let bytes = Bytes::from_elems(quantized);
111 let layout = Layout::contiguous(shape);
112 let qt = FlexTensor::new(bytes, layout, DType::I8);
113
114 FlexQTensor::new(qt, scheme.with_store(QuantStore::Native), scales)
115 }
116
117 fn quantize(
118 tensor: FloatTensor<Flex>,
119 scheme: &QuantScheme,
120 qparams: QuantizationParametersPrimitive<Flex>,
121 ) -> QuantizedTensor<Flex> {
122 let shape = tensor.shape();
123 let tensor = tensor.to_contiguous();
124 let float_data = float_storage_as_f32(&tensor);
125
126 let scales_tensor = qparams.scales.to_contiguous();
131 let scales_data = float_storage_as_f32(&scales_tensor);
132 let scales: Vec<f32> = scales_data.iter().copied().map(validated_scale).collect();
133
134 let (a, b) = scheme.value.range();
135
136 let quantized = match scheme.level {
137 QuantLevel::Tensor => {
138 let inv_scale = 1.0 / scales[0];
139 float_data
140 .iter()
141 .map(|&x| (x * inv_scale).round().clamp(a, b) as i8)
142 .collect::<Vec<i8>>()
143 }
144 QuantLevel::Block(block_size) => {
145 let block_elems = block_size.num_elements();
146 debug_assert!(
147 float_data.len().is_multiple_of(block_elems),
148 "tensor length {} not divisible by block size {}",
149 float_data.len(),
150 block_elems
151 );
152 let mut quantized = Vec::with_capacity(float_data.len());
153 for (block, &scale) in float_data.chunks(block_elems).zip(scales.iter()) {
154 let inv_scale = 1.0 / scale;
155 for &x in block {
156 quantized.push((x * inv_scale).round().clamp(a, b) as i8);
157 }
158 }
159 quantized
160 }
161 };
162
163 let bytes = Bytes::from_elems(quantized);
164 let layout = Layout::contiguous(shape);
165 let qt = FlexTensor::new(bytes, layout, DType::I8);
166
167 FlexQTensor::new(qt, scheme.with_store(QuantStore::Native), scales)
168 }
169
170 fn dequantize(tensor: QuantizedTensor<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
171 let shape = tensor.tensor.shape();
172 let qt = tensor.tensor.to_contiguous();
173 let q_data: &[i8] = qt.storage();
174
175 let dequantized = match tensor.scheme.level {
176 QuantLevel::Tensor => {
177 let scale = tensor.scales[0];
178 q_data
179 .iter()
180 .map(|&x_q| scale * x_q as f32)
181 .collect::<Vec<f32>>()
182 }
183 QuantLevel::Block(block_size) => {
184 let block_elems = block_size.num_elements();
185 q_data
186 .chunks(block_elems)
187 .zip(tensor.scales.iter())
188 .flat_map(|(block, &scale)| block.iter().map(move |&x_q| scale * x_q as f32))
189 .collect::<Vec<f32>>()
190 }
191 };
192
193 let layout = Layout::contiguous(shape);
194 match dtype {
195 FloatDType::F32 | FloatDType::Flex32 => {
196 FlexTensor::new(Bytes::from_elems(dequantized), layout, DType::F32)
197 }
198 FloatDType::F64 => {
199 let data: Vec<f64> = dequantized.iter().map(|&v| v as f64).collect();
200 FlexTensor::new(Bytes::from_elems(data), layout, DType::F64)
201 }
202 FloatDType::F16 => {
203 let data: Vec<f16> = dequantized.iter().map(|&v| f16::from_f32(v)).collect();
204 FlexTensor::new(Bytes::from_elems(data), layout, DType::F16)
205 }
206 FloatDType::BF16 => {
207 let data: Vec<bf16> = dequantized.iter().map(|&v| bf16::from_f32(v)).collect();
208 FlexTensor::new(Bytes::from_elems(data), layout, DType::BF16)
209 }
210 }
211 }
212
213 fn q_device(_tensor: &QuantizedTensor<Flex>) -> Device<Flex> {
214 Default::default()
215 }
216
217 fn q_to_device(tensor: QuantizedTensor<Flex>, _device: &Device<Flex>) -> QuantizedTensor<Flex> {
218 tensor
219 }
220
221 fn q_reshape(tensor: QuantizedTensor<Flex>, shape: Shape) -> QuantizedTensor<Flex> {
222 block_safe_layout_op(tensor, |t| t.reshape(shape))
223 }
224
225 async fn q_into_data(tensor: QuantizedTensor<Flex>) -> Result<TensorData, ExecutionError> {
226 let shape = tensor.tensor.shape();
227 let scheme = tensor.scheme;
228 let qt = tensor.tensor.to_contiguous();
229 let values: Vec<i8> = qt.storage::<i8>().to_vec();
230
231 Ok(TensorData::quantized(
232 values,
233 shape.to_vec(),
234 scheme,
235 &tensor.scales,
236 ))
237 }
238
239 fn q_swap_dims(
240 tensor: QuantizedTensor<Flex>,
241 dim1: usize,
242 dim2: usize,
243 ) -> QuantizedTensor<Flex> {
244 block_safe_layout_op(tensor, |t| t.transpose(dim1, dim2))
245 }
246
247 fn q_permute(tensor: QuantizedTensor<Flex>, axes: &[usize]) -> QuantizedTensor<Flex> {
248 block_safe_layout_op(tensor, |t| t.permute(axes))
249 }
250
251 fn q_flip(tensor: QuantizedTensor<Flex>, axes: &[usize]) -> QuantizedTensor<Flex> {
252 block_safe_layout_op(tensor, |t| crate::ops::flip::flip(t, axes))
253 }
254
255 fn q_expand(tensor: QuantizedTensor<Flex>, shape: Shape) -> QuantizedTensor<Flex> {
256 block_safe_layout_op(tensor, |t| crate::ops::expand::expand(t, shape))
257 }
258
259 fn q_select(
260 tensor: QuantizedTensor<Flex>,
261 dim: usize,
262 indices: IntTensor<Flex>,
263 ) -> QuantizedTensor<Flex> {
264 match tensor.scheme.level {
265 QuantLevel::Tensor => FlexQTensor::new(
266 crate::ops::gather_scatter::select::<i8>(tensor.tensor, dim, indices),
267 tensor.scheme,
268 tensor.scales,
269 ),
270 QuantLevel::Block(_) => {
271 let scheme = tensor.scheme;
272 let float_tensor = Flex::dequantize(tensor, FloatDType::F32);
273 let result = crate::ops::gather_scatter::select::<f32>(float_tensor, dim, indices);
274 Flex::quantize_dynamic(result, &scheme)
275 }
276 }
277 }
278
279 fn q_slice(tensor: QuantizedTensor<Flex>, slices: &[Slice]) -> QuantizedTensor<Flex> {
280 block_safe_layout_op(tensor, |t| crate::ops::slice::slice(t, slices))
281 }
282
283 fn q_argmax(
284 tensor: QuantizedTensor<Flex>,
285 dim: usize,
286 out_dtype: burn_std::IntDType,
287 ) -> IntTensor<Flex> {
288 let result = crate::ops::reduce::argmax(tensor.tensor, dim);
289 if result.dtype() != DType::from(out_dtype) {
290 Flex::int_cast(result, out_dtype)
291 } else {
292 result
293 }
294 }
295
296 fn q_argmin(
297 tensor: QuantizedTensor<Flex>,
298 dim: usize,
299 out_dtype: burn_std::IntDType,
300 ) -> IntTensor<Flex> {
301 let result = crate::ops::reduce::argmin(tensor.tensor, dim);
302 if result.dtype() != DType::from(out_dtype) {
303 Flex::int_cast(result, out_dtype)
304 } else {
305 result
306 }
307 }
308
309 fn q_gather(
310 dim: usize,
311 tensor: QuantizedTensor<Flex>,
312 indices: IntTensor<Flex>,
313 ) -> QuantizedTensor<Flex> {
314 match tensor.scheme.level {
315 QuantLevel::Tensor => FlexQTensor::new(
316 crate::ops::gather_scatter::gather::<i8>(tensor.tensor, dim, indices),
317 tensor.scheme,
318 tensor.scales,
319 ),
320 QuantLevel::Block(_) => {
321 let scheme = tensor.scheme;
322 let float_tensor = Flex::dequantize(tensor, FloatDType::F32);
323 let result = crate::ops::gather_scatter::gather::<f32>(float_tensor, dim, indices);
324 Flex::quantize_dynamic(result, &scheme)
325 }
326 }
327 }
328}
329
330fn block_safe_layout_op(
334 qtensor: FlexQTensor,
335 op: impl FnOnce(FlexTensor) -> FlexTensor,
336) -> FlexQTensor {
337 match qtensor.scheme.level {
338 QuantLevel::Tensor => FlexQTensor::new(op(qtensor.tensor), qtensor.scheme, qtensor.scales),
339 QuantLevel::Block(_) => {
340 let scheme = qtensor.scheme;
341 let float_tensor = Flex::dequantize(qtensor, FloatDType::F32);
342 let result = op(float_tensor);
343 Flex::quantize_dynamic(result, &scheme)
344 }
345 }
346}
347
348fn validated_scale(scale: f32) -> f32 {
350 if scale.is_normal() {
351 scale
352 } else {
353 f32::MIN_POSITIVE
354 }
355}
356
357#[cfg(test)]
365mod tests {
366 use super::*;
367 use burn_backend::{TensorMetadata, quantization::QuantValue};
368
369 #[test]
370 fn test_quantize_dequantize_roundtrip() {
371 let values = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
373 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [2, 3]));
374
375 let scheme = QuantScheme::default()
376 .with_value(QuantValue::Q8S)
377 .with_store(QuantStore::Native);
378
379 let scale: f32 = 2.0 * 5.0 / 254.0;
383 let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale], [1]));
384
385 let qparams = QuantizationParametersPrimitive {
386 scales: scales_tensor,
387 };
388
389 let qtensor = Flex::quantize(tensor, &scheme, qparams);
391 assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
392 assert_eq!(qtensor.tensor.dtype(), DType::I8);
393
394 let q_vals: &[i8] = qtensor.tensor.storage();
396 assert_eq!(q_vals[0], 0);
398 assert_eq!(q_vals[1], 25);
399 assert_eq!(q_vals[5], 127);
400
401 let result = Flex::dequantize(qtensor, FloatDType::F32);
403 assert_eq!(result.shape().to_vec(), vec![2, 3]);
404 assert_eq!(result.dtype(), DType::F32);
405
406 let result_vals: &[f32] = result.storage();
407 for (orig, deq) in values.iter().zip(result_vals.iter()) {
409 assert!((orig - deq).abs() < 0.05, "orig={orig}, dequantized={deq}");
410 }
411 }
412
413 #[test]
414 fn test_quantize_dequantize_negative_values() {
415 let values = vec![-3.0f32, -1.5, 0.0, 1.5, 3.0];
416 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [5]));
417
418 let scheme = QuantScheme::default()
419 .with_value(QuantValue::Q8S)
420 .with_store(QuantStore::Native);
421
422 let scale: f32 = 2.0 * 3.0 / 254.0;
423 let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale], [1]));
424
425 let qparams = QuantizationParametersPrimitive {
426 scales: scales_tensor,
427 };
428
429 let qtensor = Flex::quantize(tensor, &scheme, qparams);
430 let result = Flex::dequantize(qtensor, FloatDType::F32);
431 let result_vals: &[f32] = result.storage();
432
433 for (orig, deq) in values.iter().zip(result_vals.iter()) {
434 assert!((orig - deq).abs() < 0.05, "orig={orig}, dequantized={deq}");
435 }
436 }
437
438 #[test]
439 fn test_q_from_data_into_data_roundtrip() {
440 let values = vec![0i8, 25, 51, 76, 102, 127];
442 let scale = 0.03937008f32;
443 let scheme = QuantScheme::default()
444 .with_value(QuantValue::Q8S)
445 .with_store(QuantStore::Native);
446
447 let data = TensorData::quantized(values.clone(), [2, 3], scheme, &[scale]);
448
449 let qtensor = Flex::q_from_data(data, &Default::default());
451 assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
452 assert_eq!(qtensor.scales, vec![scale]);
453
454 let float_tensor = Flex::dequantize(qtensor, FloatDType::F32);
456 let result: &[f32] = float_tensor.storage();
457 assert!((result[0]).abs() < 0.01); assert!((result[5] - 5.0).abs() < 0.05); }
460
461 #[test]
462 fn test_quantize_zero_tensor() {
463 let values = vec![0.0f32; 4];
464 let tensor = FlexTensor::from_data(TensorData::new(values, [4]));
465
466 let scheme = QuantScheme::default()
467 .with_value(QuantValue::Q8S)
468 .with_store(QuantStore::Native);
469
470 let scales_tensor = FlexTensor::from_data(TensorData::new(vec![0.0f32], [1]));
472 let qparams = QuantizationParametersPrimitive {
473 scales: scales_tensor,
474 };
475
476 let qtensor = Flex::quantize(tensor, &scheme, qparams);
477 let q_vals: &[i8] = qtensor.tensor.storage();
478 assert_eq!(q_vals, &[0, 0, 0, 0]);
479 }
480
481 #[test]
482 fn test_quantize_dynamic_roundtrip() {
483 let values = vec![-3.0f32, -1.5, 0.0, 1.5, 3.0, 4.5];
484 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [2, 3]));
485
486 let scheme = QuantScheme::default()
487 .with_value(QuantValue::Q8S)
488 .with_store(QuantStore::Native);
489
490 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
491 assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
492 assert_eq!(qtensor.scales.len(), 1);
493
494 let expected_scale: f32 = 2.0 * 4.5 / 254.0;
496 assert!(
497 (qtensor.scales[0] - expected_scale).abs() < 1e-6,
498 "scale={}, expected={}",
499 qtensor.scales[0],
500 expected_scale
501 );
502
503 let result = Flex::dequantize(qtensor, FloatDType::F32);
504 let result_vals: &[f32] = result.storage();
505 for (orig, deq) in values.iter().zip(result_vals.iter()) {
506 assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
507 }
508 }
509
510 #[test]
511 fn test_per_block_quantize_dequantize() {
512 use burn_std::quantization::BlockSize;
513
514 let values = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
515 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [8]));
516
517 let block_size = BlockSize::new([4]);
518 let scheme = QuantScheme::default()
519 .with_value(QuantValue::Q8S)
520 .with_level(QuantLevel::Block(block_size))
521 .with_store(QuantStore::Native);
522
523 let scale_1: f32 = 2.0 * 3.0 / 254.0;
526 let scale_2: f32 = 2.0 * 7.0 / 254.0;
527 let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale_1, scale_2], [2]));
528
529 let qparams = QuantizationParametersPrimitive {
530 scales: scales_tensor,
531 };
532
533 let qtensor = Flex::quantize(tensor, &scheme, qparams);
534 assert_eq!(qtensor.scales.len(), 2);
535
536 let result = Flex::dequantize(qtensor, FloatDType::F32);
537 let result_vals: &[f32] = result.storage();
538
539 for (orig, deq) in values.iter().zip(result_vals.iter()) {
540 assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
541 }
542 }
543
544 #[test]
545 fn test_quantize_dynamic_block() {
546 use burn_std::quantization::BlockSize;
547
548 let values = vec![-2.0f32, -1.0, 0.0, 1.0, 4.0, 5.0, 6.0, 7.0];
549 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [8]));
550
551 let block_size = BlockSize::new([4]);
552 let scheme = QuantScheme::default()
553 .with_value(QuantValue::Q8S)
554 .with_level(QuantLevel::Block(block_size))
555 .with_store(QuantStore::Native);
556
557 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
558 assert_eq!(qtensor.scales.len(), 2);
559
560 let expected_scale_1: f32 = 2.0 * 2.0 / 254.0;
563 let expected_scale_2: f32 = 2.0 * 7.0 / 254.0;
564 assert!((qtensor.scales[0] - expected_scale_1).abs() < 1e-6);
565 assert!((qtensor.scales[1] - expected_scale_2).abs() < 1e-6);
566
567 let result = Flex::dequantize(qtensor, FloatDType::F32);
568 let result_vals: &[f32] = result.storage();
569 for (orig, deq) in values.iter().zip(result_vals.iter()) {
570 assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
571 }
572 }
573
574 #[test]
575 fn test_quantize_dynamic_q8f() {
576 let values = vec![-5.0f32, -2.5, 0.0, 2.5, 5.0, 7.5];
578 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [6]));
579
580 let scheme = QuantScheme::default()
581 .with_value(QuantValue::Q8F)
582 .with_store(QuantStore::Native);
583
584 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
585
586 let expected_scale: f32 = 2.0 * 7.5 / 255.0;
589 assert!(
590 (qtensor.scales[0] - expected_scale).abs() < 1e-6,
591 "scale={}, expected={}",
592 qtensor.scales[0],
593 expected_scale
594 );
595
596 let result = Flex::dequantize(qtensor, FloatDType::F32);
597 let result_vals: &[f32] = result.storage();
598 for (orig, deq) in values.iter().zip(result_vals.iter()) {
599 assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
600 }
601 }
602
603 #[test]
604 fn test_block_quantized_transpose_dequantize() {
605 use burn_std::quantization::BlockSize;
606
607 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
609 let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
610
611 let block_size = BlockSize::new([4]);
612 let scheme = QuantScheme::default()
613 .with_value(QuantValue::Q8S)
614 .with_level(QuantLevel::Block(block_size))
615 .with_store(QuantStore::Native);
616
617 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
618
619 let transposed = Flex::q_swap_dims(qtensor, 0, 1);
621 assert_eq!(transposed.tensor.shape().to_vec(), vec![4, 2]);
622
623 let result = Flex::dequantize(transposed, FloatDType::F32);
624 let result_vals: &[f32] = result.storage();
625
626 let expected = [1.0f32, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0];
628 for (exp, deq) in expected.iter().zip(result_vals.iter()) {
629 assert!(
630 (exp - deq).abs() < 0.15,
631 "expected={exp}, dequantized={deq}"
632 );
633 }
634 }
635
636 #[test]
637 fn test_block_quantized_select() {
638 use burn_std::quantization::BlockSize;
639
640 let values = vec![1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
642 let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
643
644 let block_size = BlockSize::new([4]);
645 let scheme = QuantScheme::default()
646 .with_value(QuantValue::Q8S)
647 .with_level(QuantLevel::Block(block_size))
648 .with_store(QuantStore::Native);
649
650 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
651
652 let indices = FlexTensor::from_data(TensorData::new(vec![1i64], [1]));
654 let selected = Flex::q_select(qtensor, 0, indices);
655 assert_eq!(selected.tensor.shape().to_vec(), vec![1, 4]);
656
657 let result = Flex::dequantize(selected, FloatDType::F32);
658 let result_vals: &[f32] = result.storage();
659 let expected = [10.0f32, 20.0, 30.0, 40.0];
660 for (exp, deq) in expected.iter().zip(result_vals.iter()) {
661 assert!((exp - deq).abs() < 0.5, "expected={exp}, dequantized={deq}");
662 }
663 }
664
665 #[test]
666 fn test_block_quantized_flip_dequantize() {
667 use burn_std::quantization::BlockSize;
668
669 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
670 let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
671
672 let block_size = BlockSize::new([4]);
673 let scheme = QuantScheme::default()
674 .with_value(QuantValue::Q8S)
675 .with_level(QuantLevel::Block(block_size))
676 .with_store(QuantStore::Native);
677
678 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
679
680 let flipped = Flex::q_flip(qtensor, &[0]);
682 assert_eq!(flipped.tensor.shape().to_vec(), vec![2, 4]);
683
684 let result = Flex::dequantize(flipped, FloatDType::F32);
685 let result_vals: &[f32] = result.storage();
686 let expected = [5.0f32, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0];
687 for (exp, deq) in expected.iter().zip(result_vals.iter()) {
688 assert!(
689 (exp - deq).abs() < 0.15,
690 "expected={exp}, dequantized={deq}"
691 );
692 }
693 }
694
695 #[test]
696 fn test_quantize_dynamic_f64_tensor() {
697 use burn_backend::quantization::QuantValue;
698
699 let values = vec![0.0f64, 1.0, 2.0, 3.0, 4.0, 5.0];
700 let tensor = FlexTensor::new(
701 Bytes::from_elems(values),
702 Layout::contiguous([6].into()),
703 DType::F64,
704 );
705 assert_eq!(tensor.dtype(), DType::F64);
706
707 let scheme = QuantScheme::default()
708 .with_value(QuantValue::Q8S)
709 .with_store(QuantStore::Native);
710
711 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
712 assert_eq!(qtensor.tensor.dtype(), DType::I8);
713
714 let result = Flex::dequantize(qtensor, FloatDType::F32);
716 let result_vals: &[f32] = result.storage();
717 let expected = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
718 for (exp, deq) in expected.iter().zip(result_vals.iter()) {
719 assert!(
720 (exp - deq).abs() < 0.15,
721 "expected={exp}, dequantized={deq}"
722 );
723 }
724 }
725
726 #[test]
727 fn test_dequantize_f64() {
728 let values = vec![0.0f32, 1.0, 2.0, 3.0];
729 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [4]));
730
731 let scheme = QuantScheme::default()
732 .with_value(QuantValue::Q8S)
733 .with_store(QuantStore::Native);
734
735 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
736 let result = Flex::dequantize(qtensor, FloatDType::F64);
737 assert_eq!(result.dtype(), DType::F64);
738 let result_vals: &[f64] = result.storage();
739 for (orig, deq) in values.iter().zip(result_vals.iter()) {
740 assert!(
741 (*orig as f64 - deq).abs() < 0.05,
742 "orig={orig}, dequantized={deq}"
743 );
744 }
745 }
746
747 #[test]
748 fn test_dequantize_f16() {
749 let values = vec![0.0f32, 1.0, 2.0, 3.0];
750 let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [4]));
751
752 let scheme = QuantScheme::default()
753 .with_value(QuantValue::Q8S)
754 .with_store(QuantStore::Native);
755
756 let qtensor = Flex::quantize_dynamic(tensor, &scheme);
757 let result = Flex::dequantize(qtensor, FloatDType::F16);
758 assert_eq!(result.dtype(), DType::F16);
759 let result_vals: &[f16] = result.storage();
760 for (orig, deq) in values.iter().zip(result_vals.iter()) {
761 assert!(
762 (*orig - f32::from(*deq)).abs() < 0.05,
763 "orig={orig}, dequantized={deq}"
764 );
765 }
766 }
767}