pub struct EmbeddingBackwardPlan<T: Element> { /* private fields */ }Expand description
embedding_backward plan.
Adjoint of crate::EmbeddingPlan:
dweight[indices[n], :] += dout[n, :] via atomicAdd. Rows where
indices[n] == padding_idx (or negative / out-of-range) are skipped.
When to use: backward for EmbeddingPlan.
Dtypes: {f32, f64} only — atomicAdd is native-FP.
Shape limits: dweight is [V, D], dout is [N, D],
indices is [N], all extents non-negative.
Workspace: none. Caller MUST zero dweight before launch
(or pre-populate to accumulate into a running gradient).
Precision guarantee: non-deterministic — atomicAdd ordering varies between launches.
Implementations§
Source§impl<T: Element> EmbeddingBackwardPlan<T>
impl<T: Element> EmbeddingBackwardPlan<T>
Sourcepub fn select(
_stream: &Stream,
desc: &EmbeddingBackwardDescriptor,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &EmbeddingBackwardDescriptor, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for desc.
Sourcepub fn can_implement<I: IndexElement>(
&self,
args: &EmbeddingBackwardArgs<'_, T, I>,
) -> Result<()>
pub fn can_implement<I: IndexElement>( &self, args: &EmbeddingBackwardArgs<'_, T, I>, ) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace size in bytes (zero).
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees for this plan’s kernel.
Sourcepub fn run<I: IndexElement>(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: EmbeddingBackwardArgs<'_, T, I>,
) -> Result<()>
pub fn run<I: IndexElement>( &self, stream: &Stream, _workspace: Workspace<'_>, args: EmbeddingBackwardArgs<'_, T, I>, ) -> Result<()>
Launch.
Phase 11.5: generic over I: IndexElement.
Auto Trait Implementations§
impl<T> Freeze for EmbeddingBackwardPlan<T>
impl<T> RefUnwindSafe for EmbeddingBackwardPlan<T>where
T: RefUnwindSafe,
impl<T> Send for EmbeddingBackwardPlan<T>where
T: Send,
impl<T> Sync for EmbeddingBackwardPlan<T>where
T: Sync,
impl<T> Unpin for EmbeddingBackwardPlan<T>where
T: Unpin,
impl<T> UnsafeUnpin for EmbeddingBackwardPlan<T>
impl<T> UnwindSafe for EmbeddingBackwardPlan<T>where
T: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more