baracuda_kernels/quantize/
fake_quantize.rs1use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22 PlanPreference, PrecisionGuarantee, QuantizeKind, ScalarType, TensorMut, TensorRef, Workspace,
23};
24
25use super::{map_status, validate_input_element};
26
27#[derive(Copy, Clone, Debug)]
29pub struct FakeQuantizeDescriptor {
30 pub numel: i32,
32 pub q_min: i32,
34 pub q_max: i32,
36 pub input_element: ElementKind,
38}
39
40pub struct FakeQuantizeArgs<'a, TIn: Element> {
42 pub input: TensorRef<'a, TIn, 1>,
44 pub scale: <TIn as Element>::Scalar,
46 pub zero_point: i32,
48 pub output: TensorMut<'a, TIn, 1>,
50}
51
52pub struct FakeQuantizePlan<TIn: Element> {
72 desc: FakeQuantizeDescriptor,
73 sku: KernelSku,
74 _marker: PhantomData<TIn>,
75}
76
77impl<TIn: Element> FakeQuantizePlan<TIn> {
78 pub fn select(
80 _stream: &Stream,
81 desc: &FakeQuantizeDescriptor,
82 _pref: PlanPreference,
83 ) -> Result<Self> {
84 if desc.input_element != TIn::KIND {
85 return Err(Error::Unsupported(
86 "FakeQuantizePlan: descriptor input_element != TIn",
87 ));
88 }
89 validate_input_element(TIn::KIND, "FakeQuantizePlan: unsupported TIn dtype")?;
90 if desc.numel < 0 {
91 return Err(Error::InvalidProblem(
92 "FakeQuantizePlan: numel must be non-negative",
93 ));
94 }
95 if desc.q_max < desc.q_min {
96 return Err(Error::InvalidProblem("FakeQuantizePlan: q_max < q_min"));
97 }
98 let sku = build_sku::<TIn>(QuantizeKind::FakeQuantize);
99 Ok(Self {
100 desc: *desc,
101 sku,
102 _marker: PhantomData,
103 })
104 }
105
106 pub fn can_implement(&self, args: &FakeQuantizeArgs<'_, TIn>) -> Result<()> {
108 let expected = [self.desc.numel];
109 if args.input.shape != expected || args.output.shape != expected {
110 return Err(Error::InvalidProblem(
111 "FakeQuantizePlan: tensor shape != [numel]",
112 ));
113 }
114 Ok(())
115 }
116
117 #[inline]
119 pub fn workspace_size(&self) -> usize {
120 0
121 }
122
123 #[inline]
125 pub fn sku(&self) -> KernelSku {
126 self.sku
127 }
128
129 #[inline]
131 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132 self.sku.precision_guarantee
133 }
134
135 pub fn run(
137 &self,
138 stream: &Stream,
139 _workspace: Workspace<'_>,
140 args: FakeQuantizeArgs<'_, TIn>,
141 ) -> Result<()> {
142 self.can_implement(&args)?;
143 let numel = self.desc.numel as i64;
144 if numel == 0 {
145 return Ok(());
146 }
147 let x_ptr = args.input.data.as_raw().0 as *const c_void;
148 let y_ptr = args.output.data.as_raw().0 as *mut c_void;
149 let stream_ptr = stream.as_raw() as *mut c_void;
150 let zp = args.zero_point;
151 let qmin = self.desc.q_min;
152 let qmax = self.desc.q_max;
153
154 let status = if <TIn::Scalar as ScalarType>::IS_F64 {
155 let scale_f64 = args.scale.to_f64();
156 unsafe {
157 baracuda_kernels_sys::baracuda_kernels_fake_quantize_f64_run(
158 numel, scale_f64, zp, qmin, qmax, x_ptr, y_ptr,
159 core::ptr::null_mut(), 0, stream_ptr,
160 )
161 }
162 } else {
163 let scale_f32 = args.scale.to_f32();
164 match TIn::KIND {
165 ElementKind::F32 => unsafe {
166 baracuda_kernels_sys::baracuda_kernels_fake_quantize_f32_run(
167 numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
168 core::ptr::null_mut(), 0, stream_ptr,
169 )
170 },
171 ElementKind::F16 => unsafe {
172 baracuda_kernels_sys::baracuda_kernels_fake_quantize_f16_run(
173 numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
174 core::ptr::null_mut(), 0, stream_ptr,
175 )
176 },
177 ElementKind::Bf16 => unsafe {
178 baracuda_kernels_sys::baracuda_kernels_fake_quantize_bf16_run(
179 numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
180 core::ptr::null_mut(), 0, stream_ptr,
181 )
182 },
183 _ => return Err(Error::Unsupported(
184 "FakeQuantizePlan: unsupported TIn at run()",
185 )),
186 }
187 };
188 map_status(status)
189 }
190}
191
192pub(crate) fn build_sku<TIn: Element>(op: QuantizeKind) -> KernelSku {
196 let precision_guarantee = PrecisionGuarantee {
197 math_precision: if TIn::KIND == ElementKind::F64 {
198 MathPrecision::F64
199 } else {
200 MathPrecision::F32
201 },
202 accumulator: ElementKind::F32,
203 bit_stable_on_same_hardware: true,
204 deterministic: true,
205 };
206 KernelSku {
207 category: OpCategory::Quantization,
208 op: op as u16,
209 element: TIn::KIND,
210 aux_element: None,
211 layout: None,
212 epilogue: None,
213 arch: ArchSku::Sm80,
214 backend: BackendKind::Bespoke,
215 precision_guarantee,
216 }
217}