1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use core::future::Future;

use crate::{
    backend::Backend,
    quantization::{QuantizationParametersPrimitive, QuantizationScheme},
    Device, Shape, TensorData,
};

use super::{FloatTensor, QuantizedTensor};

/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function.
pub trait QTensorOps<B: Backend> {
    /// Creates a new tensor from the data structure.
    ///
    /// # Arguments
    ///
    /// * `data` - The data structure.
    /// * `device` - The device to create the tensor on.
    ///
    /// # Returns
    ///
    /// The tensor with the given data.
    fn q_from_data<const D: usize>(data: TensorData, device: &Device<B>) -> QuantizedTensor<B, D>;

    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
    fn quantize<const D: usize>(
        tensor: FloatTensor<B, D>,
        scheme: &QuantizationScheme,
        qparams: QuantizationParametersPrimitive<B>,
    ) -> QuantizedTensor<B, D>;

    /// Convert the tensor back to a higher precision data type.
    fn dequantize<const D: usize>(tensor: QuantizedTensor<B, D>) -> FloatTensor<B, D>;

    /// Gets the shape of the tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor.
    ///
    /// # Returns
    ///
    /// The shape of the tensor.
    fn q_shape<const D: usize>(tensor: &QuantizedTensor<B, D>) -> Shape<D>;

    /// Gets the device of the tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor.
    ///
    /// # Returns
    ///
    /// The device of the tensor.
    fn q_device<const D: usize>(tensor: &QuantizedTensor<B, D>) -> Device<B>;

    /// Reshapes a tensor.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor to reshape.
    /// * `shape` - The new shape of the tensor.
    ///
    /// # Returns
    ///
    /// The tensor with the new shape.
    fn q_reshape<const D1: usize, const D2: usize>(
        tensor: QuantizedTensor<B, D1>,
        shape: Shape<D2>,
    ) -> QuantizedTensor<B, D2>;

    /// Converts the tensor to a data structure.
    ///
    /// # Arguments
    ///
    /// * `tensor` - The tensor.
    ///
    /// # Returns
    ///
    /// The data structure with the tensor's data.
    fn q_into_data<const D: usize>(
        tensor: QuantizedTensor<B, D>,
    ) -> impl Future<Output = TensorData> + Send;

    /// Sets the `require_grad` flag of a tensor.
    fn q_set_require_grad<const D: usize>(
        tensor: QuantizedTensor<B, D>,
        _require_grad: bool,
    ) -> QuantizedTensor<B, D> {
        // Should only be overridden by autodiff backends.
        tensor
    }

    /// Returns the `require_grad` flag of a tensor.
    fn q_is_require_grad<const D: usize>(_tensor: &QuantizedTensor<B, D>) -> bool {
        // Should only be overridden by autodiff backends.
        false
    }
}