baracuda_kernels/embedding/
embedding_backward.rs1use core::ffi::c_void;
14use core::marker::PhantomData;
15
16use baracuda_cutlass::{Error, Result};
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 ArchSku, BackendKind, Element, ElementKind, EmbeddingKind, IndexElement, IndexElementKind,
20 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
21 TensorRef, Workspace,
22};
23
24use crate::indexing::gather::map_status;
25
26use super::PADDING_DISABLED;
27
28#[derive(Copy, Clone, Debug)]
30pub struct EmbeddingBackwardDescriptor {
31 pub num_embeddings: i32,
33 pub embedding_dim: i32,
35 pub num_indices: i32,
37 pub padding_idx: Option<i32>,
40 pub element: ElementKind,
42}
43
44pub struct EmbeddingBackwardArgs<'a, T: Element, I: IndexElement = i32> {
48 pub dout: TensorRef<'a, T, 2>,
50 pub indices: TensorRef<'a, I, 1>,
52 pub dweight: TensorMut<'a, T, 2>,
54}
55
56pub struct EmbeddingBackwardPlan<T: Element> {
75 desc: EmbeddingBackwardDescriptor,
76 sku: KernelSku,
77 _marker: PhantomData<T>,
78}
79
80impl<T: Element> EmbeddingBackwardPlan<T> {
81 pub fn select(
83 _stream: &Stream,
84 desc: &EmbeddingBackwardDescriptor,
85 _pref: PlanPreference,
86 ) -> Result<Self> {
87 if desc.element != T::KIND {
88 return Err(Error::Unsupported(
89 "baracuda-kernels::EmbeddingBackwardPlan: descriptor element != type parameter T",
90 ));
91 }
92 if desc.num_embeddings < 0
93 || desc.embedding_dim < 0
94 || desc.num_indices < 0
95 {
96 return Err(Error::InvalidProblem(
97 "baracuda-kernels::EmbeddingBackwardPlan: num_embeddings / embedding_dim / \
98 num_indices must be non-negative",
99 ));
100 }
101 let supported = matches!(T::KIND, ElementKind::F32 | ElementKind::F64);
102 if !supported {
103 return Err(Error::Unsupported(
104 "baracuda-kernels::EmbeddingBackwardPlan: today only `f32`, `f64` wired \
105 (BW uses atomicAdd)",
106 ));
107 }
108 let precision_guarantee = PrecisionGuarantee {
109 math_precision: if T::KIND == ElementKind::F64 {
110 MathPrecision::F64
111 } else {
112 MathPrecision::F32
113 },
114 accumulator: T::KIND,
115 bit_stable_on_same_hardware: false,
117 deterministic: false,
118 };
119 let sku = KernelSku {
120 category: OpCategory::Embedding,
121 op: EmbeddingKind::EmbeddingBackward as u16,
122 element: T::KIND,
123 aux_element: Some(ElementKind::I32),
124 layout: None,
125 epilogue: None,
126 arch: ArchSku::Sm80,
127 backend: BackendKind::Bespoke,
128 precision_guarantee,
129 };
130 Ok(Self {
131 desc: *desc,
132 sku,
133 _marker: PhantomData,
134 })
135 }
136
137 pub fn can_implement<I: IndexElement>(&self, args: &EmbeddingBackwardArgs<'_, T, I>) -> Result<()> {
139 if args.dout.shape[0] != self.desc.num_indices
140 || args.dout.shape[1] != self.desc.embedding_dim
141 {
142 return Err(Error::InvalidProblem(
143 "baracuda-kernels::EmbeddingBackwardPlan: dout shape must be \
144 [num_indices, embedding_dim]",
145 ));
146 }
147 if args.indices.shape[0] != self.desc.num_indices {
148 return Err(Error::InvalidProblem(
149 "baracuda-kernels::EmbeddingBackwardPlan: indices.shape[0] mismatch with descriptor",
150 ));
151 }
152 if args.dweight.shape[0] != self.desc.num_embeddings
153 || args.dweight.shape[1] != self.desc.embedding_dim
154 {
155 return Err(Error::InvalidProblem(
156 "baracuda-kernels::EmbeddingBackwardPlan: dweight shape must be \
157 [num_embeddings, embedding_dim]",
158 ));
159 }
160 let dout_len = args.dout.data.len() as i64;
161 let idx_len = args.indices.data.len() as i64;
162 let dw_len = args.dweight.data.len() as i64;
163 let dout_numel = args.dout.numel();
164 let idx_numel = args.indices.numel();
165 let dw_numel = args.dweight.numel();
166 if dout_len < dout_numel {
167 return Err(Error::BufferTooSmall {
168 needed: dout_numel as usize,
169 got: dout_len as usize,
170 });
171 }
172 if idx_len < idx_numel {
173 return Err(Error::BufferTooSmall {
174 needed: idx_numel as usize,
175 got: idx_len as usize,
176 });
177 }
178 if dw_len < dw_numel {
179 return Err(Error::BufferTooSmall {
180 needed: dw_numel as usize,
181 got: dw_len as usize,
182 });
183 }
184 Ok(())
185 }
186
187 #[inline]
189 pub fn workspace_size(&self) -> usize {
190 0
191 }
192
193 #[inline]
195 pub fn sku(&self) -> KernelSku {
196 self.sku
197 }
198
199 #[inline]
201 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
202 self.sku.precision_guarantee
203 }
204
205 pub fn run<I: IndexElement>(
209 &self,
210 stream: &Stream,
211 _workspace: Workspace<'_>,
212 args: EmbeddingBackwardArgs<'_, T, I>,
213 ) -> Result<()> {
214 self.can_implement(&args)?;
215 let num_indices = self.desc.num_indices as i64;
216 if num_indices == 0 || self.desc.embedding_dim == 0 {
217 return Ok(());
218 }
219 let dout_ptr = args.dout.data.as_raw().0 as *const c_void;
220 let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
221 let dweight_ptr = args.dweight.data.as_raw().0 as *mut c_void;
222 let stream_ptr = stream.as_raw() as *mut c_void;
223 let padding_idx: i64 = self.desc.padding_idx.unwrap_or(PADDING_DISABLED) as i64;
225
226 let status = match (T::KIND, I::KIND) {
227 (ElementKind::F32, IndexElementKind::I32) => unsafe {
228 baracuda_kernels_sys::baracuda_kernels_embedding_backward_f32_run(
229 num_indices, self.desc.num_embeddings, self.desc.embedding_dim,
230 padding_idx, dout_ptr, idx_ptr, dweight_ptr,
231 core::ptr::null_mut(), 0, stream_ptr,
232 )
233 },
234 (ElementKind::F64, IndexElementKind::I32) => unsafe {
235 baracuda_kernels_sys::baracuda_kernels_embedding_backward_f64_run(
236 num_indices, self.desc.num_embeddings, self.desc.embedding_dim,
237 padding_idx, dout_ptr, idx_ptr, dweight_ptr,
238 core::ptr::null_mut(), 0, stream_ptr,
239 )
240 },
241 (ElementKind::F32, IndexElementKind::I64) => unsafe {
242 baracuda_kernels_sys::baracuda_kernels_embedding_backward_i64idx_f32_run(
243 num_indices, self.desc.num_embeddings, self.desc.embedding_dim,
244 padding_idx, dout_ptr, idx_ptr, dweight_ptr,
245 core::ptr::null_mut(), 0, stream_ptr,
246 )
247 },
248 (ElementKind::F64, IndexElementKind::I64) => unsafe {
249 baracuda_kernels_sys::baracuda_kernels_embedding_backward_i64idx_f64_run(
250 num_indices, self.desc.num_embeddings, self.desc.embedding_dim,
251 padding_idx, dout_ptr, idx_ptr, dweight_ptr,
252 core::ptr::null_mut(), 0, stream_ptr,
253 )
254 },
255 _ => {
256 return Err(Error::Unsupported(
257 "baracuda-kernels::EmbeddingBackwardPlan::run reached an unimplemented dtype \
258 — select() should have caught this",
259 ));
260 }
261 };
262 map_status(status)
263 }
264}