baracuda_kernels/quantize/
per_tensor_backward.rs1use core::ffi::c_void;
14use core::marker::PhantomData;
15
16use baracuda_cutlass::{Error, Result};
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
20 ScalarType, TensorMut, TensorRef, Workspace,
21};
22
23use super::map_status;
24use super::per_tensor::build_sku;
25use super::{validate_input_element, validate_output_element};
26
27#[derive(Copy, Clone, Debug)]
33pub struct QuantizePerTensorBackwardDescriptor {
34 pub numel: i32,
36 pub q_min: i32,
38 pub q_max: i32,
40 pub input_element: ElementKind,
42 pub output_element: ElementKind,
45}
46
47pub struct QuantizePerTensorBackwardArgs<'a, TIn: Element, TOut: IntElement> {
49 pub input: TensorRef<'a, TIn, 1>,
51 pub scale: <TIn as Element>::Scalar,
53 pub zero_point: i32,
55 pub d_output: TensorRef<'a, TIn, 1>,
57 pub d_input: TensorMut<'a, TIn, 1>,
59 pub _phantom: PhantomData<TOut>,
63}
64
65pub struct QuantizePerTensorBackwardPlan<TIn: Element, TOut: IntElement> {
87 desc: QuantizePerTensorBackwardDescriptor,
88 sku: KernelSku,
89 _marker: PhantomData<(TIn, TOut)>,
90}
91
92impl<TIn: Element, TOut: IntElement> QuantizePerTensorBackwardPlan<TIn, TOut> {
93 pub fn select(
95 _stream: &Stream,
96 desc: &QuantizePerTensorBackwardDescriptor,
97 _pref: PlanPreference,
98 ) -> Result<Self> {
99 if desc.input_element != TIn::KIND {
100 return Err(Error::Unsupported(
101 "QuantizePerTensorBackwardPlan: descriptor input_element != TIn",
102 ));
103 }
104 if desc.output_element != TOut::KIND {
105 return Err(Error::Unsupported(
106 "QuantizePerTensorBackwardPlan: descriptor output_element != TOut",
107 ));
108 }
109 validate_input_element(
110 TIn::KIND,
111 "QuantizePerTensorBackwardPlan: unsupported TIn dtype",
112 )?;
113 validate_output_element(
114 TOut::KIND,
115 "QuantizePerTensorBackwardPlan: unsupported TOut dtype",
116 )?;
117 if desc.numel < 0 {
118 return Err(Error::InvalidProblem(
119 "QuantizePerTensorBackwardPlan: numel must be non-negative",
120 ));
121 }
122 let sku = build_sku::<TIn, TOut>(QuantizeKind::PerTensorBackward);
123 Ok(Self {
124 desc: *desc,
125 sku,
126 _marker: PhantomData,
127 })
128 }
129
130 pub fn can_implement(
132 &self,
133 args: &QuantizePerTensorBackwardArgs<'_, TIn, TOut>,
134 ) -> Result<()> {
135 let expected = [self.desc.numel];
136 if args.input.shape != expected
137 || args.d_output.shape != expected
138 || args.d_input.shape != expected
139 {
140 return Err(Error::InvalidProblem(
141 "QuantizePerTensorBackwardPlan: tensor shape != [numel]",
142 ));
143 }
144 Ok(())
145 }
146
147 #[inline]
149 pub fn workspace_size(&self) -> usize {
150 0
151 }
152
153 #[inline]
155 pub fn sku(&self) -> KernelSku {
156 self.sku
157 }
158
159 #[inline]
161 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
162 self.sku.precision_guarantee
163 }
164
165 pub fn run(
167 &self,
168 stream: &Stream,
169 _workspace: Workspace<'_>,
170 args: QuantizePerTensorBackwardArgs<'_, TIn, TOut>,
171 ) -> Result<()> {
172 self.can_implement(&args)?;
173 let numel = self.desc.numel as i64;
174 if numel == 0 {
175 return Ok(());
176 }
177 let x_ptr = args.input.data.as_raw().0 as *const c_void;
178 let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
179 let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
180 let stream_ptr = stream.as_raw() as *mut c_void;
181 let zp = args.zero_point;
182 let qmin = self.desc.q_min;
183 let qmax = self.desc.q_max;
184
185 let status = if <TIn::Scalar as ScalarType>::IS_F64 {
186 let scale_f64 = args.scale.to_f64();
187 unsafe {
188 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_backward_f64_run(
189 numel, scale_f64, zp, qmin, qmax,
190 x_ptr, dy_ptr, dx_ptr,
191 core::ptr::null_mut(), 0, stream_ptr,
192 )
193 }
194 } else {
195 let scale_f32 = args.scale.to_f32();
196 match TIn::KIND {
197 ElementKind::F32 => unsafe {
198 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_backward_f32_run(
199 numel, scale_f32, zp, qmin, qmax,
200 x_ptr, dy_ptr, dx_ptr,
201 core::ptr::null_mut(), 0, stream_ptr,
202 )
203 },
204 ElementKind::F16 => unsafe {
205 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_backward_f16_run(
206 numel, scale_f32, zp, qmin, qmax,
207 x_ptr, dy_ptr, dx_ptr,
208 core::ptr::null_mut(), 0, stream_ptr,
209 )
210 },
211 ElementKind::Bf16 => unsafe {
212 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_backward_bf16_run(
213 numel, scale_f32, zp, qmin, qmax,
214 x_ptr, dy_ptr, dx_ptr,
215 core::ptr::null_mut(), 0, stream_ptr,
216 )
217 },
218 _ => return Err(Error::Unsupported(
219 "QuantizePerTensorBackwardPlan: unsupported TIn at run()",
220 )),
221 }
222 };
223 map_status(status)
224 }
225}