Skip to main content

baracuda_kernels/quantize/
dequantize_per_tensor.rs

1//! `dequantize_per_tensor` forward plan.
2//!
3//! `x = scale * (q - zero_point)`. Linear; exactly invertible (up to
4//! FW rounding) against [`super::QuantizePerTensorPlan`]. Output is FP-
5//! typed (`TIn`); the int input is `TOut` (the same int dtype the FW
6//! quantized into).
7
8use core::ffi::c_void;
9use core::marker::PhantomData;
10
11use baracuda_cutlass::{Error, Result};
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14    Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
15    ScalarType, TensorMut, TensorRef, Workspace,
16};
17
18use super::map_status;
19use super::per_tensor::build_sku;
20use super::{validate_input_element, validate_output_element};
21
22/// Descriptor for a `dequantize_per_tensor` op.
23#[derive(Copy, Clone, Debug)]
24pub struct DequantizePerTensorDescriptor {
25    /// Total element count.
26    pub numel: i32,
27    /// Output FP element kind (same as FW input).
28    pub input_element: ElementKind,
29    /// Input int element kind (s8 or u8 — the FW output dtype).
30    pub output_element: ElementKind,
31}
32
33/// Args bundle for a `dequantize_per_tensor` launch.
34pub struct DequantizePerTensorArgs<'a, TIn: Element, TOut: IntElement> {
35    /// Input int tensor `[numel]`.
36    pub input: TensorRef<'a, TOut, 1>,
37    /// Scalar scale (FP).
38    pub scale: <TIn as Element>::Scalar,
39    /// Scalar zero point.
40    pub zero_point: i32,
41    /// Output FP tensor `[numel]`.
42    pub output: TensorMut<'a, TIn, 1>,
43}
44
45/// `dequantize_per_tensor` plan.
46///
47/// `x = scale * (q - zero_point)`. Linear; exactly invertible (up to
48/// FW rounding) against [`QuantizePerTensorPlan`](crate::QuantizePerTensorPlan).
49///
50/// **When to use**: FP recovery from a per-tensor-quantized buffer.
51/// Pair with [`DequantizePerTensorBackwardPlan`](crate::DequantizePerTensorBackwardPlan)
52/// for autograd through the dequant op.
53///
54/// **Dtypes**: input int `{s8, u8}` (= `TOut`); output FP
55/// `{f32, f64, f16, bf16}` (= `TIn`).
56///
57/// **Shape limits**: flat `[numel]`.
58///
59/// **Workspace**: none.
60///
61/// **Precision guarantee**: deterministic, bit-stable.
62pub struct DequantizePerTensorPlan<TIn: Element, TOut: IntElement> {
63    desc: DequantizePerTensorDescriptor,
64    sku: KernelSku,
65    _marker: PhantomData<(TIn, TOut)>,
66}
67
68impl<TIn: Element, TOut: IntElement> DequantizePerTensorPlan<TIn, TOut> {
69    /// Pick a kernel.
70    pub fn select(
71        _stream: &Stream,
72        desc: &DequantizePerTensorDescriptor,
73        _pref: PlanPreference,
74    ) -> Result<Self> {
75        if desc.input_element != TIn::KIND {
76            return Err(Error::Unsupported(
77                "DequantizePerTensorPlan: descriptor input_element != TIn",
78            ));
79        }
80        if desc.output_element != TOut::KIND {
81            return Err(Error::Unsupported(
82                "DequantizePerTensorPlan: descriptor output_element != TOut",
83            ));
84        }
85        validate_input_element(TIn::KIND, "DequantizePerTensorPlan: unsupported TIn dtype")?;
86        validate_output_element(TOut::KIND, "DequantizePerTensorPlan: unsupported TOut dtype")?;
87        if desc.numel < 0 {
88            return Err(Error::InvalidProblem(
89                "DequantizePerTensorPlan: numel must be non-negative",
90            ));
91        }
92        let sku = build_sku::<TIn, TOut>(QuantizeKind::DequantizePerTensor);
93        Ok(Self {
94            desc: *desc,
95            sku,
96            _marker: PhantomData,
97        })
98    }
99
100    /// Validate args.
101    pub fn can_implement(&self, args: &DequantizePerTensorArgs<'_, TIn, TOut>) -> Result<()> {
102        let expected = [self.desc.numel];
103        if args.input.shape != expected || args.output.shape != expected {
104            return Err(Error::InvalidProblem(
105                "DequantizePerTensorPlan: tensor shape != [numel]",
106            ));
107        }
108        Ok(())
109    }
110
111    /// Workspace bytes.
112    #[inline]
113    pub fn workspace_size(&self) -> usize {
114        0
115    }
116
117    /// Identity.
118    #[inline]
119    pub fn sku(&self) -> KernelSku {
120        self.sku
121    }
122
123    /// Numerical guarantees.
124    #[inline]
125    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
126        self.sku.precision_guarantee
127    }
128
129    /// Launch.
130    pub fn run(
131        &self,
132        stream: &Stream,
133        _workspace: Workspace<'_>,
134        args: DequantizePerTensorArgs<'_, TIn, TOut>,
135    ) -> Result<()> {
136        self.can_implement(&args)?;
137        let numel = self.desc.numel as i64;
138        if numel == 0 {
139            return Ok(());
140        }
141        let q_ptr = args.input.data.as_raw().0 as *const c_void;
142        let x_ptr = args.output.data.as_raw().0 as *mut c_void;
143        let stream_ptr = stream.as_raw() as *mut c_void;
144        let zp = args.zero_point;
145
146        let status = if <TIn::Scalar as ScalarType>::IS_F64 {
147            let scale_f64 = args.scale.to_f64();
148            match TOut::KIND {
149                ElementKind::S8 => unsafe {
150                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f64_s8_run(
151                        numel, scale_f64, zp, q_ptr, x_ptr,
152                        core::ptr::null_mut(), 0, stream_ptr,
153                    )
154                },
155                ElementKind::U8 => unsafe {
156                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f64_u8_run(
157                        numel, scale_f64, zp, q_ptr, x_ptr,
158                        core::ptr::null_mut(), 0, stream_ptr,
159                    )
160                },
161                _ => return Err(Error::Unsupported(
162                    "DequantizePerTensorPlan: unsupported TOut at run()",
163                )),
164            }
165        } else {
166            let scale_f32 = args.scale.to_f32();
167            match (TIn::KIND, TOut::KIND) {
168                (ElementKind::F32, ElementKind::S8) => unsafe {
169                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f32_s8_run(
170                        numel, scale_f32, zp, q_ptr, x_ptr,
171                        core::ptr::null_mut(), 0, stream_ptr,
172                    )
173                },
174                (ElementKind::F32, ElementKind::U8) => unsafe {
175                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f32_u8_run(
176                        numel, scale_f32, zp, q_ptr, x_ptr,
177                        core::ptr::null_mut(), 0, stream_ptr,
178                    )
179                },
180                (ElementKind::F16, ElementKind::S8) => unsafe {
181                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f16_s8_run(
182                        numel, scale_f32, zp, q_ptr, x_ptr,
183                        core::ptr::null_mut(), 0, stream_ptr,
184                    )
185                },
186                (ElementKind::F16, ElementKind::U8) => unsafe {
187                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f16_u8_run(
188                        numel, scale_f32, zp, q_ptr, x_ptr,
189                        core::ptr::null_mut(), 0, stream_ptr,
190                    )
191                },
192                (ElementKind::Bf16, ElementKind::S8) => unsafe {
193                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_bf16_s8_run(
194                        numel, scale_f32, zp, q_ptr, x_ptr,
195                        core::ptr::null_mut(), 0, stream_ptr,
196                    )
197                },
198                (ElementKind::Bf16, ElementKind::U8) => unsafe {
199                    baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_bf16_u8_run(
200                        numel, scale_f32, zp, q_ptr, x_ptr,
201                        core::ptr::null_mut(), 0, stream_ptr,
202                    )
203                },
204                _ => return Err(Error::Unsupported(
205                    "DequantizePerTensorPlan: unsupported (TIn, TOut) at run()",
206                )),
207            }
208        };
209        map_status(status)
210    }
211}