Skip to main content

burn_flex/ops/
qtensor.rs

1//! Quantized tensor operations for the Flex backend.
2
3use alloc::vec::Vec;
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use burn_backend::{
9    DType, ExecutionError, FloatDType, TensorData, TensorMetadata,
10    ops::{IntTensorOps, QTensorOps},
11    quantization::{
12        QuantLevel, QuantScheme, QuantStore, QuantizationParametersPrimitive, QuantizedBytes,
13    },
14    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
15};
16use burn_std::{Bytes, Shape, Slice, bf16, f16};
17
18use super::float_storage_as_f32;
19use crate::{Flex, FlexQTensor, FlexTensor, Layout};
20
21impl QTensorOps<Flex> for Flex {
22    fn q_from_data(data: TensorData, _device: &Device<Flex>) -> QuantizedTensor<Flex> {
23        let scheme = match data.dtype {
24            DType::QFloat(scheme) => scheme,
25            _ => panic!("Expected quantized dtype, got {:?}", data.dtype),
26        };
27
28        let shape = data.shape.clone();
29        let num_elements = data.num_elements();
30
31        let q_bytes = QuantizedBytes {
32            bytes: data.into_bytes(),
33            scheme,
34            num_elements,
35        };
36
37        let (values, qparams) = q_bytes.into_vec_i8();
38        let tensor_data = TensorData::new(values, shape);
39        let tensor = FlexTensor::from_data(tensor_data);
40
41        // Use native storage since we've unpacked to i8
42        let scheme = scheme.with_store(QuantStore::Native);
43
44        FlexQTensor::new(tensor, scheme, qparams.scales)
45    }
46
47    fn quantize_dynamic(tensor: FloatTensor<Flex>, scheme: &QuantScheme) -> QuantizedTensor<Flex> {
48        let shape = tensor.shape();
49        let tensor = tensor.to_contiguous();
50        let float_data = float_storage_as_f32(&tensor);
51        let (a, b) = scheme.value.range();
52        let range = b - a;
53
54        let (quantized, scales) = match scheme.level {
55            QuantLevel::Tensor => {
56                // Pass 1: find alpha = max(|min|, |max|)
57                let mut alpha: f32 = 0.0;
58                for &x in &*float_data {
59                    let abs = x.abs();
60                    if abs > alpha {
61                        alpha = abs;
62                    }
63                }
64                let scale = validated_scale(2.0 * alpha / range);
65                let inv_scale = 1.0 / scale;
66
67                // Pass 2: quantize
68                let quantized = float_data
69                    .iter()
70                    .map(|&x| (x * inv_scale).round().clamp(a, b) as i8)
71                    .collect::<Vec<i8>>();
72
73                (quantized, alloc::vec![scale])
74            }
75            QuantLevel::Block(block_size) => {
76                let block_elems = block_size.num_elements();
77                debug_assert!(
78                    float_data.len().is_multiple_of(block_elems),
79                    "tensor length {} not divisible by block size {}",
80                    float_data.len(),
81                    block_elems
82                );
83                let num_blocks = float_data.len() / block_elems;
84                let mut scales = Vec::with_capacity(num_blocks);
85                let mut quantized = Vec::with_capacity(float_data.len());
86
87                for block in float_data.chunks(block_elems) {
88                    // Find alpha for this block
89                    let mut alpha: f32 = 0.0;
90                    for &x in block {
91                        let abs = x.abs();
92                        if abs > alpha {
93                            alpha = abs;
94                        }
95                    }
96                    let scale = validated_scale(2.0 * alpha / range);
97                    let inv_scale = 1.0 / scale;
98                    scales.push(scale);
99
100                    // Quantize this block
101                    for &x in block {
102                        quantized.push((x * inv_scale).round().clamp(a, b) as i8);
103                    }
104                }
105
106                (quantized, scales)
107            }
108        };
109
110        let bytes = Bytes::from_elems(quantized);
111        let layout = Layout::contiguous(shape);
112        let qt = FlexTensor::new(bytes, layout, DType::I8);
113
114        FlexQTensor::new(qt, scheme.with_store(QuantStore::Native), scales)
115    }
116
117    fn quantize(
118        tensor: FloatTensor<Flex>,
119        scheme: &QuantScheme,
120        qparams: QuantizationParametersPrimitive<Flex>,
121    ) -> QuantizedTensor<Flex> {
122        let shape = tensor.shape();
123        let tensor = tensor.to_contiguous();
124        let float_data = float_storage_as_f32(&tensor);
125
126        // Extract and validate scales from the qparams tensor. The scales tensor
127        // shares its dtype with the float element type, which can be any of
128        // f32/f64/f16/bf16, so we normalise via float_storage_as_f32 instead of
129        // assuming f32 storage.
130        let scales_tensor = qparams.scales.to_contiguous();
131        let scales_data = float_storage_as_f32(&scales_tensor);
132        let scales: Vec<f32> = scales_data.iter().copied().map(validated_scale).collect();
133
134        let (a, b) = scheme.value.range();
135
136        let quantized = match scheme.level {
137            QuantLevel::Tensor => {
138                let inv_scale = 1.0 / scales[0];
139                float_data
140                    .iter()
141                    .map(|&x| (x * inv_scale).round().clamp(a, b) as i8)
142                    .collect::<Vec<i8>>()
143            }
144            QuantLevel::Block(block_size) => {
145                let block_elems = block_size.num_elements();
146                debug_assert!(
147                    float_data.len().is_multiple_of(block_elems),
148                    "tensor length {} not divisible by block size {}",
149                    float_data.len(),
150                    block_elems
151                );
152                let mut quantized = Vec::with_capacity(float_data.len());
153                for (block, &scale) in float_data.chunks(block_elems).zip(scales.iter()) {
154                    let inv_scale = 1.0 / scale;
155                    for &x in block {
156                        quantized.push((x * inv_scale).round().clamp(a, b) as i8);
157                    }
158                }
159                quantized
160            }
161        };
162
163        let bytes = Bytes::from_elems(quantized);
164        let layout = Layout::contiguous(shape);
165        let qt = FlexTensor::new(bytes, layout, DType::I8);
166
167        FlexQTensor::new(qt, scheme.with_store(QuantStore::Native), scales)
168    }
169
170    fn dequantize(tensor: QuantizedTensor<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
171        let shape = tensor.tensor.shape();
172        let qt = tensor.tensor.to_contiguous();
173        let q_data: &[i8] = qt.storage();
174
175        let dequantized = match tensor.scheme.level {
176            QuantLevel::Tensor => {
177                let scale = tensor.scales[0];
178                q_data
179                    .iter()
180                    .map(|&x_q| scale * x_q as f32)
181                    .collect::<Vec<f32>>()
182            }
183            QuantLevel::Block(block_size) => {
184                let block_elems = block_size.num_elements();
185                q_data
186                    .chunks(block_elems)
187                    .zip(tensor.scales.iter())
188                    .flat_map(|(block, &scale)| block.iter().map(move |&x_q| scale * x_q as f32))
189                    .collect::<Vec<f32>>()
190            }
191        };
192
193        let layout = Layout::contiguous(shape);
194        match dtype {
195            FloatDType::F32 | FloatDType::Flex32 => {
196                FlexTensor::new(Bytes::from_elems(dequantized), layout, DType::F32)
197            }
198            FloatDType::F64 => {
199                let data: Vec<f64> = dequantized.iter().map(|&v| v as f64).collect();
200                FlexTensor::new(Bytes::from_elems(data), layout, DType::F64)
201            }
202            FloatDType::F16 => {
203                let data: Vec<f16> = dequantized.iter().map(|&v| f16::from_f32(v)).collect();
204                FlexTensor::new(Bytes::from_elems(data), layout, DType::F16)
205            }
206            FloatDType::BF16 => {
207                let data: Vec<bf16> = dequantized.iter().map(|&v| bf16::from_f32(v)).collect();
208                FlexTensor::new(Bytes::from_elems(data), layout, DType::BF16)
209            }
210        }
211    }
212
213    fn q_device(_tensor: &QuantizedTensor<Flex>) -> Device<Flex> {
214        Default::default()
215    }
216
217    fn q_to_device(tensor: QuantizedTensor<Flex>, _device: &Device<Flex>) -> QuantizedTensor<Flex> {
218        tensor
219    }
220
221    fn q_reshape(tensor: QuantizedTensor<Flex>, shape: Shape) -> QuantizedTensor<Flex> {
222        block_safe_layout_op(tensor, |t| t.reshape(shape))
223    }
224
225    async fn q_into_data(tensor: QuantizedTensor<Flex>) -> Result<TensorData, ExecutionError> {
226        let shape = tensor.tensor.shape();
227        let scheme = tensor.scheme;
228        let qt = tensor.tensor.to_contiguous();
229        let values: Vec<i8> = qt.storage::<i8>().to_vec();
230
231        Ok(TensorData::quantized(
232            values,
233            shape.to_vec(),
234            scheme,
235            &tensor.scales,
236        ))
237    }
238
239    fn q_swap_dims(
240        tensor: QuantizedTensor<Flex>,
241        dim1: usize,
242        dim2: usize,
243    ) -> QuantizedTensor<Flex> {
244        block_safe_layout_op(tensor, |t| t.transpose(dim1, dim2))
245    }
246
247    fn q_permute(tensor: QuantizedTensor<Flex>, axes: &[usize]) -> QuantizedTensor<Flex> {
248        block_safe_layout_op(tensor, |t| t.permute(axes))
249    }
250
251    fn q_flip(tensor: QuantizedTensor<Flex>, axes: &[usize]) -> QuantizedTensor<Flex> {
252        block_safe_layout_op(tensor, |t| crate::ops::flip::flip(t, axes))
253    }
254
255    fn q_expand(tensor: QuantizedTensor<Flex>, shape: Shape) -> QuantizedTensor<Flex> {
256        block_safe_layout_op(tensor, |t| crate::ops::expand::expand(t, shape))
257    }
258
259    fn q_select(
260        tensor: QuantizedTensor<Flex>,
261        dim: usize,
262        indices: IntTensor<Flex>,
263    ) -> QuantizedTensor<Flex> {
264        match tensor.scheme.level {
265            QuantLevel::Tensor => FlexQTensor::new(
266                crate::ops::gather_scatter::select::<i8>(tensor.tensor, dim, indices),
267                tensor.scheme,
268                tensor.scales,
269            ),
270            QuantLevel::Block(_) => {
271                let scheme = tensor.scheme;
272                let float_tensor = Flex::dequantize(tensor, FloatDType::F32);
273                let result = crate::ops::gather_scatter::select::<f32>(float_tensor, dim, indices);
274                Flex::quantize_dynamic(result, &scheme)
275            }
276        }
277    }
278
279    fn q_slice(tensor: QuantizedTensor<Flex>, slices: &[Slice]) -> QuantizedTensor<Flex> {
280        block_safe_layout_op(tensor, |t| crate::ops::slice::slice(t, slices))
281    }
282
283    fn q_argmax(
284        tensor: QuantizedTensor<Flex>,
285        dim: usize,
286        out_dtype: burn_std::IntDType,
287    ) -> IntTensor<Flex> {
288        let result = crate::ops::reduce::argmax(tensor.tensor, dim);
289        if result.dtype() != DType::from(out_dtype) {
290            Flex::int_cast(result, out_dtype)
291        } else {
292            result
293        }
294    }
295
296    fn q_argmin(
297        tensor: QuantizedTensor<Flex>,
298        dim: usize,
299        out_dtype: burn_std::IntDType,
300    ) -> IntTensor<Flex> {
301        let result = crate::ops::reduce::argmin(tensor.tensor, dim);
302        if result.dtype() != DType::from(out_dtype) {
303            Flex::int_cast(result, out_dtype)
304        } else {
305            result
306        }
307    }
308
309    fn q_gather(
310        dim: usize,
311        tensor: QuantizedTensor<Flex>,
312        indices: IntTensor<Flex>,
313    ) -> QuantizedTensor<Flex> {
314        match tensor.scheme.level {
315            QuantLevel::Tensor => FlexQTensor::new(
316                crate::ops::gather_scatter::gather::<i8>(tensor.tensor, dim, indices),
317                tensor.scheme,
318                tensor.scales,
319            ),
320            QuantLevel::Block(_) => {
321                let scheme = tensor.scheme;
322                let float_tensor = Flex::dequantize(tensor, FloatDType::F32);
323                let result = crate::ops::gather_scatter::gather::<f32>(float_tensor, dim, indices);
324                Flex::quantize_dynamic(result, &scheme)
325            }
326        }
327    }
328}
329
330/// Apply a layout operation to a quantized tensor.
331/// For block-quantized tensors, dequantizes and requantizes to preserve
332/// correct scale-to-block mapping.
333fn block_safe_layout_op(
334    qtensor: FlexQTensor,
335    op: impl FnOnce(FlexTensor) -> FlexTensor,
336) -> FlexQTensor {
337    match qtensor.scheme.level {
338        QuantLevel::Tensor => FlexQTensor::new(op(qtensor.tensor), qtensor.scheme, qtensor.scales),
339        QuantLevel::Block(_) => {
340            let scheme = qtensor.scheme;
341            let float_tensor = Flex::dequantize(qtensor, FloatDType::F32);
342            let result = op(float_tensor);
343            Flex::quantize_dynamic(result, &scheme)
344        }
345    }
346}
347
348/// Ensure scale is finite and nonzero to avoid division by zero or NaN propagation.
349fn validated_scale(scale: f32) -> f32 {
350    if scale.is_normal() {
351        scale
352    } else {
353        f32::MIN_POSITIVE
354    }
355}
356
357// Tests kept here exercise flex-specific behavior: quantization scheme
358// roundtrips, per-block / dynamic quantization, block-quantized layout
359// ops (transpose / select / flip dequantize), and f16/f64 dequantize
360// dtype paths. Plain layout-preservation / select / slice / argmax /
361// argmin / gather tests are covered generically in
362// crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/
363// so they run on every backend.
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use burn_backend::{TensorMetadata, quantization::QuantValue};
368
369    #[test]
370    fn test_quantize_dequantize_roundtrip() {
371        // Create a float tensor
372        let values = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
373        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [2, 3]));
374
375        let scheme = QuantScheme::default()
376            .with_value(QuantValue::Q8S)
377            .with_store(QuantStore::Native);
378
379        // Compute scale: symmetric, so scale = 2 * max(|min|, |max|) / (b - a)
380        // max_abs = 5.0, range = 127 - (-127) = 254
381        // scale = 2 * 5.0 / 254 = 0.03937008
382        let scale: f32 = 2.0 * 5.0 / 254.0;
383        let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale], [1]));
384
385        let qparams = QuantizationParametersPrimitive {
386            scales: scales_tensor,
387        };
388
389        // Quantize
390        let qtensor = Flex::quantize(tensor, &scheme, qparams);
391        assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
392        assert_eq!(qtensor.tensor.dtype(), DType::I8);
393
394        // Check quantized values
395        let q_vals: &[i8] = qtensor.tensor.storage();
396        // 0 / 0.03937 = 0, 1 / 0.03937 = 25.4 -> 25, etc.
397        assert_eq!(q_vals[0], 0);
398        assert_eq!(q_vals[1], 25);
399        assert_eq!(q_vals[5], 127);
400
401        // Dequantize
402        let result = Flex::dequantize(qtensor, FloatDType::F32);
403        assert_eq!(result.shape().to_vec(), vec![2, 3]);
404        assert_eq!(result.dtype(), DType::F32);
405
406        let result_vals: &[f32] = result.storage();
407        // Values should be approximately equal (quantization introduces small errors)
408        for (orig, deq) in values.iter().zip(result_vals.iter()) {
409            assert!((orig - deq).abs() < 0.05, "orig={orig}, dequantized={deq}");
410        }
411    }
412
413    #[test]
414    fn test_quantize_dequantize_negative_values() {
415        let values = vec![-3.0f32, -1.5, 0.0, 1.5, 3.0];
416        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [5]));
417
418        let scheme = QuantScheme::default()
419            .with_value(QuantValue::Q8S)
420            .with_store(QuantStore::Native);
421
422        let scale: f32 = 2.0 * 3.0 / 254.0;
423        let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale], [1]));
424
425        let qparams = QuantizationParametersPrimitive {
426            scales: scales_tensor,
427        };
428
429        let qtensor = Flex::quantize(tensor, &scheme, qparams);
430        let result = Flex::dequantize(qtensor, FloatDType::F32);
431        let result_vals: &[f32] = result.storage();
432
433        for (orig, deq) in values.iter().zip(result_vals.iter()) {
434            assert!((orig - deq).abs() < 0.05, "orig={orig}, dequantized={deq}");
435        }
436    }
437
438    #[test]
439    fn test_q_from_data_into_data_roundtrip() {
440        // Create quantized TensorData the standard way
441        let values = vec![0i8, 25, 51, 76, 102, 127];
442        let scale = 0.03937008f32;
443        let scheme = QuantScheme::default()
444            .with_value(QuantValue::Q8S)
445            .with_store(QuantStore::Native);
446
447        let data = TensorData::quantized(values.clone(), [2, 3], scheme, &[scale]);
448
449        // Load into FlexQTensor
450        let qtensor = Flex::q_from_data(data, &Default::default());
451        assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
452        assert_eq!(qtensor.scales, vec![scale]);
453
454        // Dequantize and check values
455        let float_tensor = Flex::dequantize(qtensor, FloatDType::F32);
456        let result: &[f32] = float_tensor.storage();
457        assert!((result[0]).abs() < 0.01); // 0 * scale ~ 0
458        assert!((result[5] - 5.0).abs() < 0.05); // 127 * scale ~ 5.0
459    }
460
461    #[test]
462    fn test_quantize_zero_tensor() {
463        let values = vec![0.0f32; 4];
464        let tensor = FlexTensor::from_data(TensorData::new(values, [4]));
465
466        let scheme = QuantScheme::default()
467            .with_value(QuantValue::Q8S)
468            .with_store(QuantStore::Native);
469
470        // Scale of 0 should be handled gracefully
471        let scales_tensor = FlexTensor::from_data(TensorData::new(vec![0.0f32], [1]));
472        let qparams = QuantizationParametersPrimitive {
473            scales: scales_tensor,
474        };
475
476        let qtensor = Flex::quantize(tensor, &scheme, qparams);
477        let q_vals: &[i8] = qtensor.tensor.storage();
478        assert_eq!(q_vals, &[0, 0, 0, 0]);
479    }
480
481    #[test]
482    fn test_quantize_dynamic_roundtrip() {
483        let values = vec![-3.0f32, -1.5, 0.0, 1.5, 3.0, 4.5];
484        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [2, 3]));
485
486        let scheme = QuantScheme::default()
487            .with_value(QuantValue::Q8S)
488            .with_store(QuantStore::Native);
489
490        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
491        assert_eq!(qtensor.tensor.shape().to_vec(), vec![2, 3]);
492        assert_eq!(qtensor.scales.len(), 1);
493
494        // Scale should be 2 * 4.5 / 254
495        let expected_scale: f32 = 2.0 * 4.5 / 254.0;
496        assert!(
497            (qtensor.scales[0] - expected_scale).abs() < 1e-6,
498            "scale={}, expected={}",
499            qtensor.scales[0],
500            expected_scale
501        );
502
503        let result = Flex::dequantize(qtensor, FloatDType::F32);
504        let result_vals: &[f32] = result.storage();
505        for (orig, deq) in values.iter().zip(result_vals.iter()) {
506            assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
507        }
508    }
509
510    #[test]
511    fn test_per_block_quantize_dequantize() {
512        use burn_std::quantization::BlockSize;
513
514        let values = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
515        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [8]));
516
517        let block_size = BlockSize::new([4]);
518        let scheme = QuantScheme::default()
519            .with_value(QuantValue::Q8S)
520            .with_level(QuantLevel::Block(block_size))
521            .with_store(QuantStore::Native);
522
523        // Block 1: [0, 1, 2, 3] -> max_abs=3, scale = 6/254
524        // Block 2: [4, 5, 6, 7] -> max_abs=7, scale = 14/254
525        let scale_1: f32 = 2.0 * 3.0 / 254.0;
526        let scale_2: f32 = 2.0 * 7.0 / 254.0;
527        let scales_tensor = FlexTensor::from_data(TensorData::new(vec![scale_1, scale_2], [2]));
528
529        let qparams = QuantizationParametersPrimitive {
530            scales: scales_tensor,
531        };
532
533        let qtensor = Flex::quantize(tensor, &scheme, qparams);
534        assert_eq!(qtensor.scales.len(), 2);
535
536        let result = Flex::dequantize(qtensor, FloatDType::F32);
537        let result_vals: &[f32] = result.storage();
538
539        for (orig, deq) in values.iter().zip(result_vals.iter()) {
540            assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
541        }
542    }
543
544    #[test]
545    fn test_quantize_dynamic_block() {
546        use burn_std::quantization::BlockSize;
547
548        let values = vec![-2.0f32, -1.0, 0.0, 1.0, 4.0, 5.0, 6.0, 7.0];
549        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [8]));
550
551        let block_size = BlockSize::new([4]);
552        let scheme = QuantScheme::default()
553            .with_value(QuantValue::Q8S)
554            .with_level(QuantLevel::Block(block_size))
555            .with_store(QuantStore::Native);
556
557        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
558        assert_eq!(qtensor.scales.len(), 2);
559
560        // Block 1: [-2, -1, 0, 1] -> alpha=2, scale = 4/254
561        // Block 2: [4, 5, 6, 7] -> alpha=7, scale = 14/254
562        let expected_scale_1: f32 = 2.0 * 2.0 / 254.0;
563        let expected_scale_2: f32 = 2.0 * 7.0 / 254.0;
564        assert!((qtensor.scales[0] - expected_scale_1).abs() < 1e-6);
565        assert!((qtensor.scales[1] - expected_scale_2).abs() < 1e-6);
566
567        let result = Flex::dequantize(qtensor, FloatDType::F32);
568        let result_vals: &[f32] = result.storage();
569        for (orig, deq) in values.iter().zip(result_vals.iter()) {
570            assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
571        }
572    }
573
574    #[test]
575    fn test_quantize_dynamic_q8f() {
576        // Q8F uses asymmetric range [-128, 127]
577        let values = vec![-5.0f32, -2.5, 0.0, 2.5, 5.0, 7.5];
578        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [6]));
579
580        let scheme = QuantScheme::default()
581            .with_value(QuantValue::Q8F)
582            .with_store(QuantStore::Native);
583
584        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
585
586        // Q8F range: [-128, 127], so range = 255
587        // alpha = 7.5, scale = 2 * 7.5 / 255
588        let expected_scale: f32 = 2.0 * 7.5 / 255.0;
589        assert!(
590            (qtensor.scales[0] - expected_scale).abs() < 1e-6,
591            "scale={}, expected={}",
592            qtensor.scales[0],
593            expected_scale
594        );
595
596        let result = Flex::dequantize(qtensor, FloatDType::F32);
597        let result_vals: &[f32] = result.storage();
598        for (orig, deq) in values.iter().zip(result_vals.iter()) {
599            assert!((orig - deq).abs() < 0.1, "orig={orig}, dequantized={deq}");
600        }
601    }
602
603    #[test]
604    fn test_block_quantized_transpose_dequantize() {
605        use burn_std::quantization::BlockSize;
606
607        // 2x4 tensor, 2 blocks of 4
608        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
609        let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
610
611        let block_size = BlockSize::new([4]);
612        let scheme = QuantScheme::default()
613            .with_value(QuantValue::Q8S)
614            .with_level(QuantLevel::Block(block_size))
615            .with_store(QuantStore::Native);
616
617        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
618
619        // Transpose to [4, 2], then dequantize
620        let transposed = Flex::q_swap_dims(qtensor, 0, 1);
621        assert_eq!(transposed.tensor.shape().to_vec(), vec![4, 2]);
622
623        let result = Flex::dequantize(transposed, FloatDType::F32);
624        let result_vals: &[f32] = result.storage();
625
626        // Original [[1,2,3,4],[5,6,7,8]] transposed to [[1,5],[2,6],[3,7],[4,8]]
627        let expected = [1.0f32, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0];
628        for (exp, deq) in expected.iter().zip(result_vals.iter()) {
629            assert!(
630                (exp - deq).abs() < 0.15,
631                "expected={exp}, dequantized={deq}"
632            );
633        }
634    }
635
636    #[test]
637    fn test_block_quantized_select() {
638        use burn_std::quantization::BlockSize;
639
640        // 2x4 tensor, 2 blocks of 4
641        let values = vec![1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
642        let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
643
644        let block_size = BlockSize::new([4]);
645        let scheme = QuantScheme::default()
646            .with_value(QuantValue::Q8S)
647            .with_level(QuantLevel::Block(block_size))
648            .with_store(QuantStore::Native);
649
650        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
651
652        // Select row 1 -> [10, 20, 30, 40]
653        let indices = FlexTensor::from_data(TensorData::new(vec![1i64], [1]));
654        let selected = Flex::q_select(qtensor, 0, indices);
655        assert_eq!(selected.tensor.shape().to_vec(), vec![1, 4]);
656
657        let result = Flex::dequantize(selected, FloatDType::F32);
658        let result_vals: &[f32] = result.storage();
659        let expected = [10.0f32, 20.0, 30.0, 40.0];
660        for (exp, deq) in expected.iter().zip(result_vals.iter()) {
661            assert!((exp - deq).abs() < 0.5, "expected={exp}, dequantized={deq}");
662        }
663    }
664
665    #[test]
666    fn test_block_quantized_flip_dequantize() {
667        use burn_std::quantization::BlockSize;
668
669        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
670        let tensor = FlexTensor::from_data(TensorData::new(values, [2, 4]));
671
672        let block_size = BlockSize::new([4]);
673        let scheme = QuantScheme::default()
674            .with_value(QuantValue::Q8S)
675            .with_level(QuantLevel::Block(block_size))
676            .with_store(QuantStore::Native);
677
678        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
679
680        // Flip along axis 0: [[5,6,7,8],[1,2,3,4]]
681        let flipped = Flex::q_flip(qtensor, &[0]);
682        assert_eq!(flipped.tensor.shape().to_vec(), vec![2, 4]);
683
684        let result = Flex::dequantize(flipped, FloatDType::F32);
685        let result_vals: &[f32] = result.storage();
686        let expected = [5.0f32, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0];
687        for (exp, deq) in expected.iter().zip(result_vals.iter()) {
688            assert!(
689                (exp - deq).abs() < 0.15,
690                "expected={exp}, dequantized={deq}"
691            );
692        }
693    }
694
695    #[test]
696    fn test_quantize_dynamic_f64_tensor() {
697        use burn_backend::quantization::QuantValue;
698
699        let values = vec![0.0f64, 1.0, 2.0, 3.0, 4.0, 5.0];
700        let tensor = FlexTensor::new(
701            Bytes::from_elems(values),
702            Layout::contiguous([6].into()),
703            DType::F64,
704        );
705        assert_eq!(tensor.dtype(), DType::F64);
706
707        let scheme = QuantScheme::default()
708            .with_value(QuantValue::Q8S)
709            .with_store(QuantStore::Native);
710
711        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
712        assert_eq!(qtensor.tensor.dtype(), DType::I8);
713
714        // Dequantize and verify round-trip accuracy
715        let result = Flex::dequantize(qtensor, FloatDType::F32);
716        let result_vals: &[f32] = result.storage();
717        let expected = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
718        for (exp, deq) in expected.iter().zip(result_vals.iter()) {
719            assert!(
720                (exp - deq).abs() < 0.15,
721                "expected={exp}, dequantized={deq}"
722            );
723        }
724    }
725
726    #[test]
727    fn test_dequantize_f64() {
728        let values = vec![0.0f32, 1.0, 2.0, 3.0];
729        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [4]));
730
731        let scheme = QuantScheme::default()
732            .with_value(QuantValue::Q8S)
733            .with_store(QuantStore::Native);
734
735        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
736        let result = Flex::dequantize(qtensor, FloatDType::F64);
737        assert_eq!(result.dtype(), DType::F64);
738        let result_vals: &[f64] = result.storage();
739        for (orig, deq) in values.iter().zip(result_vals.iter()) {
740            assert!(
741                (*orig as f64 - deq).abs() < 0.05,
742                "orig={orig}, dequantized={deq}"
743            );
744        }
745    }
746
747    #[test]
748    fn test_dequantize_f16() {
749        let values = vec![0.0f32, 1.0, 2.0, 3.0];
750        let tensor = FlexTensor::from_data(TensorData::new(values.clone(), [4]));
751
752        let scheme = QuantScheme::default()
753            .with_value(QuantValue::Q8S)
754            .with_store(QuantStore::Native);
755
756        let qtensor = Flex::quantize_dynamic(tensor, &scheme);
757        let result = Flex::dequantize(qtensor, FloatDType::F16);
758        assert_eq!(result.dtype(), DType::F16);
759        let result_vals: &[f16] = result.storage();
760        for (orig, deq) in values.iter().zip(result_vals.iter()) {
761            assert!(
762                (*orig - f32::from(*deq)).abs() < 0.05,
763                "orig={orig}, dequantized={deq}"
764            );
765        }
766    }
767}