Skip to main content

ferrotorch_nn/
embedding.rs

1//! Embedding layer: a lookup table of fixed-size vectors.
2//!
3//! Maps integer indices (stored as `T` values and cast to `usize`) to
4//! dense vectors. This is the standard way to represent discrete tokens
5//! (words, subwords, categorical features) as continuous vectors for
6//! gradient-based learning.
7//!
8//! The backward pass implements a sparse scatter-add: only the rows that
9//! were accessed receive gradient, and duplicate indices accumulate.
10//!
11//! ## REQ status (per `.design/ferrotorch-nn/embedding.md`)
12//!
13//! | REQ | Status | Evidence |
14//! |---|---|---|
15//! | REQ-1 | SHIPPED | impl: `pub struct Embedding<T: Float>` here with `weight` / `num_embeddings` / `embedding_dim` / `padding_idx` / `sparse` fields, mirroring `torch/nn/modules/sparse.py:37-50`; non-test consumer: `ferrotorch-llama/src/model.rs` declares `pub embed_tokens: Embedding<T>` as a model field. |
16//! | REQ-2 | SHIPPED | impl: the `Embedding::new` constructor here (with `padding_idx` validation + N(0,1) init + padding-row zero); non-test consumer: `Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?` in `ferrotorch-llama/src/model.rs` is the Llama model's token-embedding constructor. |
17//! | REQ-3 | SHIPPED | impl: `<Embedding as Module>::forward` body here (gather + grad-attach); non-test consumer: `ferrotorch-llama` model's forward path calls `self.embed_tokens.forward(input_ids)` on every training step and inference token. |
18//! | REQ-4 | SHIPPED | impl: `pub struct EmbeddingBackward<T>` and its `GradFn::backward` body here; non-test consumer: every `loss.backward()` call in the Llama training scaffolding traverses `EmbeddingBackward` nodes via `ferrotorch_core::autograd::engine`. |
19//! | REQ-5 | SHIPPED | impl: `grad_output.is_cuda()` + `scatter_add_rows_f32/f64` dispatch inside `EmbeddingBackward::backward` here; non-test consumer: `ferrotorch-gpu/src/backend_impl.rs` exposes `Backend::scatter_add_rows_f32`; GPU training-loop runs on the Llama model trigger this on every embedding backward. |
20//! | REQ-6 | SHIPPED | impl: the `Embedding::sparse_grad` accessor here returning a `SparseGrad<T>`; non-test consumer: `ferrotorch_optim::SparseAdam::collect_sparse_grad_from_embedding` (`ferrotorch-optim/src/sparse_adam.rs`) calls `Embedding::sparse_grad` and registers it via `set_sparse_grad`, then `SparseAdam::step` applies the masked sparse-Adam update — the wired `nn.Embedding(sparse=True)` → `torch.optim.SparseAdam` flow (`torch/optim/sparse_adam.py:132-161`). |
21//! | REQ-7 | SHIPPED | impl: `pub struct EmbeddingBag<T: Float>` + `pub enum EmbeddingBagMode` + `Module` impl here; non-test consumer: `pub use embedding::{EmbeddingBag, EmbeddingBagMode}` in `lib.rs` exposes the type for downstream models. |
22//! | REQ-8 | SHIPPED | impl: both `Module<T> for Embedding<T>` and `Module<T> for EmbeddingBag<T>` impl blocks here; non-test consumer: `ferrotorch_optim::Optimizer` iterates `model.parameters_mut()` which surfaces the embedding's weight parameter for every step. |
23//! | REQ-9 | SHIPPED | impl: free fn `renorm_weight_rows_in_place` here (faithful translation of `embedding_renorm_cpu_` at `aten/src/ATen/native/Embedding.cpp:181-212` — sort+dedup touched rows, row norm via `at::norm` special-cased per `aten/src/ATen/native/cpu/ReduceOpsKernel.cpp:191-203` for `norm_type` 0/+inf/-inf, scale rows with norm > max_norm by `max_norm/(norm+1e-7)`, persist via `Tensor::update_data`), called by `Embedding::renorm_weight_in_place` and `EmbeddingBag::forward_bag`. L2 PRECISION (#1614): the default `norm_type == 2.0` f32 row reduces via `ferrotorch_core::simd_reduce::l2_norm_f32_torch` (torch's vectorized last-dim L2 kernel model, `ReduceOpsKernel.cpp:222-255`, f32 accumulator) so the `norm > max_norm` boundary decision matches torch byte-for-byte (closing the powf-vs-`v*v` summation-method gap #1612 left open); f64 rows and finite `p != 2` keep the generic `(Σ|x|^p)^(1/p)` arm. `with_max_norm`/`with_norm_type`/`with_scale_grad_by_freq` builders on `Embedding<T>`, plus `EmbeddingBag::new_with` + `with_max_norm`/`with_norm_type`/`with_scale_grad_by_freq`/`with_sparse`/`with_include_last_offset` and `padding_idx` exclusion in `forward_bag`. `EmbeddingBackward::scale_grad_by_freq` divides each touched row's grad by its forward count (`torch/nn/functional.py:2499-2500`). Renorm runs BEFORE the gather, matching `F.embedding`/`F.embedding_bag` (`functional.py:2561-2573`, `2766-2771`). Consumer surface: per goal.md S5, `Embedding`/`EmbeddingBag` ARE boundary public API (the module mirrors `torch.nn.Embedding`/`torch.nn.EmbeddingBag` field-for-field — the user-facing kwargs ARE the deliverable), grandfathered SHIPPED with no further downstream caller required. The renorm is on the live forward path: `<Embedding as Module>::forward` here calls `self.renorm_weight_in_place(&indices)?` on every forward (no-op when `max_norm` unset), and `EmbeddingBag::forward_bag` / `<EmbeddingBag as Module>::forward` consume the bag kwargs; both types are re-exported via `pub use embedding::{Embedding, EmbeddingBag, EmbeddingBagMode}` in `lib.rs` as the public consumer surface. (NB #1566: the prior cite to `ferrotorch-llama/src/model.rs embed_tokens` as the renorm consumer was FALSE — `model.rs` constructs `Embedding::new(.., None)` with no `max_norm`/`EmbeddingBag`; corrected to the S5 boundary-API rationale.) |
24//! | REQ-10 | NOT-STARTED | blocker #1441 (umbrella) — parity-sweep runner arms absent for both `nn.functional.embedding` and `nn.functional.embedding_bag`. Lib tests verify the impl end-to-end. |
25//! | REQ-11 | SHIPPED | impl: `pub fn forward_bag_weighted` + `struct EmbeddingBagSumWeightedBackward` here — `per_sample_weights` (#1610): sum-mode-only per-sample scaling before the bag reduction (`aten/src/ATen/native/EmbeddingBag.cpp:537-543`), grad to BOTH the embedding table (`grad[bag]*psw`, `EmbeddingBag.cpp:1564-1582`) AND `per_sample_weights` (`dot(grad[bag], weight[idx])`, `EmbeddingBag.cpp:1716-1724`); `mode!='sum'` returns torch's exact `NotImplementedError` text (`torch/nn/functional.py:2773-2778`), shape-mismatch matches `functional.py:2698-2702`, `padding_idx` samples contribute 0 to both grads. Non-test production consumer: the existing 2-arg `EmbeddingBag::forward_bag` (called by the parity-sweep `embedding_bag` runner arm + boundary public API re-exported at `lib.rs`) is rewired in this commit to delegate to `forward_bag_weighted(.., None)`, so the new pub method has an in-production caller (R-DEFER-1). Verified by the `test_bag_psw_*` live-torch-2.11 oracle lib tests. |
26
27use std::any::TypeId;
28use std::sync::Arc;
29
30use ferrotorch_core::autograd::no_grad::is_grad_enabled;
31use ferrotorch_core::device::Device;
32use ferrotorch_core::dtype::DType;
33use ferrotorch_core::gpu_dispatch::{GpuBufferHandle, gpu_backend};
34use ferrotorch_core::tensor::GradFn;
35use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
36
37use crate::init;
38use crate::module::Module;
39use crate::parameter::Parameter;
40
41/// Returns `true` if `T` is `f32`.
42#[inline]
43fn is_f32<T: Float>() -> bool {
44    TypeId::of::<T>() == TypeId::of::<f32>()
45}
46
47/// Returns `true` if `T` is `f64`.
48#[inline]
49fn is_f64<T: Float>() -> bool {
50    TypeId::of::<T>() == TypeId::of::<f64>()
51}
52
53/// Upload a CPU `&[f32]` slice to a GPU buffer on the given device ordinal.
54fn upload_f32_to_gpu(data: &[f32], ordinal: usize) -> FerrotorchResult<GpuBufferHandle> {
55    let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
56    // SAFETY: `data` is a live `&[f32]` borrow; its memory is valid for reads of
57    // `data.len() * 4` bytes (every `f32` is exactly 4 bytes — `size_of::<f32>() == 4`,
58    // guaranteed by the language and verified by `mem::size_of`). The cast from
59    // `*const f32` to `*const u8` does not violate alignment (alignment of `u8` is 1,
60    // strictly weaker than `f32`'s alignment of 4). The resulting `&[u8]` is borrowed
61    // for the duration of this expression and consumed by `backend.cpu_to_gpu` before
62    // `data` goes out of scope, so the lifetime never outlives the source borrow.
63    // No interior mutability — `data` is a shared reference and `f32` has no padding.
64    let bytes: &[u8] =
65        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) };
66    backend.cpu_to_gpu(bytes, DType::F32, ordinal)
67}
68
69/// Renormalise the rows of `weight` touched by `indices`, IN PLACE, so each
70/// touched row's `norm_type`-norm is at most `max_norm`.
71///
72/// Faithful translation of `embedding_renorm_cpu_`
73/// (`aten/src/ATen/native/Embedding.cpp:181-212`): indices are sorted and
74/// de-duplicated; each unique row whose norm exceeds `max_norm` is scaled by
75/// `max_norm / (norm + 1e-7)`. Rows within `max_norm`, and rows never indexed
76/// this forward, are left untouched. PyTorch runs this BEFORE the gather under
77/// `torch.no_grad()` (`torch/nn/functional.py:2561-2573`), mutating the
78/// persisted `weight`, so the change survives across forward calls — this
79/// function matches that by writing the renormed rows back via
80/// [`Tensor::update_data`].
81///
82/// Shared by `Embedding` and `EmbeddingBag` so both layers' `max_norm`
83/// semantics stay byte-identical. CUDA weights have no on-device renorm kernel
84/// yet, so this returns `NotImplementedOnCuda` rather than silently skipping.
85fn renorm_weight_rows_in_place<T: Float>(
86    weight: &Tensor<T>,
87    indices: &[usize],
88    dim: usize,
89    max_norm: f64,
90    norm_type: f64,
91    op: &'static str,
92) -> FerrotorchResult<()> {
93    if weight.is_cuda() {
94        return Err(FerrotorchError::NotImplementedOnCuda { op });
95    }
96
97    // Sort + dedup, mirroring `std::sort` + the `sorted[i]==sorted[i-1]` skip
98    // at Embedding.cpp:193-201. Visiting each unique row once is required:
99    // re-scaling an already-clipped row would shrink it below max_norm.
100    let mut sorted: Vec<usize> = indices.to_vec();
101    sorted.sort_unstable();
102    sorted.dedup();
103
104    let weight_data = weight.data()?;
105    let mut new_data: Option<Vec<T>> = None;
106    for &idx in &sorted {
107        let row_start = idx * dim;
108        let row = &weight_data[row_start..row_start + dim];
109        // `row.norm(norm_type)` = `at::norm`, which special-cases the
110        // non-finite / degenerate orders rather than evaluating the generic
111        // `(Σ|x|^p)^(1/p)` formula — that formula gives `inf^0 = 1` for
112        // `p = +inf` and `x^0 = 1` for `p = 0`, both wrong. Mirror the kernel
113        // dispatch at `aten/src/ATen/native/cpu/ReduceOpsKernel.cpp:191-203`:
114        //   p == 0     -> NormZeroOps : count of nonzero elements (L0)
115        //   p == +inf  -> AbsMaxOps   : max_i |x_i|  (infinity norm)
116        //   p == -inf  -> AbsMinOps   : min_i |x_i|  (acc seeded +inf)
117        //   else       -> NormOps     : (Σ|x|^p)^(1/p)
118        // (p == 1 / p == 2 are exact under the generic formula, so they need
119        // no separate arm here.)
120        //
121        // PRECISION (#1612): the norm is accumulated and rooted in the WEIGHT'S
122        // NATIVE dtype `T`, then widened to f64 only for the `> max_norm`
123        // compare and the scale. This mirrors `row.norm(norm_type).item<double>()`
124        // at `Embedding.cpp:202-203` byte-for-byte: `at::norm`'s accumulator is
125        // `at::opmath_type<scalar_t>` (`ReduceOpsKernel.cpp:190`), which is
126        // `float` for an f32 row and `double` for an f64 row, and the result is
127        // stored back as `scalar_t` (`result_data[0] = scalar_t(std::sqrt(..))`,
128        // `ReduceOpsKernel.cpp:253`); `.item<double>()` widens that already-`T`-
129        // rounded scalar AFTER the fact. Accumulating in f64 for an f32 weight
130        // would make the clip DECISION on a value torch never sees — at the
131        // boundary (f32 norm == max_norm) torch does NOT clip but an f64 norm
132        // can land just above, wrongly scaling the row (#1612).
133        let norm_t: T = if norm_type == 0.0 {
134            // NormZeroOps (`SharedReduceOps.h:285`): count of nonzeros.
135            T::from(
136                row.iter()
137                    .filter(|&&v| v != <T as num_traits::Zero>::zero())
138                    .count(),
139            )
140            .unwrap_or_else(<T as num_traits::Zero>::zero)
141        } else if norm_type == f64::INFINITY {
142            // AbsMaxOps (`SharedReduceOps.h:216`): max_i |x_i|, in `T`.
143            row.iter().fold(<T as num_traits::Zero>::zero(), |acc, &v| {
144                let av = v.abs();
145                if av > acc { av } else { acc }
146            })
147        } else if norm_type == f64::NEG_INFINITY {
148            // AbsMinOps (`SharedReduceOps.h:186`): min_i |x_i|, acc seeded +inf,
149            // in `T`.
150            row.iter().fold(T::infinity(), |acc, &v| {
151                let av = v.abs();
152                if av < acc { av } else { acc }
153            })
154        } else if norm_type == 2.0 && is_f32::<T>() {
155            // L2 FAST PATH (#1614): the default `norm_type == 2.0` over a
156            // contiguous f32 row is what torch's `at::norm(2.0)` evaluates via
157            // its VECTORIZED last-dim L2 kernel (`ReduceOpsKernel.cpp:222-255`):
158            // a width-8 lane accumulate of `v*v` + a naive left-fold + a scalar
159            // FMA tail + `sqrt`, all in an f32 (NOT f64) accumulator. A scalar
160            // `Σ |v|.powf(2)` then `.powf(0.5)` (the generic arm below) lands up
161            // to one ULP off that value, flipping the `norm > max_norm` boundary
162            // decision (#1612 / #1614). Route the f32 L2 row through the shared
163            // `ferrotorch_core::simd_reduce::l2_norm_f32_torch` primitive so the
164            // renorm decision matches torch byte-for-byte (modulo the documented
165            // ~3% one-ULP residual; the #1614 boundary row IS matched).
166            //
167            // `row: &[T]` is f32 here (guarded by `is_f32::<T>()`); collect it as
168            // `&[f32]` via the exact identity `ToPrimitive::to_f32`.
169            let mut row_f32: Vec<f32> = Vec::with_capacity(row.len());
170            for &v in row {
171                row_f32.push(num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0));
172            }
173            let n_f32 = ferrotorch_core::simd_reduce::l2_norm_f32_torch(&row_f32);
174            // Lift the f32 norm back into `T` (== f32). The unwrap is on the
175            // identity f32->f32 NumCast, which never fails for finite/inf/NaN.
176            T::from(n_f32).unwrap_or_else(<T as num_traits::Zero>::zero)
177        } else {
178            // NormOps: generic finite p-norm `(Σ|x|^p)^(1/p)`, accumulated and
179            // rooted in `T` (f32 for an f32 weight) to match `at::norm`. (Used
180            // for f64 rows, and for finite p != 2; the f32 L2 case is handled
181            // by the byte-exact `simd_reduce` arm above.)
182            let p_t = T::from(norm_type).unwrap_or_else(<T as num_traits::One>::one);
183            let mut acc = <T as num_traits::Zero>::zero();
184            for &v in row {
185                acc += v.abs().powf(p_t);
186            }
187            let inv_p = T::from(1.0 / norm_type).unwrap_or_else(<T as num_traits::One>::one);
188            acc.powf(inv_p)
189        };
190        // Widen the native-precision norm to f64 exactly as `.item<double>()`
191        // does (`Embedding.cpp:203`) — only NOW does the value become f64.
192        let norm = num_traits::ToPrimitive::to_f64(&norm_t).unwrap_or(0.0);
193        if norm > max_norm {
194            // Lazily materialise the mutable copy only when a row needs
195            // clipping, so the no-clip case never touches the buffer.
196            let buf = new_data.get_or_insert_with(|| weight_data.to_vec());
197            let scale = max_norm / (norm + 1e-7);
198            let scale_t = T::from(scale).unwrap();
199            for v in &mut buf[row_start..row_start + dim] {
200                *v = *v * scale_t;
201            }
202        }
203    }
204
205    if let Some(buf) = new_data {
206        // SAFETY: `update_data` requires exclusive access to the weight's
207        // storage for the duration of the write. The renorm runs inside the
208        // forward, which holds the only live borrow of `weight_data` (a
209        // `&[T]` over the same Arc); that borrow ends before this call (the
210        // slice is fully consumed into `buf` above). No backward node captures
211        // a mutable view, and the autograd engine is not concurrently reading
212        // the weight: PyTorch performs this exact mutation under
213        // `torch.no_grad()` (`functional.py:2567-2572`), a grad-disabled,
214        // single-threaded in-place edit of the persisted weight. `buf` has
215        // exactly `num_embeddings * dim` elements, matching the tensor's numel.
216        #[allow(
217            clippy::undocumented_unsafe_blocks,
218            reason = "SAFETY comment above documents the exclusive-access invariant; torch embedding_renorm_ mutates weight in place under no_grad (functional.py:2567-2572), matching the optimizer step()'s update_data contract"
219        )]
220        unsafe {
221            weight.update_data(&buf)?;
222        }
223    }
224
225    Ok(())
226}
227
228// ---------------------------------------------------------------------------
229// EmbeddingBackward
230// ---------------------------------------------------------------------------
231
232/// Backward function for the embedding lookup.
233///
234/// Forward: `output[i, :] = weight[indices[i], :]`
235///
236/// VJP: `grad_weight = zeros(num_embeddings, embedding_dim);`
237///       `for i, idx in indices: grad_weight[idx, :] += grad_output[i, :]`
238///
239/// This is a sparse gradient — only accessed rows are non-zero.
240/// Duplicate indices accumulate their corresponding `grad_output` rows.
241#[derive(Debug)]
242pub struct EmbeddingBackward<T: Float> {
243    /// The weight tensor (needed for graph traversal and shape).
244    weight: Tensor<T>,
245    /// Indices used in the forward pass.
246    indices: Vec<usize>,
247    /// Total number of embedding rows.
248    num_embeddings: usize,
249    /// Width of each embedding vector.
250    embedding_dim: usize,
251    /// If set, this row's gradient is always zero.
252    padding_idx: Option<usize>,
253    /// If `true`, divide each row's accumulated gradient by the number of
254    /// times the index appeared in the forward pass — mirrors
255    /// `torch/nn/functional.py:2374-2388`'s `scale_grad_by_freq=True`
256    /// branch. (Closes #1445.)
257    scale_grad_by_freq: bool,
258}
259
260impl<T: Float> GradFn<T> for EmbeddingBackward<T> {
261    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
262        if !is_grad_enabled() {
263            return Ok(vec![None]);
264        }
265
266        let dim = self.embedding_dim;
267
268        // GPU fast path: scatter-add rows entirely on GPU for f32/f64 tensors.
269        if grad_output.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
270            let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
271            let ordinal = match self.weight.device() {
272                Device::Cuda(o) => o,
273                _ => unreachable!(),
274            };
275
276            let indices_f32: Vec<f32> = self.indices.iter().map(|&i| i as f32).collect();
277            let idx_handle = upload_f32_to_gpu(&indices_f32, ordinal)?;
278            let go_handle = grad_output.gpu_handle()?;
279            let f64_path = is_f64::<T>();
280            let elem_size: usize = if f64_path { 8 } else { 4 };
281
282            let mut gw_handle = if f64_path {
283                backend.scatter_add_rows_f64(go_handle, &idx_handle, self.num_embeddings, dim)?
284            } else {
285                backend.scatter_add_rows_f32(go_handle, &idx_handle, self.num_embeddings, dim)?
286            };
287
288            if let Some(pad_idx) = self.padding_idx {
289                let mut gw_bytes = backend.gpu_to_cpu(&gw_handle)?;
290                let start_byte = pad_idx * dim * elem_size;
291                let end_byte = start_byte + dim * elem_size;
292                for b in &mut gw_bytes[start_byte..end_byte] {
293                    *b = 0;
294                }
295                let gw_dtype = if f64_path { DType::F64 } else { DType::F32 };
296                gw_handle = backend.cpu_to_gpu(&gw_bytes, gw_dtype, ordinal)?;
297            }
298
299            let grad_tensor = Tensor::from_storage(
300                TensorStorage::gpu(gw_handle),
301                vec![self.num_embeddings, dim],
302                false,
303            )?;
304            return Ok(vec![Some(grad_tensor)]);
305        }
306
307        if grad_output.is_cuda() {
308            return Err(FerrotorchError::NotImplementedOnCuda {
309                op: "EmbeddingBackward",
310            });
311        }
312
313        let go_data = grad_output.data()?;
314
315        // Allocate a full-size gradient for the weight matrix, initialized to zero.
316        let mut grad_weight = vec![<T as num_traits::Zero>::zero(); self.num_embeddings * dim];
317
318        // Scatter-add: for each index position, accumulate the corresponding
319        // grad_output row into the weight gradient at the accessed index.
320        for (i, &idx) in self.indices.iter().enumerate() {
321            let go_row = &go_data[i * dim..(i + 1) * dim];
322            let gw_row = &mut grad_weight[idx * dim..(idx + 1) * dim];
323            for (gw, &go) in gw_row.iter_mut().zip(go_row.iter()) {
324                *gw += go;
325            }
326        }
327
328        // scale_grad_by_freq: divide each touched row by its appearance
329        // count in the forward pass (mirrors
330        // `torch/nn/functional.py:2374-2388`). Untouched rows have grad
331        // identically zero, so the divide is a no-op there.
332        if self.scale_grad_by_freq {
333            let mut counts: std::collections::HashMap<usize, usize> =
334                std::collections::HashMap::new();
335            for &idx in &self.indices {
336                *counts.entry(idx).or_insert(0) += 1;
337            }
338            for (&idx, &cnt) in &counts {
339                if cnt <= 1 {
340                    continue;
341                }
342                let scale = T::from(1.0 / cnt as f64).unwrap();
343                let row_start = idx * dim;
344                for v in &mut grad_weight[row_start..row_start + dim] {
345                    *v = *v * scale;
346                }
347            }
348        }
349
350        // If padding_idx is set, zero that row's gradient unconditionally.
351        if let Some(pad_idx) = self.padding_idx {
352            let start = pad_idx * dim;
353            for v in &mut grad_weight[start..start + dim] {
354                *v = <T as num_traits::Zero>::zero();
355            }
356        }
357
358        Ok(vec![Some(Tensor::from_storage(
359            TensorStorage::cpu(grad_weight),
360            vec![self.num_embeddings, dim],
361            false,
362        )?)])
363    }
364
365    fn inputs(&self) -> Vec<&Tensor<T>> {
366        vec![&self.weight]
367    }
368
369    fn name(&self) -> &'static str {
370        "EmbeddingBackward"
371    }
372}
373
374// ---------------------------------------------------------------------------
375// EmbeddingBagSumWeightedBackward — sum-mode bag with per_sample_weights
376// ---------------------------------------------------------------------------
377
378/// Backward function for `EmbeddingBag::forward_bag_weighted` in `sum` mode
379/// with `per_sample_weights` supplied. The forward is the scaled
380/// index-select-add (`aten/src/ATen/native/EmbeddingBag.cpp:537-543`):
381///
382/// `output[bag(i)][:] += weight[idx[i]][:] * psw[i]`  (padding samples skipped)
383///
384/// Gradient flows to BOTH the embedding table AND `per_sample_weights`, matching
385/// torch's autograd (`per_sample_weights.requires_grad` is honored at
386/// `EmbeddingBag.cpp:1248-1250`):
387///
388/// - `grad_weight[idx[i]][:] += grad_output[bag(i)][:] * psw[i]`
389///   — the sum-mode `scale = per_sample_weights_data[..]` axpy at
390///   `EmbeddingBag.cpp:1564-1582` (`scale_grad_by_freq` divides by the index
391///   frequency; `mode == SUM` never divides by bag size).
392/// - `grad_psw[i] = dot(grad_output[bag(i)][:], weight[idx[i]][:])`
393///   — `_embedding_bag_per_sample_weights_backward_cpu_template`'s per-sample
394///   `dot_impl(grad[bag], weight[idx])` at `EmbeddingBag.cpp:1716-1724`.
395///
396/// Padding samples (`idx[i] == padding_idx`) contribute 0 to BOTH gradients:
397/// they are skipped in the weight-grad loop (`EmbeddingBag.cpp:1561`) and their
398/// `grad_psw` entry stays at the zero-init (`EmbeddingBag.cpp:1671`, `:1720`).
399#[derive(Debug)]
400struct EmbeddingBagSumWeightedBackward<T: Float> {
401    /// The embedding table (input 0; receives the scatter-add grad).
402    weight: Tensor<T>,
403    /// The per-sample weights (input 1; receives the per-sample dot grad).
404    per_sample_weights: Tensor<T>,
405    /// Flattened embedding indices, one per sample, in forward order.
406    indices: Vec<usize>,
407    /// Bag id for each sample (`offset2bag`): `bag_of[i]` is the output row that
408    /// sample `i` accumulates into.
409    bag_of: Vec<usize>,
410    /// Total number of embedding rows.
411    num_embeddings: usize,
412    /// Width of each embedding vector.
413    embedding_dim: usize,
414    /// If set, samples whose index equals this contribute no gradient.
415    padding_idx: Option<usize>,
416    /// If `true`, each touched weight-row grad is divided by the number of
417    /// times that index appeared in the forward (`EmbeddingBag.cpp:1569-1571`).
418    scale_grad_by_freq: bool,
419}
420
421impl<T: Float> GradFn<T> for EmbeddingBagSumWeightedBackward<T> {
422    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
423        if !is_grad_enabled() {
424            return Ok(vec![None, None]);
425        }
426
427        if grad_output.is_cuda() {
428            return Err(FerrotorchError::NotImplementedOnCuda {
429                op: "EmbeddingBagSumWeightedBackward",
430            });
431        }
432
433        let dim = self.embedding_dim;
434        let go_data = grad_output.data()?;
435        let weight_data = self.weight.data()?;
436        let psw_data = self.per_sample_weights.data()?;
437        let n = self.indices.len();
438
439        // grad to the embedding table: scatter-add the bag's grad row scaled by
440        // the sample's per_sample_weight (EmbeddingBag.cpp:1564-1582).
441        let mut grad_weight = vec![<T as num_traits::Zero>::zero(); self.num_embeddings * dim];
442        // grad to per_sample_weights: dot(grad[bag], weight[idx]) per sample,
443        // zero for padding samples (EmbeddingBag.cpp:1716-1724).
444        let mut grad_psw = vec![<T as num_traits::Zero>::zero(); n];
445
446        // scale_grad_by_freq divisor map (EmbeddingBag.cpp:1522,1569-1571).
447        //
448        // Torch's dense sum/mean backward operates on SORTED indices. It builds
449        // `counts[v]` = occurrences of value `v` over the full index array
450        // (`EmbeddingBag.cpp:1475-1478`), sorts the indices (`:1522`), then walks
451        // the sorted array one UNIQUE index at a time with the stride
452        // `i += counts[sorted[i]]` (`:1499`). For the k-th unique step (counter
453        // `i` in torch's loop, here `k`) it divides that index's grad by
454        // `counts[indices_data[i]]` — i.e. `counts[sorted[k]]`, NOT
455        // `counts[that index's own value]` (`:1569-1571`). Because `k` indexes the
456        // SORTED array, for any input that is not already sorted this divides one
457        // index's grad by a *neighbouring* index's frequency. We replicate this
458        // exactly: `divisor_of_index[v]` = `counts[sorted[k]]` for the unique step
459        // `k` that lands on value `v`. Padding indices participate in `counts`, the
460        // sort, and the unique-step counter `k` (only their grad scatter is skipped
461        // at `:1561`), so they are included here too.
462        let divisor_of_index: Option<std::collections::HashMap<usize, usize>> =
463            if self.scale_grad_by_freq {
464                let mut counts: std::collections::HashMap<usize, usize> =
465                    std::collections::HashMap::new();
466                for &idx in &self.indices {
467                    *counts.entry(idx).or_insert(0) += 1;
468                }
469                let mut sorted = self.indices.clone();
470                sorted.sort_unstable();
471                let mut divisor: std::collections::HashMap<usize, usize> =
472                    std::collections::HashMap::new();
473                let mut i = 0usize; // position in the sorted array
474                let mut k = 0usize; // unique-step counter (torch's loop index `i`)
475                while i < sorted.len() {
476                    let index = sorted[i]; // the value this unique step owns
477                    // The quirk: torch divides by counts[indices_data[k]] = counts[sorted[k]].
478                    let div = counts.get(&sorted[k]).copied().unwrap_or(1);
479                    divisor.insert(index, div);
480                    let stride = counts.get(&index).copied().unwrap_or(1).max(1);
481                    i += stride;
482                    k += 1;
483                }
484                Some(divisor)
485            } else {
486                None
487            };
488
489        for i in 0..n {
490            let idx = self.indices[i];
491            // Padding samples are excluded from BOTH grads (EmbeddingBag.cpp:1561,
492            // :1720): grad_psw[i] stays 0 and no weight-row update happens.
493            if self.padding_idx == Some(idx) {
494                continue;
495            }
496            let bag = self.bag_of[i];
497            let go_row = &go_data[bag * dim..(bag + 1) * dim];
498            let w_row = &weight_data[idx * dim..(idx + 1) * dim];
499            let psw_i = psw_data[i];
500
501            // weight-grad scale: psw, optionally divided by torch's sorted-neighbour
502            // frequency divisor (EmbeddingBag.cpp:1569-1571).
503            let mut w_scale = psw_i;
504            if let Some(d) = &divisor_of_index {
505                if let Some(&div) = d.get(&idx) {
506                    if div > 0 {
507                        w_scale =
508                            w_scale / T::from(div).unwrap_or_else(<T as num_traits::One>::one);
509                    }
510                }
511            }
512            let gw_row = &mut grad_weight[idx * dim..(idx + 1) * dim];
513            for (gw, &go) in gw_row.iter_mut().zip(go_row.iter()) {
514                *gw += go * w_scale;
515            }
516
517            // per_sample_weight grad: dot(grad[bag], weight[idx]). This is the
518            // UNSCALED bag grad against the embedding row — scale_grad_by_freq
519            // only weights the table grad, not the psw grad (it is absent from
520            // the psw-backward kernel at EmbeddingBag.cpp:1716-1724).
521            let mut dot = <T as num_traits::Zero>::zero();
522            for (&go, &w) in go_row.iter().zip(w_row.iter()) {
523                dot += go * w;
524            }
525            grad_psw[i] = dot;
526        }
527
528        let grad_weight_t = Tensor::from_storage(
529            TensorStorage::cpu(grad_weight),
530            vec![self.num_embeddings, dim],
531            false,
532        )?;
533        let grad_psw_t = Tensor::from_storage(
534            TensorStorage::cpu(grad_psw),
535            self.per_sample_weights.shape().to_vec(),
536            false,
537        )?;
538        Ok(vec![Some(grad_weight_t), Some(grad_psw_t)])
539    }
540
541    fn inputs(&self) -> Vec<&Tensor<T>> {
542        vec![&self.weight, &self.per_sample_weights]
543    }
544
545    fn name(&self) -> &'static str {
546        "EmbeddingBagSumWeightedBackward"
547    }
548}
549
550// ---------------------------------------------------------------------------
551// Embedding layer
552// ---------------------------------------------------------------------------
553
554/// A simple lookup table that stores embeddings of a fixed dictionary.
555///
556/// Given a 1-D tensor of integer indices (stored as float values, cast to
557/// `usize`), returns a 2-D tensor `[len, embedding_dim]` by gathering the
558/// corresponding rows from the weight matrix.
559///
560/// # Padding index
561///
562/// If `padding_idx` is set, the embedding vector at that index is always
563/// zero and receives no gradient updates. This is commonly used to
564/// represent a padding token.
565///
566/// # Example
567///
568/// ```ignore
569/// let emb = Embedding::<f32>::new(1000, 64, None)?;
570/// let indices = ferrotorch_core::tensor(&[1.0, 5.0, 3.0])?;
571/// let output = emb.forward(&indices)?;
572/// assert_eq!(output.shape(), &[3, 64]);
573/// ```
574#[derive(Debug)]
575pub struct Embedding<T: Float> {
576    /// The learnable weight matrix, shape `[num_embeddings, embedding_dim]`.
577    pub weight: Parameter<T>,
578    /// Number of entries in the lookup table.
579    pub num_embeddings: usize,
580    /// Dimensionality of each embedding vector.
581    pub embedding_dim: usize,
582    /// If set, this row is kept at zero and receives no gradient.
583    pub padding_idx: Option<usize>,
584    /// If set, every row touched by a forward call is renormalised in-place
585    /// so its `norm_type`-norm is at most `max_norm`, mirroring
586    /// `torch/nn/functional.py:2306-2370` (`_no_grad_embedding_renorm_`).
587    /// Carried as `f64` for the upstream scalar type (kwarg is `float`).
588    /// (Closes #1445.)
589    pub max_norm: Option<f64>,
590    /// Order of the row-norm used when `max_norm` is active. Defaults to
591    /// `2.0` (Euclidean) per `torch/nn/functional.py:2316`. (Closes #1445.)
592    pub norm_type: f64,
593    /// If `true`, `EmbeddingBackward` divides each accumulated row gradient
594    /// by the number of times that index appeared in the forward pass,
595    /// matching `torch/nn/functional.py:2374-2388`. (Closes #1445.)
596    pub scale_grad_by_freq: bool,
597    /// Whether the module is in training mode.
598    training: bool,
599    /// If true, advertise a sparse gradient pattern (the only rows touched
600    /// are the ones actually indexed in the most recent forward call).
601    /// This is purely a flag — autograd still populates a dense grad on
602    /// the weight; callers can extract a `SparseGrad` view via
603    /// [`Self::sparse_grad`] to feed `optim::SparseAdam` or
604    /// `SparseGrad::apply_sgd` without scanning the full dense matrix.
605    /// Mirrors `torch.nn.Embedding(sparse=True)`. (#623)
606    pub sparse: bool,
607    /// Cached unique indices touched by the most recent forward pass. None
608    /// if `sparse == false` or no forward has run yet. We dedupe here so
609    /// callers don't have to coalesce the SparseGrad themselves.
610    last_indices: std::sync::Mutex<Option<Vec<usize>>>,
611}
612
613impl<T: Float> Embedding<T> {
614    /// Create a new embedding layer.
615    ///
616    /// Weight is initialized from N(0, 1). If `padding_idx` is set, that
617    /// row is zeroed after initialization.
618    ///
619    /// # Errors
620    ///
621    /// Returns an error if `padding_idx >= num_embeddings`.
622    pub fn new(
623        num_embeddings: usize,
624        embedding_dim: usize,
625        padding_idx: Option<usize>,
626    ) -> FerrotorchResult<Self> {
627        // Validate padding_idx.
628        if let Some(idx) = padding_idx {
629            if idx >= num_embeddings {
630                return Err(FerrotorchError::InvalidArgument {
631                    message: format!(
632                        "padding_idx {idx} is out of range for num_embeddings {num_embeddings}"
633                    ),
634                });
635            }
636        }
637
638        // Initialize weight from N(0, 1).
639        let mut weight = Parameter::zeros(&[num_embeddings, embedding_dim])?;
640        init::normal(&mut weight, 0.0, 1.0)?;
641
642        // Zero the padding row if requested.
643        if let Some(idx) = padding_idx {
644            let data = weight.data()?.to_vec();
645            let mut new_data = data;
646            let start = idx * embedding_dim;
647            for v in &mut new_data[start..start + embedding_dim] {
648                *v = <T as num_traits::Zero>::zero();
649            }
650            weight = Parameter::new(Tensor::from_storage(
651                TensorStorage::cpu(new_data),
652                vec![num_embeddings, embedding_dim],
653                true,
654            )?);
655        }
656
657        Ok(Self {
658            weight,
659            num_embeddings,
660            embedding_dim,
661            padding_idx,
662            max_norm: None,
663            norm_type: 2.0,
664            scale_grad_by_freq: false,
665            training: true,
666            sparse: false,
667            last_indices: std::sync::Mutex::new(None),
668        })
669    }
670
671    /// Create an embedding layer from an existing weight tensor.
672    ///
673    /// The tensor must have shape `[num_embeddings, embedding_dim]`.
674    pub fn from_pretrained(
675        weight: Tensor<T>,
676        padding_idx: Option<usize>,
677    ) -> FerrotorchResult<Self> {
678        if weight.ndim() != 2 {
679            return Err(FerrotorchError::InvalidArgument {
680                message: format!(
681                    "Embedding weight must be 2-D, got shape {:?}",
682                    weight.shape()
683                ),
684            });
685        }
686        let num_embeddings = weight.shape()[0];
687        let embedding_dim = weight.shape()[1];
688
689        if let Some(idx) = padding_idx {
690            if idx >= num_embeddings {
691                return Err(FerrotorchError::InvalidArgument {
692                    message: format!(
693                        "padding_idx {idx} is out of range for num_embeddings {num_embeddings}"
694                    ),
695                });
696            }
697        }
698
699        Ok(Self {
700            weight: Parameter::new(weight),
701            num_embeddings,
702            embedding_dim,
703            padding_idx,
704            max_norm: None,
705            norm_type: 2.0,
706            scale_grad_by_freq: false,
707            training: true,
708            sparse: false,
709            last_indices: std::sync::Mutex::new(None),
710        })
711    }
712
713    /// Builder: set the maximum row norm. After every forward pass, rows
714    /// of `weight` touched by the input have their `norm_type`-norm clipped
715    /// to `max_norm` via in-place renormalisation, matching
716    /// `torch.nn.Embedding(max_norm=...)`. Closes #1445.
717    pub fn with_max_norm(mut self, max_norm: f64) -> Self {
718        self.max_norm = Some(max_norm);
719        self
720    }
721
722    /// Builder: set the order of the row-norm used by `max_norm` (default
723    /// `2.0`). Closes #1445.
724    pub fn with_norm_type(mut self, norm_type: f64) -> Self {
725        self.norm_type = norm_type;
726        self
727    }
728
729    /// Builder: if `true`, `EmbeddingBackward` divides each touched row's
730    /// gradient by the number of times the index appeared in the forward
731    /// (`torch.nn.Embedding(scale_grad_by_freq=True)`). Closes #1445.
732    pub fn with_scale_grad_by_freq(mut self, scale: bool) -> Self {
733        self.scale_grad_by_freq = scale;
734        self
735    }
736
737    /// Renormalise the rows of `self.weight` that `indices` touched, IN
738    /// PLACE, so each touched row's `norm_type`-norm is at most `max_norm`.
739    ///
740    /// This is a faithful translation of the aten kernel
741    /// `embedding_renorm_cpu_` (`aten/src/ATen/native/Embedding.cpp:181-212`):
742    /// the touched indices are sorted and de-duplicated, and for each unique
743    /// row whose current norm exceeds `max_norm` the row is scaled by
744    /// `max_norm / (norm + 1e-7)`. Rows already within `max_norm` are left
745    /// untouched, and rows never indexed in this forward are not visited.
746    ///
747    /// PyTorch's `F.embedding` (`torch/nn/functional.py:2561-2573`) runs this
748    /// renorm BEFORE the gather, under `torch.no_grad()`, mutating the
749    /// persisted `weight` tensor — so the change survives across forward
750    /// calls. We match that by writing the renormed rows back into
751    /// `self.weight` via [`Tensor::update_data`], the same in-place storage
752    /// mutation the optimizer `step()` uses. The write is performed only when
753    /// at least one row actually exceeded `max_norm`, keeping the common
754    /// "nothing to clip" path allocation-free on the weight buffer.
755    ///
756    /// Returns `Ok(())` when `max_norm` is unset (no-op) or after the
757    /// in-place mutation completes.
758    fn renorm_weight_in_place(&self, indices: &[usize]) -> FerrotorchResult<()> {
759        let Some(max_norm) = self.max_norm else {
760            return Ok(());
761        };
762        renorm_weight_rows_in_place(
763            self.weight.tensor(),
764            indices,
765            self.embedding_dim,
766            max_norm,
767            self.norm_type,
768            "Embedding(max_norm) weight renorm",
769        )
770    }
771
772    /// Toggle the sparse-grad mode. When enabled, [`Self::sparse_grad`]
773    /// returns a `SparseGrad<T>` populated only with the rows actually
774    /// touched by the most recent forward pass. Off by default. Returns
775    /// `&mut self` for chaining.
776    pub fn with_sparse(mut self, sparse: bool) -> Self {
777        self.sparse = sparse;
778        self
779    }
780
781    /// Record the unique row indices touched by the most recent forward pass.
782    /// No-op when sparse mode is off — keeps the hot path zero-overhead for
783    /// the common dense-grad case.
784    fn cache_touched_rows(&self, indices: &[usize]) {
785        if !self.sparse {
786            return;
787        }
788        // Dedupe (sorted) so callers don't have to coalesce later.
789        let mut uniq: Vec<usize> = indices.to_vec();
790        uniq.sort_unstable();
791        uniq.dedup();
792        if let Ok(mut g) = self.last_indices.lock() {
793            *g = Some(uniq);
794        }
795    }
796
797    /// Materialize a [`SparseGrad`] from the current dense weight gradient,
798    /// keyed on the indices touched by the most recent forward pass.
799    ///
800    /// Returns `None` when sparse mode is off, no forward has been run yet,
801    /// or the parameter has no gradient (e.g. before the first backward
802    /// call). The returned grad is already coalesced (each touched row
803    /// appears once with its full gradient slab) — feed it directly into
804    /// [`SparseGrad::apply_sgd`] or `optim::SparseAdam`.
805    ///
806    /// Mirrors PyTorch's `embedding_bag(..., sparse=True)` → `SparseAdam`
807    /// flow. The dense grad is unchanged; `sparse_grad` just provides a
808    /// compact view for optimizers that benefit from skipping zero rows.
809    pub fn sparse_grad(&self) -> FerrotorchResult<Option<ferrotorch_core::SparseGrad<T>>> {
810        if !self.sparse {
811            return Ok(None);
812        }
813        let last = match self.last_indices.lock() {
814            Ok(g) => g,
815            Err(_) => return Ok(None),
816        };
817        let indices = match last.as_ref() {
818            Some(v) => v.clone(),
819            None => return Ok(None),
820        };
821        let grad = match self.weight.tensor().grad()? {
822            Some(g) => g,
823            None => return Ok(None),
824        };
825        let grad_data = grad.data_vec()?;
826        let dim = self.embedding_dim;
827        let mut values = Vec::with_capacity(indices.len() * dim);
828        for &idx in &indices {
829            let row_start = idx * dim;
830            let row_end = row_start + dim;
831            values.extend_from_slice(&grad_data[row_start..row_end]);
832        }
833        let sg = ferrotorch_core::SparseGrad::new(indices, values, vec![dim])?;
834        Ok(Some(sg))
835    }
836}
837
838impl<T: Float> Module<T> for Embedding<T> {
839    /// Forward pass: look up embedding vectors for the given indices.
840    ///
841    /// `input` is an index tensor of ANY shape whose values are non-negative
842    /// integers stored as floats. Each value is cast to `usize` and used to
843    /// index into the weight matrix. The lookup operates on the flattened
844    /// indices (row-major), exactly mirroring upstream `embedding_symint`
845    /// (`aten/src/ATen/native/Embedding.cpp:43-53`):
846    /// `weight.index_select(0, indices.reshape(-1)).view_symint(size)` where
847    /// `size = (*indices.sizes(), weight.size(1))`.
848    ///
849    /// Returns a tensor of shape `(*input.shape(), embedding_dim)`. A 1-D
850    /// index of length `n` therefore yields `[n, embedding_dim]`, and a 2-D
851    /// index `[a, b]` yields `[a, b, embedding_dim]`, matching `F.embedding`.
852    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
853        let dim = self.embedding_dim;
854
855        // Output shape is the index shape with `embedding_dim` appended, per
856        // upstream `embedding_symint` (`Embedding.cpp:48-53`): the gather runs
857        // over the flattened indices and the result is viewed back to
858        // `(*indices.sizes(), weight.size(1))`. A 1-D input keeps the existing
859        // `[n, dim]` behavior (the empty-prefix special-case is implicit).
860        let mut output_shape: Vec<usize> = input.shape().to_vec();
861        output_shape.push(dim);
862
863        // GPU fast path for f32/f64 embeddings: gather rows entirely on GPU.
864        if self.weight.tensor().is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
865            let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
866            let device = self.weight.tensor().device();
867            let ordinal = match device {
868                Device::Cuda(o) => o,
869                _ => unreachable!(),
870            };
871
872            let input_data = input.data_vec()?;
873            let n = input_data.len();
874
875            let mut indices = Vec::with_capacity(n);
876            let mut indices_f32 = Vec::with_capacity(n);
877            for (i, &val) in input_data.iter().enumerate() {
878                let idx = num_traits::ToPrimitive::to_usize(&val).ok_or_else(|| {
879                    FerrotorchError::InvalidArgument {
880                        message: format!(
881                            "Embedding index at position {i} cannot be converted to usize: {val:?}"
882                        ),
883                    }
884                })?;
885                if idx >= self.num_embeddings {
886                    return Err(FerrotorchError::IndexOutOfBounds {
887                        index: idx,
888                        axis: 0,
889                        size: self.num_embeddings,
890                    });
891                }
892                indices.push(idx);
893                indices_f32.push(idx as f32);
894            }
895
896            self.cache_touched_rows(&indices);
897
898            // max_norm with a CUDA weight has no on-device renorm kernel yet;
899            // surface that explicitly rather than silently returning
900            // un-renormed rows (which would diverge from torch's in-place
901            // mutation at functional.py:2561-2573). No-op when max_norm unset.
902            self.renorm_weight_in_place(&indices)?;
903
904            let idx_handle = upload_f32_to_gpu(&indices_f32, ordinal)?;
905            let weight_handle = self.weight.tensor().gpu_handle()?;
906
907            let output_handle = if is_f64::<T>() {
908                backend.embed_lookup_batch_f64(&idx_handle, weight_handle, n, dim)?
909            } else {
910                backend.embed_lookup_batch_f32(&idx_handle, weight_handle, n, dim)?
911            };
912
913            // Padding index: if set, zero the corresponding output rows on GPU.
914            // For padding_idx, the weight row should already be zero, so output
915            // rows at padding positions should already be zero. Be defensive
916            // only if padding_idx is actually referenced.
917            // (The weight is zeroed at init, so we skip extra GPU work here.)
918
919            let storage = TensorStorage::gpu(output_handle);
920
921            if self.weight.requires_grad() && is_grad_enabled() {
922                let grad_fn = Arc::new(EmbeddingBackward {
923                    weight: self.weight.tensor().clone(),
924                    indices,
925                    num_embeddings: self.num_embeddings,
926                    embedding_dim: dim,
927                    padding_idx: self.padding_idx,
928                    scale_grad_by_freq: self.scale_grad_by_freq,
929                });
930                return Tensor::from_operation(storage, output_shape, grad_fn);
931            } else {
932                return Tensor::from_storage(storage, output_shape, false);
933            }
934        }
935
936        // CPU path — non-f32 GPU tensors have no GPU kernel, error out.
937        if self.weight.tensor().is_cuda() {
938            return Err(FerrotorchError::NotImplementedOnCuda { op: "Embedding" });
939        }
940        let input_data = input.data_vec()?;
941        let n = input_data.len();
942
943        // Convert float indices to usize and validate bounds.
944        let mut indices = Vec::with_capacity(n);
945        for (i, &val) in input_data.iter().enumerate() {
946            let idx = num_traits::ToPrimitive::to_usize(&val).ok_or_else(|| {
947                FerrotorchError::InvalidArgument {
948                    message: format!(
949                        "Embedding index at position {i} cannot be converted to usize: {val:?}"
950                    ),
951                }
952            })?;
953            if idx >= self.num_embeddings {
954                return Err(FerrotorchError::IndexOutOfBounds {
955                    index: idx,
956                    axis: 0,
957                    size: self.num_embeddings,
958                });
959            }
960            indices.push(idx);
961        }
962
963        self.cache_touched_rows(&indices);
964
965        // max_norm: renormalise the touched rows of the PERSISTED weight
966        // IN PLACE, BEFORE the gather. This mirrors
967        // `torch/nn/functional.py:2561-2573`, where `F.embedding` calls
968        // `_no_grad_embedding_renorm_(weight, ...)` (which mutates `weight`
969        // via `torch.embedding_renorm_`) and only THEN does the lookup.
970        // The mutation persists across forward calls — a second forward with
971        // the same indices is a no-op because the rows now satisfy max_norm.
972        // Closes #1445 (CPU path).
973        self.renorm_weight_in_place(&indices)?;
974
975        // Re-read the (possibly mutated) weight buffer for the gather.
976        let cpu_weight = self.weight.tensor().clone();
977        let weight_data = cpu_weight.data()?;
978
979        // Gather rows from weight.
980        let mut output_data = Vec::with_capacity(n * dim);
981        for &idx in &indices {
982            let row_start = idx * dim;
983            output_data.extend_from_slice(&weight_data[row_start..row_start + dim]);
984        }
985
986        // If padding_idx is set, ensure those rows are zeros in the output
987        // (they should already be zero in the weight, but be defensive).
988        if let Some(pad_idx) = self.padding_idx {
989            for (i, &idx) in indices.iter().enumerate() {
990                if idx == pad_idx {
991                    let start = i * dim;
992                    for v in &mut output_data[start..start + dim] {
993                        *v = <T as num_traits::Zero>::zero();
994                    }
995                }
996            }
997        }
998
999        // Output device matches the weight's device (GPU if model is on GPU).
1000        let device = self.weight.tensor().device();
1001
1002        // Build storage on the target device first, then attach grad_fn.
1003        // This avoids to() stripping the grad_fn by creating a leaf tensor.
1004        let storage = if device.is_cuda() {
1005            let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
1006            let ordinal = match device {
1007                Device::Cuda(o) => o,
1008                _ => unreachable!(),
1009            };
1010            // SAFETY: `output_data` is a live owned `Vec<T>` whose contents we borrow
1011            // shared for the duration of this expression. Its underlying buffer is valid
1012            // for reads of `output_data.len() * size_of::<T>()` bytes — `T: Float`
1013            // is one of f32/f64/bf16/f16, none of which have padding bytes (no struct
1014            // wrappers, no niches), so the byte-length calculation is exact. The cast
1015            // `*const T` -> `*const u8` does not violate alignment because `u8`'s
1016            // alignment (1) is at most `T`'s alignment. The resulting `&[u8]` is
1017            // consumed by `backend.cpu_to_gpu` before `output_data` is moved into
1018            // `TensorStorage::cpu` on the else branch (mutually exclusive paths) or
1019            // dropped here, so the borrow never outlives the source.
1020            let bytes: &[u8] = unsafe {
1021                std::slice::from_raw_parts(
1022                    output_data.as_ptr() as *const u8,
1023                    output_data.len() * std::mem::size_of::<T>(),
1024                )
1025            };
1026            let handle = backend.cpu_to_gpu(bytes, T::dtype(), ordinal)?;
1027            TensorStorage::gpu(handle)
1028        } else {
1029            TensorStorage::cpu(output_data)
1030        };
1031
1032        if self.weight.requires_grad() && is_grad_enabled() {
1033            let grad_fn = Arc::new(EmbeddingBackward {
1034                weight: self.weight.tensor().clone(),
1035                indices,
1036                num_embeddings: self.num_embeddings,
1037                embedding_dim: dim,
1038                padding_idx: self.padding_idx,
1039                scale_grad_by_freq: self.scale_grad_by_freq,
1040            });
1041            Tensor::from_operation(storage, output_shape, grad_fn)
1042        } else {
1043            Tensor::from_storage(storage, output_shape, false)
1044        }
1045    }
1046
1047    fn parameters(&self) -> Vec<&Parameter<T>> {
1048        vec![&self.weight]
1049    }
1050
1051    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1052        vec![&mut self.weight]
1053    }
1054
1055    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1056        vec![("weight".to_string(), &self.weight)]
1057    }
1058
1059    fn train(&mut self) {
1060        self.training = true;
1061    }
1062
1063    fn eval(&mut self) {
1064        self.training = false;
1065    }
1066
1067    fn is_training(&self) -> bool {
1068        self.training
1069    }
1070}
1071
1072// ---------------------------------------------------------------------------
1073// EmbeddingBag — fused lookup + reduce
1074// ---------------------------------------------------------------------------
1075
1076/// Reduction mode for [`EmbeddingBag`].
1077#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1078pub enum EmbeddingBagMode {
1079    /// Sum all embeddings in each bag.
1080    Sum,
1081    /// Mean of all embeddings in each bag.
1082    Mean,
1083    /// Element-wise max across embeddings in each bag.
1084    Max,
1085}
1086
1087/// Computes sums or means of bags of embeddings without instantiating the
1088/// full intermediate embeddings. This is more efficient than `Embedding`
1089/// followed by a reduction for variable-length sequences.
1090///
1091/// # Input format
1092///
1093/// - `input`: 1-D tensor of indices [total_indices]
1094/// - `offsets`: 1-D tensor [num_bags] giving the start index of each bag
1095///   in `input`. Must be sorted and non-negative. Example: if `input` has
1096///   indices for 3 bags with lengths [2, 3, 1], then `offsets = [0, 2, 5]`.
1097///
1098/// # Modes
1099///
1100/// - `Sum`: output[b] = sum of weight[input[offsets[b]:offsets[b+1]]]
1101/// - `Mean`: output[b] = mean of weight[input[offsets[b]:offsets[b+1]]]
1102/// - `Max`: output[b] = element-wise max of weight[input[offsets[b]:offsets[b+1]]]
1103#[derive(Debug)]
1104#[allow(
1105    clippy::struct_excessive_bools,
1106    reason = "scale_grad_by_freq/sparse/include_last_offset/training each mirror a distinct torch.nn.EmbeddingBag kwarg (sparse.py:376-380) — R-DEV-2 requires matching the upstream Python API surface field-for-field, so collapsing them into a flags enum would diverge from the user-facing kwarg contract"
1107)]
1108pub struct EmbeddingBag<T: Float> {
1109    weight: Parameter<T>,
1110    num_embeddings: usize,
1111    embedding_dim: usize,
1112    mode: EmbeddingBagMode,
1113    training: bool,
1114    /// If set, each touched weight row is renormalised in place to at most
1115    /// `max_norm` under the `norm_type`-norm before the bag reduction,
1116    /// mirroring `torch.nn.EmbeddingBag(max_norm=...)`
1117    /// (`torch/nn/modules/sparse.py:374`, `functional.py:2766-2771`).
1118    pub max_norm: Option<f64>,
1119    /// Order of the row-norm used when `max_norm` is active. Defaults to
1120    /// `2.0` per `torch/nn/modules/sparse.py:375`.
1121    pub norm_type: f64,
1122    /// If `true`, future gradient accumulation scales each row by the inverse
1123    /// frequency of its index in the mini-batch. Carried to mirror the
1124    /// upstream kwarg (`sparse.py:376`); `max` mode forbids it
1125    /// (`functional.py:2755-2758`).
1126    pub scale_grad_by_freq: bool,
1127    /// Advertises a sparse-gradient pattern, mirroring
1128    /// `torch.nn.EmbeddingBag(sparse=True)` (`sparse.py:378`). `max` mode
1129    /// forbids it (`functional.py:2760-2761`).
1130    pub sparse: bool,
1131    /// When `true`, `offsets` has `num_bags + 1` entries and its last entry
1132    /// is the total index count (CSR-style), mirroring
1133    /// `torch.nn.EmbeddingBag(include_last_offset=True)` (`sparse.py:380`,
1134    /// `functional.py:2621-2624`).
1135    pub include_last_offset: bool,
1136    /// If set, indices equal to `padding_idx` are excluded from each bag's
1137    /// reduction (and the mean divisor), and the corresponding weight row is
1138    /// zeroed at construction — matching `torch.nn.EmbeddingBag(padding_idx)`
1139    /// (`sparse.py:381`, `aten/src/ATen/native/EmbeddingBag.cpp:140-156`).
1140    pub padding_idx: Option<usize>,
1141}
1142
1143impl<T: Float> EmbeddingBag<T> {
1144    /// Create a new EmbeddingBag with default kwargs (no `max_norm`,
1145    /// `norm_type = 2.0`, `scale_grad_by_freq = false`, `sparse = false`,
1146    /// `include_last_offset = false`, no `padding_idx`), matching the
1147    /// `torch.nn.EmbeddingBag(num_embeddings, embedding_dim, mode=...)`
1148    /// defaults at `torch/nn/modules/sparse.py:370-381`.
1149    pub fn new(
1150        num_embeddings: usize,
1151        embedding_dim: usize,
1152        mode: EmbeddingBagMode,
1153    ) -> FerrotorchResult<Self> {
1154        Self::new_with(num_embeddings, embedding_dim, mode, None)
1155    }
1156
1157    /// Create a new EmbeddingBag, optionally with a `padding_idx`.
1158    ///
1159    /// Mirrors the `padding_idx` validation + zero-fill in
1160    /// `torch.nn.EmbeddingBag.__init__` / `_fill_padding_idx_with_zero`
1161    /// (`torch/nn/modules/sparse.py:392-423`): `padding_idx` must be within
1162    /// `num_embeddings`, and that weight row is zeroed after init.
1163    ///
1164    /// # Errors
1165    ///
1166    /// Returns an error if `padding_idx >= num_embeddings`.
1167    pub fn new_with(
1168        num_embeddings: usize,
1169        embedding_dim: usize,
1170        mode: EmbeddingBagMode,
1171        padding_idx: Option<usize>,
1172    ) -> FerrotorchResult<Self> {
1173        if let Some(idx) = padding_idx {
1174            if idx >= num_embeddings {
1175                return Err(FerrotorchError::InvalidArgument {
1176                    message: format!(
1177                        "padding_idx {idx} must be within num_embeddings {num_embeddings}"
1178                    ),
1179                });
1180            }
1181        }
1182
1183        let mut weight = Parameter::zeros(&[num_embeddings, embedding_dim])?;
1184        init::normal(&mut weight, 0.0, 1.0)?;
1185
1186        // Zero the padding row if requested (mirrors
1187        // `_fill_padding_idx_with_zero`, sparse.py:420-423).
1188        if let Some(idx) = padding_idx {
1189            let data = weight.data()?.to_vec();
1190            let mut new_data = data;
1191            let start = idx * embedding_dim;
1192            for v in &mut new_data[start..start + embedding_dim] {
1193                *v = <T as num_traits::Zero>::zero();
1194            }
1195            weight = Parameter::new(Tensor::from_storage(
1196                TensorStorage::cpu(new_data),
1197                vec![num_embeddings, embedding_dim],
1198                true,
1199            )?);
1200        }
1201
1202        Ok(Self {
1203            weight,
1204            num_embeddings,
1205            embedding_dim,
1206            mode,
1207            training: true,
1208            max_norm: None,
1209            norm_type: 2.0,
1210            scale_grad_by_freq: false,
1211            sparse: false,
1212            include_last_offset: false,
1213            padding_idx,
1214        })
1215    }
1216
1217    /// Builder: set the maximum row norm. Touched rows of `weight` have their
1218    /// `norm_type`-norm clipped to `max_norm` in place before each bag
1219    /// reduction, mirroring `torch.nn.EmbeddingBag(max_norm=...)`. Closes #1445.
1220    pub fn with_max_norm(mut self, max_norm: f64) -> Self {
1221        self.max_norm = Some(max_norm);
1222        self
1223    }
1224
1225    /// Builder: set the order of the row-norm used by `max_norm` (default
1226    /// `2.0`, `sparse.py:375`). Closes #1445.
1227    pub fn with_norm_type(mut self, norm_type: f64) -> Self {
1228        self.norm_type = norm_type;
1229        self
1230    }
1231
1232    /// Builder: set `scale_grad_by_freq` (`sparse.py:376`). Rejected for
1233    /// `max` mode by [`Self::forward_bag`], matching `functional.py:2755-2758`.
1234    /// Closes #1445.
1235    pub fn with_scale_grad_by_freq(mut self, scale: bool) -> Self {
1236        self.scale_grad_by_freq = scale;
1237        self
1238    }
1239
1240    /// Builder: set `sparse` (`sparse.py:378`). Rejected for `max` mode by
1241    /// [`Self::forward_bag`], matching `functional.py:2760-2761`. Closes #1445.
1242    pub fn with_sparse(mut self, sparse: bool) -> Self {
1243        self.sparse = sparse;
1244        self
1245    }
1246
1247    /// Builder: set `include_last_offset` (`sparse.py:380`). When `true`,
1248    /// `offsets` carries `num_bags + 1` entries with the last being the total
1249    /// index count, matching the CSR convention in `functional.py:2621-2624`.
1250    /// Closes #1445.
1251    pub fn with_include_last_offset(mut self, include_last_offset: bool) -> Self {
1252        self.include_last_offset = include_last_offset;
1253        self
1254    }
1255
1256    /// Forward pass: compute bag-reduced embeddings.
1257    ///
1258    /// `input`: 1-D tensor of indices `[total_indices]`.
1259    /// `offsets`: bag start offsets. When `include_last_offset == false`,
1260    /// this has `num_bags` entries (bag `b` spans `offsets[b]..offsets[b+1]`,
1261    /// the last bag running to the end of `input`). When
1262    /// `include_last_offset == true`, it has `num_bags + 1` entries with the
1263    /// final entry being the total index count (CSR style), matching
1264    /// `torch/nn/functional.py:2621-2624`.
1265    ///
1266    /// Honors `max_norm` (in-place weight renorm before the reduction,
1267    /// mirroring `functional.py:2766-2771`) and `padding_idx` (indices equal
1268    /// to it are excluded from both the reduction and the mean divisor,
1269    /// mirroring `aten/src/ATen/native/EmbeddingBag.cpp:140-156`). `max` mode
1270    /// rejects `scale_grad_by_freq` / `sparse`
1271    /// (`functional.py:2755-2761`).
1272    ///
1273    /// This is the unweighted path (`per_sample_weights = None`); it delegates
1274    /// to [`Self::forward_bag_weighted`] so the two share a single reduction
1275    /// body. The unweighted forward returns a non-grad-tracked tensor (per-bag
1276    /// backward for the plain reductions is tracked separately).
1277    pub fn forward_bag(&self, input: &Tensor<T>, offsets: &[usize]) -> FerrotorchResult<Tensor<T>> {
1278        self.forward_bag_weighted(input, offsets, None)
1279    }
1280
1281    /// Forward pass with optional `per_sample_weights`, mirroring
1282    /// `F.embedding_bag(input, weight, offsets, ..., per_sample_weights=...)`
1283    /// (`torch/nn/functional.py:2576-2791`).
1284    ///
1285    /// When `per_sample_weights` is `Some(psw)`:
1286    /// - It is ONLY valid for `mode == Sum`; any other mode returns torch's
1287    ///   exact `NotImplementedError` text (`functional.py:2773-2778`).
1288    /// - `psw` must have the same shape as `input` (`functional.py:2698-2702`).
1289    /// - Each gathered embedding row is scaled by its sample weight BEFORE the
1290    ///   sum reduction (`output[bag][:] += weight[idx][:] * psw[i]`,
1291    ///   `EmbeddingBag.cpp:537-543`).
1292    /// - The output is grad-tracked: gradient flows to BOTH `weight` and
1293    ///   `psw` via [`EmbeddingBagSumWeightedBackward`], matching torch's
1294    ///   autograd. `padding_idx` samples contribute 0 to the reduction and to
1295    ///   both gradients.
1296    ///
1297    /// When `per_sample_weights` is `None` this is the plain unweighted
1298    /// reduction (sum / mean / max) and returns a non-grad tensor — identical
1299    /// to the historical [`Self::forward_bag`] behavior.
1300    pub fn forward_bag_weighted(
1301        &self,
1302        input: &Tensor<T>,
1303        offsets: &[usize],
1304        per_sample_weights: Option<&Tensor<T>>,
1305    ) -> FerrotorchResult<Tensor<T>> {
1306        if input.ndim() != 1 {
1307            return Err(FerrotorchError::InvalidArgument {
1308                message: format!("EmbeddingBag input must be 1-D, got {:?}", input.shape()),
1309            });
1310        }
1311
1312        // per_sample_weights is only supported for mode='sum' — torch raises a
1313        // NotImplementedError with this exact text (functional.py:2773-2778).
1314        // Validate this BEFORE the shape check matches torch's ordering only
1315        // loosely, but both are user-facing errors; we surface the mode error
1316        // first since it is the dominant constraint for this feature.
1317        if let Some(psw) = per_sample_weights {
1318            if self.mode != EmbeddingBagMode::Sum {
1319                let mode_str = match self.mode {
1320                    EmbeddingBagMode::Sum => "sum",
1321                    EmbeddingBagMode::Mean => "mean",
1322                    EmbeddingBagMode::Max => "max",
1323                };
1324                return Err(FerrotorchError::InvalidArgument {
1325                    message: format!(
1326                        "embedding_bag: per_sample_weights was not None. per_sample_weights is \
1327                         only supported for mode='sum' (got mode='{mode_str}'). Please open a \
1328                         feature request on GitHub."
1329                    ),
1330                });
1331            }
1332            // psw must have exactly the same shape as input (functional.py:2698).
1333            if psw.shape() != input.shape() {
1334                return Err(FerrotorchError::InvalidArgument {
1335                    message: format!(
1336                        "embedding_bag: If per_sample_weights ({:?}) is not None, then it must \
1337                         have the same shape as the input ({:?})",
1338                        psw.shape(),
1339                        input.shape()
1340                    ),
1341                });
1342            }
1343        }
1344
1345        // mode='max' forbids scale_grad_by_freq and sparse, matching
1346        // functional.py:2755-2761.
1347        if self.mode == EmbeddingBagMode::Max {
1348            if self.scale_grad_by_freq {
1349                return Err(FerrotorchError::InvalidArgument {
1350                    message: "max mode does not support scaling the gradient by the frequency"
1351                        .into(),
1352                });
1353            }
1354            if self.sparse {
1355                return Err(FerrotorchError::InvalidArgument {
1356                    message: "max mode does not support sparse weights".into(),
1357                });
1358            }
1359        }
1360
1361        let input_data = input.data_vec()?;
1362        let dim = self.embedding_dim;
1363        let total = input_data.len();
1364
1365        // Materialise the bag boundaries from `offsets`, honoring
1366        // include_last_offset (CSR layout: trailing entry == total count).
1367        let num_bags = if self.include_last_offset {
1368            offsets.len().saturating_sub(1)
1369        } else {
1370            offsets.len()
1371        };
1372
1373        // Validate + collect indices.
1374        let mut indices = Vec::with_capacity(total);
1375        for (i, &val) in input_data.iter().enumerate() {
1376            let idx = num_traits::ToPrimitive::to_usize(&val).ok_or_else(|| {
1377                FerrotorchError::InvalidArgument {
1378                    message: format!("EmbeddingBag index {i} invalid: {val:?}"),
1379                }
1380            })?;
1381            if idx >= self.num_embeddings {
1382                return Err(FerrotorchError::IndexOutOfBounds {
1383                    index: idx,
1384                    axis: 0,
1385                    size: self.num_embeddings,
1386                });
1387            }
1388            indices.push(idx);
1389        }
1390
1391        // max_norm: renormalise the touched rows of the persisted weight IN
1392        // PLACE before the reduction (functional.py:2766-2771 runs the renorm
1393        // before torch.embedding_bag). No-op when max_norm unset.
1394        if let Some(max_norm) = self.max_norm {
1395            renorm_weight_rows_in_place(
1396                self.weight.tensor(),
1397                &indices,
1398                dim,
1399                max_norm,
1400                self.norm_type,
1401                "EmbeddingBag(max_norm) weight renorm",
1402            )?;
1403        }
1404
1405        // per_sample_weights data + a `bag_of` map (offset2bag) — both only
1406        // materialised when psw is present (psw is sum-mode-only, validated
1407        // above). `bag_of[i]` is the output row sample `i` accumulates into,
1408        // mirroring torch's `offset2bag` (`EmbeddingBag.cpp:1563`).
1409        let psw_data: Option<Vec<T>> = match per_sample_weights {
1410            Some(psw) => Some(psw.data_vec()?),
1411            None => None,
1412        };
1413        let mut bag_of: Vec<usize> = vec![0; total];
1414
1415        // Re-read the (possibly renormed) weight for the reduction.
1416        let weight_data = self.weight.tensor().data()?;
1417
1418        let mut output = vec![<T as num_traits::Zero>::zero(); num_bags * dim];
1419
1420        for b in 0..num_bags {
1421            let start = offsets[b];
1422            // With include_last_offset, every bag (including the last) reads
1423            // its end from offsets[b+1]; otherwise the final bag runs to total.
1424            let end = if self.include_last_offset || b + 1 < num_bags {
1425                offsets[b + 1]
1426            } else {
1427                total
1428            };
1429
1430            // Record offset2bag for every sample in this bag so the weighted
1431            // backward (when psw is present) can map sample -> bag grad row.
1432            for s in bag_of.iter_mut().take(end).skip(start) {
1433                *s = b;
1434            }
1435
1436            match self.mode {
1437                EmbeddingBagMode::Sum | EmbeddingBagMode::Mean => {
1438                    // Count of non-padding entries; the mean divides by this,
1439                    // mirroring the bag_size decrement at EmbeddingBag.cpp:151-156.
1440                    let mut count: usize = 0;
1441                    let out_start = b * dim;
1442                    for s in start..end {
1443                        let idx = indices[s];
1444                        // padding_idx entries are excluded from the reduction
1445                        // (EmbeddingBag.cpp:147 `if (idx != padding_idx)`).
1446                        if self.padding_idx == Some(idx) {
1447                            continue;
1448                        }
1449                        let row_start = idx * dim;
1450                        // per_sample_weights scale (sum-mode only): each gathered
1451                        // row is multiplied by its sample weight BEFORE the sum
1452                        // (EmbeddingBag.cpp:540-543). `None` => scale of 1.
1453                        match &psw_data {
1454                            Some(pw) => {
1455                                let scale = pw[s];
1456                                for d in 0..dim {
1457                                    output[out_start + d] += weight_data[row_start + d] * scale;
1458                                }
1459                            }
1460                            None => {
1461                                for d in 0..dim {
1462                                    output[out_start + d] += weight_data[row_start + d];
1463                                }
1464                            }
1465                        }
1466                        count += 1;
1467                    }
1468                    if self.mode == EmbeddingBagMode::Mean && count > 0 {
1469                        let scale = T::from(count).unwrap();
1470                        for d in 0..dim {
1471                            output[out_start + d] = output[out_start + d] / scale;
1472                        }
1473                    }
1474                }
1475                EmbeddingBagMode::Max => {
1476                    let out_start = b * dim;
1477                    // Initialize with -inf; an all-padding (or empty) bag stays
1478                    // at zero (torch leaves max-mode empty bags at zero too).
1479                    let mut any = false;
1480                    for d in 0..dim {
1481                        output[out_start + d] = T::neg_infinity();
1482                    }
1483                    for &idx in &indices[start..end] {
1484                        if self.padding_idx == Some(idx) {
1485                            continue;
1486                        }
1487                        any = true;
1488                        let row_start = idx * dim;
1489                        for d in 0..dim {
1490                            let val = weight_data[row_start + d];
1491                            if val > output[out_start + d] {
1492                                output[out_start + d] = val;
1493                            }
1494                        }
1495                    }
1496                    if !any {
1497                        for d in 0..dim {
1498                            output[out_start + d] = <T as num_traits::Zero>::zero();
1499                        }
1500                    }
1501                }
1502            }
1503        }
1504
1505        let storage = TensorStorage::cpu(output);
1506        let out_shape = vec![num_bags, dim];
1507
1508        // When per_sample_weights is supplied (sum mode, validated above) and
1509        // either the weight or the psw requires grad, attach the weighted
1510        // backward so gradient flows to BOTH inputs (EmbeddingBag.cpp:1248-1250
1511        // honors per_sample_weights.requires_grad).
1512        if let Some(psw) = per_sample_weights {
1513            let weight_t = self.weight.tensor();
1514            let needs_grad = is_grad_enabled() && (weight_t.requires_grad() || psw.requires_grad());
1515            if needs_grad {
1516                let grad_fn = Arc::new(EmbeddingBagSumWeightedBackward {
1517                    weight: weight_t.clone(),
1518                    per_sample_weights: psw.clone(),
1519                    indices,
1520                    bag_of,
1521                    num_embeddings: self.num_embeddings,
1522                    embedding_dim: dim,
1523                    padding_idx: self.padding_idx,
1524                    scale_grad_by_freq: self.scale_grad_by_freq,
1525                });
1526                return Tensor::from_operation(storage, out_shape, grad_fn);
1527            }
1528        }
1529
1530        // Unweighted path (or grad disabled): non-grad tensor, matching the
1531        // historical forward_bag behavior.
1532        Tensor::from_storage(storage, out_shape, false)
1533    }
1534
1535    /// Number of embeddings in the table.
1536    pub fn num_embeddings(&self) -> usize {
1537        self.num_embeddings
1538    }
1539
1540    /// Dimension of each embedding vector.
1541    pub fn embedding_dim(&self) -> usize {
1542        self.embedding_dim
1543    }
1544
1545    /// The reduction mode.
1546    pub fn mode(&self) -> EmbeddingBagMode {
1547        self.mode
1548    }
1549
1550    /// The `padding_idx`, if set. Indices equal to it are excluded from each
1551    /// bag's reduction. Mirrors `torch.nn.EmbeddingBag.padding_idx`.
1552    pub fn padding_idx(&self) -> Option<usize> {
1553        self.padding_idx
1554    }
1555}
1556
1557impl<T: Float> Module<T> for EmbeddingBag<T> {
1558    /// Forward pass using the input as both indices and offsets.
1559    ///
1560    /// If input is 2-D [num_bags, bag_size], each row is a bag.
1561    /// If input is 1-D, treats the entire input as a single bag.
1562    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1563        if input.ndim() == 2 {
1564            // 2D input: [num_bags, bag_size] — each row is a fixed-length bag.
1565            // torch forces include_last_offset=False for 2D input
1566            // (functional.py:2735); we build offsets in whichever convention
1567            // `forward_bag` will read, so a configured `include_last_offset`
1568            // flag stays consistent here too.
1569            let shape = input.shape();
1570            let num_bags = shape[0];
1571            let bag_size = shape[1];
1572            let mut offsets: Vec<usize> = (0..num_bags).map(|b| b * bag_size).collect();
1573            if self.include_last_offset {
1574                // CSR layout: trailing entry is the total index count.
1575                offsets.push(num_bags * bag_size);
1576            }
1577            let flat = input.view_reshape(vec![num_bags * bag_size])?;
1578            self.forward_bag(&flat, &offsets)
1579        } else if input.ndim() == 1 {
1580            // 1D input: single bag. With include_last_offset the CSR boundary
1581            // is [0, total]; otherwise a single [0] start offset.
1582            if self.include_last_offset {
1583                let total = input.shape()[0];
1584                self.forward_bag(input, &[0, total])
1585            } else {
1586                self.forward_bag(input, &[0])
1587            }
1588        } else {
1589            Err(FerrotorchError::InvalidArgument {
1590                message: format!(
1591                    "EmbeddingBag input must be 1-D or 2-D, got {:?}",
1592                    input.shape()
1593                ),
1594            })
1595        }
1596    }
1597
1598    fn parameters(&self) -> Vec<&Parameter<T>> {
1599        vec![&self.weight]
1600    }
1601
1602    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1603        vec![&mut self.weight]
1604    }
1605
1606    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1607        vec![("weight".to_string(), &self.weight)]
1608    }
1609
1610    fn train(&mut self) {
1611        self.training = true;
1612    }
1613
1614    fn eval(&mut self) {
1615        self.training = false;
1616    }
1617
1618    fn is_training(&self) -> bool {
1619        self.training
1620    }
1621}
1622
1623// ---------------------------------------------------------------------------
1624// Tests
1625// ---------------------------------------------------------------------------
1626
1627#[cfg(test)]
1628mod tests {
1629    use super::*;
1630    use ferrotorch_core::autograd::graph::backward;
1631    use ferrotorch_core::storage::TensorStorage;
1632
1633    /// Helper: create a 1-D tensor of float indices.
1634    fn index_tensor(indices: &[f32]) -> Tensor<f32> {
1635        Tensor::from_storage(
1636            TensorStorage::cpu(indices.to_vec()),
1637            vec![indices.len()],
1638            false,
1639        )
1640        .unwrap()
1641    }
1642
1643    // --- Forward tests ---
1644
1645    #[test]
1646    fn test_forward_shape() {
1647        let emb = Embedding::<f32>::new(10, 4, None).unwrap();
1648        let indices = index_tensor(&[0.0, 3.0, 7.0]);
1649        let output = emb.forward(&indices).unwrap();
1650        assert_eq!(output.shape(), &[3, 4]);
1651    }
1652
1653    #[test]
1654    fn test_forward_correct_values() {
1655        // Build an embedding with known weights.
1656        let weight_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
1657        let weight =
1658            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![4, 3], true).unwrap();
1659        let emb = Embedding::from_pretrained(weight, None).unwrap();
1660
1661        // Look up rows 2 and 0.
1662        let indices = index_tensor(&[2.0, 0.0]);
1663        let output = emb.forward(&indices).unwrap();
1664        let data = output.data().unwrap();
1665
1666        // Row 2 = [6, 7, 8], Row 0 = [0, 1, 2]
1667        assert_eq!(data.len(), 6);
1668        assert!((data[0] - 6.0).abs() < 1e-6);
1669        assert!((data[1] - 7.0).abs() < 1e-6);
1670        assert!((data[2] - 8.0).abs() < 1e-6);
1671        assert!((data[3] - 0.0).abs() < 1e-6);
1672        assert!((data[4] - 1.0).abs() < 1e-6);
1673        assert!((data[5] - 2.0).abs() < 1e-6);
1674    }
1675
1676    #[test]
1677    fn test_forward_single_index() {
1678        let emb = Embedding::<f32>::new(5, 8, None).unwrap();
1679        let indices = index_tensor(&[3.0]);
1680        let output = emb.forward(&indices).unwrap();
1681        assert_eq!(output.shape(), &[1, 8]);
1682    }
1683
1684    // --- Padding index tests ---
1685
1686    #[test]
1687    #[allow(clippy::needless_range_loop)]
1688    fn test_padding_idx_zeros() {
1689        let emb = Embedding::<f32>::new(5, 3, Some(2)).unwrap();
1690
1691        // The padding row in the weight should be zero.
1692        let w_data = emb.weight.data().unwrap();
1693        let pad_start = 2 * 3;
1694        for i in 0..3 {
1695            assert!(
1696                (w_data[pad_start + i] - 0.0).abs() < 1e-6,
1697                "padding row weight[2][{i}] should be 0, got {}",
1698                w_data[pad_start + i]
1699            );
1700        }
1701
1702        // Forward with the padding index should return zeros.
1703        let indices = index_tensor(&[2.0]);
1704        let output = emb.forward(&indices).unwrap();
1705        let data = output.data().unwrap();
1706        for i in 0..3 {
1707            assert!(
1708                (data[i] - 0.0).abs() < 1e-6,
1709                "padding output[0][{i}] should be 0, got {}",
1710                data[i]
1711            );
1712        }
1713    }
1714
1715    #[test]
1716    fn test_padding_idx_mixed() {
1717        // Build known weights, set padding_idx=1.
1718        let weight_data: Vec<f32> = vec![
1719            1.0, 2.0, // row 0
1720            0.0, 0.0, // row 1 (padding — will be zeroed)
1721            5.0, 6.0, // row 2
1722        ];
1723        let weight =
1724            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![3, 2], true).unwrap();
1725        let emb = Embedding::from_pretrained(weight, Some(1)).unwrap();
1726
1727        let indices = index_tensor(&[0.0, 1.0, 2.0]);
1728        let output = emb.forward(&indices).unwrap();
1729        let data = output.data().unwrap();
1730
1731        // Row 0: [1, 2]
1732        assert!((data[0] - 1.0).abs() < 1e-6);
1733        assert!((data[1] - 2.0).abs() < 1e-6);
1734        // Row 1 (padding): [0, 0]
1735        assert!((data[2] - 0.0).abs() < 1e-6);
1736        assert!((data[3] - 0.0).abs() < 1e-6);
1737        // Row 2: [5, 6]
1738        assert!((data[4] - 5.0).abs() < 1e-6);
1739        assert!((data[5] - 6.0).abs() < 1e-6);
1740    }
1741
1742    #[test]
1743    fn test_padding_idx_out_of_range() {
1744        let result = Embedding::<f32>::new(5, 3, Some(10));
1745        assert!(result.is_err());
1746    }
1747
1748    // --- Out-of-bounds error ---
1749
1750    #[test]
1751    fn test_out_of_bounds_error() {
1752        let emb = Embedding::<f32>::new(5, 3, None).unwrap();
1753        let indices = index_tensor(&[0.0, 5.0]); // 5 is out of bounds for num_embeddings=5
1754        let result = emb.forward(&indices);
1755        assert!(result.is_err());
1756    }
1757
1758    #[test]
1759    fn test_negative_index_error() {
1760        let emb = Embedding::<f32>::new(5, 3, None).unwrap();
1761        let indices = index_tensor(&[-1.0]); // Negative cannot convert to usize
1762        let result = emb.forward(&indices);
1763        assert!(result.is_err());
1764    }
1765
1766    // --- N-D index input (matches upstream F.embedding) ---
1767
1768    #[test]
1769    fn test_2d_index_input_shape() {
1770        // Upstream `embedding_symint` (aten/src/ATen/native/Embedding.cpp:48-53)
1771        // accepts ANY index shape and returns `(*indices.sizes(), embedding_dim)`.
1772        // A [2,2] index against a [5,3] weight => output shape [2,2,3].
1773        let emb = Embedding::<f32>::new(5, 3, None).unwrap();
1774        let input = Tensor::from_storage(
1775            TensorStorage::cpu(vec![0.0f32, 1.0, 2.0, 3.0]),
1776            vec![2, 2],
1777            false,
1778        )
1779        .unwrap();
1780        let output = emb.forward(&input).unwrap();
1781        assert_eq!(output.shape(), &[2, 2, 3]);
1782    }
1783
1784    // --- Backward tests ---
1785
1786    #[test]
1787    fn test_backward_simple() {
1788        // weight shape [3, 2], look up indices [0, 2]
1789        // output shape [2, 2]
1790        // grad_output = [[1, 1], [1, 1]]
1791        // grad_weight = [[1, 1], [0, 0], [1, 1]]
1792        let weight_data: Vec<f32> = vec![
1793            10.0, 20.0, // row 0
1794            30.0, 40.0, // row 1
1795            50.0, 60.0, // row 2
1796        ];
1797        let weight =
1798            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![3, 2], true).unwrap();
1799        let emb = Embedding::from_pretrained(weight, None).unwrap();
1800
1801        let indices = index_tensor(&[0.0, 2.0]);
1802        let output = emb.forward(&indices).unwrap();
1803
1804        assert!(output.requires_grad());
1805        assert_eq!(output.grad_fn().unwrap().name(), "EmbeddingBackward");
1806
1807        // Manually call backward on the grad_fn.
1808        let grad_output =
1809            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 4]), vec![2, 2], false).unwrap();
1810
1811        let grad_fn = output.grad_fn().unwrap();
1812        let grads = grad_fn.backward(&grad_output).unwrap();
1813
1814        let grad_weight = grads[0].as_ref().unwrap();
1815        assert_eq!(grad_weight.shape(), &[3, 2]);
1816        let gd = grad_weight.data().unwrap();
1817
1818        // Row 0: accessed once -> [1, 1]
1819        assert!((gd[0] - 1.0).abs() < 1e-6);
1820        assert!((gd[1] - 1.0).abs() < 1e-6);
1821        // Row 1: not accessed -> [0, 0]
1822        assert!((gd[2] - 0.0).abs() < 1e-6);
1823        assert!((gd[3] - 0.0).abs() < 1e-6);
1824        // Row 2: accessed once -> [1, 1]
1825        assert!((gd[4] - 1.0).abs() < 1e-6);
1826        assert!((gd[5] - 1.0).abs() < 1e-6);
1827    }
1828
1829    #[test]
1830    fn test_backward_duplicate_indices() {
1831        // weight shape [3, 2], look up indices [1, 1, 0, 1]
1832        // output shape [4, 2]
1833        //
1834        // grad_output = [[1, 2], [3, 4], [5, 6], [7, 8]]
1835        //
1836        // grad_weight[0] = grad_output[2] = [5, 6]       (index 0 appears once, at position 2)
1837        // grad_weight[1] = grad_output[0] + grad_output[1] + grad_output[3]
1838        //                = [1, 2] + [3, 4] + [7, 8] = [11, 14]
1839        // grad_weight[2] = [0, 0]                          (index 2 never accessed)
1840        let weight_data: Vec<f32> = vec![
1841            10.0, 20.0, // row 0
1842            30.0, 40.0, // row 1
1843            50.0, 60.0, // row 2
1844        ];
1845        let weight =
1846            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![3, 2], true).unwrap();
1847        let emb = Embedding::from_pretrained(weight, None).unwrap();
1848
1849        let indices = index_tensor(&[1.0, 1.0, 0.0, 1.0]);
1850        let output = emb.forward(&indices).unwrap();
1851
1852        let grad_output = Tensor::from_storage(
1853            TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
1854            vec![4, 2],
1855            false,
1856        )
1857        .unwrap();
1858
1859        let grad_fn = output.grad_fn().unwrap();
1860        let grads = grad_fn.backward(&grad_output).unwrap();
1861
1862        let grad_weight = grads[0].as_ref().unwrap();
1863        let gd = grad_weight.data().unwrap();
1864
1865        // Row 0: [5, 6]
1866        assert!((gd[0] - 5.0).abs() < 1e-6, "gd[0] = {}, expected 5", gd[0]);
1867        assert!((gd[1] - 6.0).abs() < 1e-6, "gd[1] = {}, expected 6", gd[1]);
1868        // Row 1: [1+3+7, 2+4+8] = [11, 14]
1869        assert!(
1870            (gd[2] - 11.0).abs() < 1e-6,
1871            "gd[2] = {}, expected 11",
1872            gd[2]
1873        );
1874        assert!(
1875            (gd[3] - 14.0).abs() < 1e-6,
1876            "gd[3] = {}, expected 14",
1877            gd[3]
1878        );
1879        // Row 2: [0, 0]
1880        assert!((gd[4] - 0.0).abs() < 1e-6, "gd[4] = {}, expected 0", gd[4]);
1881        assert!((gd[5] - 0.0).abs() < 1e-6, "gd[5] = {}, expected 0", gd[5]);
1882    }
1883
1884    #[test]
1885    fn test_backward_padding_idx_zeroed() {
1886        // Even if padding_idx is accessed, its gradient should be zero.
1887        let weight_data: Vec<f32> = vec![
1888            1.0, 2.0, // row 0
1889            0.0, 0.0, // row 1 (padding)
1890            5.0, 6.0, // row 2
1891        ];
1892        let weight =
1893            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![3, 2], true).unwrap();
1894        let emb = Embedding::from_pretrained(weight, Some(1)).unwrap();
1895
1896        let indices = index_tensor(&[0.0, 1.0, 2.0]);
1897        let output = emb.forward(&indices).unwrap();
1898
1899        let grad_output =
1900            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 6]), vec![3, 2], false).unwrap();
1901
1902        let grad_fn = output.grad_fn().unwrap();
1903        let grads = grad_fn.backward(&grad_output).unwrap();
1904
1905        let grad_weight = grads[0].as_ref().unwrap();
1906        let gd = grad_weight.data().unwrap();
1907
1908        // Row 0: [1, 1]
1909        assert!((gd[0] - 1.0).abs() < 1e-6);
1910        assert!((gd[1] - 1.0).abs() < 1e-6);
1911        // Row 1 (padding): must be [0, 0] even though it was accessed
1912        assert!((gd[2] - 0.0).abs() < 1e-6, "padding grad[1][0] should be 0");
1913        assert!((gd[3] - 0.0).abs() < 1e-6, "padding grad[1][1] should be 0");
1914        // Row 2: [1, 1]
1915        assert!((gd[4] - 1.0).abs() < 1e-6);
1916        assert!((gd[5] - 1.0).abs() < 1e-6);
1917    }
1918
1919    #[test]
1920    fn test_backward_end_to_end() {
1921        // End-to-end test: use the autograd engine to verify gradients
1922        // flow all the way to the weight parameter.
1923        let weight_data: Vec<f32> = vec![
1924            1.0, 2.0, // row 0
1925            3.0, 4.0, // row 1
1926            5.0, 6.0, // row 2
1927        ];
1928        let weight =
1929            Tensor::from_storage(TensorStorage::cpu(weight_data), vec![3, 2], true).unwrap();
1930        let emb = Embedding::from_pretrained(weight, None).unwrap();
1931
1932        let indices = index_tensor(&[1.0, 0.0]);
1933        let output = emb.forward(&indices).unwrap();
1934        // output = [[3, 4], [1, 2]], shape [2, 2]
1935
1936        // Sum all elements to get a scalar for backward.
1937        let out_data = output.data().unwrap();
1938        let total: f32 = out_data.iter().sum();
1939
1940        // Build a SumBackward that broadcasts scalar grad to output shape.
1941        #[derive(Debug)]
1942        struct SumBackward<T: Float> {
1943            input: Tensor<T>,
1944        }
1945        impl<T: Float> GradFn<T> for SumBackward<T> {
1946            fn backward(
1947                &self,
1948                grad_output: &Tensor<T>,
1949            ) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1950                let go_val = grad_output.data()?[0];
1951                let grad = vec![go_val; self.input.numel()];
1952                let t = Tensor::from_storage(
1953                    TensorStorage::cpu(grad),
1954                    self.input.shape().to_vec(),
1955                    false,
1956                )?;
1957                Ok(vec![Some(t)])
1958            }
1959            fn inputs(&self) -> Vec<&Tensor<T>> {
1960                vec![&self.input]
1961            }
1962            fn name(&self) -> &'static str {
1963                "SumBackward"
1964            }
1965        }
1966
1967        let loss = Tensor::from_operation(
1968            TensorStorage::cpu(vec![total]),
1969            vec![],
1970            Arc::new(SumBackward {
1971                input: output.clone(),
1972            }),
1973        )
1974        .unwrap();
1975
1976        backward(&loss).unwrap();
1977
1978        // The weight should now have a gradient.
1979        let grad = emb.weight.tensor().grad().unwrap().unwrap();
1980        let gd = grad.data().unwrap();
1981        assert_eq!(gd.len(), 6);
1982
1983        // Row 0 accessed once (position 1): grad = [1, 1]
1984        assert!((gd[0] - 1.0).abs() < 1e-6, "grad[0][0] = {}", gd[0]);
1985        assert!((gd[1] - 1.0).abs() < 1e-6, "grad[0][1] = {}", gd[1]);
1986        // Row 1 accessed once (position 0): grad = [1, 1]
1987        assert!((gd[2] - 1.0).abs() < 1e-6, "grad[1][0] = {}", gd[2]);
1988        assert!((gd[3] - 1.0).abs() < 1e-6, "grad[1][1] = {}", gd[3]);
1989        // Row 2 not accessed: grad = [0, 0]
1990        assert!((gd[4] - 0.0).abs() < 1e-6, "grad[2][0] = {}", gd[4]);
1991        assert!((gd[5] - 0.0).abs() < 1e-6, "grad[2][1] = {}", gd[5]);
1992    }
1993
1994    // --- Module trait tests ---
1995
1996    #[test]
1997    fn test_module_parameters() {
1998        let emb = Embedding::<f32>::new(10, 4, None).unwrap();
1999        assert_eq!(emb.parameters().len(), 1);
2000        assert_eq!(emb.parameters()[0].shape(), &[10, 4]);
2001    }
2002
2003    #[test]
2004    fn test_module_named_parameters() {
2005        let emb = Embedding::<f32>::new(5, 3, None).unwrap();
2006        let named = emb.named_parameters();
2007        assert_eq!(named.len(), 1);
2008        assert_eq!(named[0].0, "weight");
2009    }
2010
2011    #[test]
2012    fn test_module_train_eval() {
2013        let mut emb = Embedding::<f32>::new(5, 3, None).unwrap();
2014        assert!(emb.is_training());
2015        emb.eval();
2016        assert!(!emb.is_training());
2017        emb.train();
2018        assert!(emb.is_training());
2019    }
2020
2021    #[test]
2022    fn test_embedding_is_send_sync() {
2023        fn assert_send_sync<T: Send + Sync>() {}
2024        assert_send_sync::<Embedding<f32>>();
2025        assert_send_sync::<Embedding<f64>>();
2026    }
2027
2028    #[test]
2029    fn test_f64_embedding() {
2030        let emb = Embedding::<f64>::new(5, 3, None).unwrap();
2031        let indices =
2032            Tensor::from_storage(TensorStorage::cpu(vec![0.0f64, 2.0, 4.0]), vec![3], false)
2033                .unwrap();
2034        let output = emb.forward(&indices).unwrap();
2035        assert_eq!(output.shape(), &[3, 3]);
2036    }
2037
2038    // -------------------------------------------------------------------
2039    // SparseGrad integration (#623)
2040    // -------------------------------------------------------------------
2041
2042    #[test]
2043    fn sparse_grad_returns_none_when_sparse_off() {
2044        let emb = Embedding::<f32>::new(8, 4, None).unwrap();
2045        // Default ctor leaves sparse off.
2046        assert!(!emb.sparse);
2047        let idx =
2048            Tensor::from_storage(TensorStorage::cpu(vec![0.0f32, 1.0]), vec![2], false).unwrap();
2049        let _ = emb.forward(&idx).unwrap();
2050        assert!(emb.sparse_grad().unwrap().is_none());
2051    }
2052
2053    #[test]
2054    fn sparse_grad_returns_none_before_first_forward() {
2055        let emb = Embedding::<f32>::new(8, 4, None).unwrap().with_sparse(true);
2056        // No forward run yet -> no last_indices recorded.
2057        assert!(emb.sparse_grad().unwrap().is_none());
2058    }
2059
2060    #[test]
2061    fn sparse_grad_emits_only_touched_rows() {
2062        // Vocabulary 8, dim 4. Touch only rows 1, 3, 5.
2063        let emb = Embedding::<f32>::new(8, 4, None).unwrap().with_sparse(true);
2064        let idx = Tensor::from_storage(
2065            TensorStorage::cpu(vec![1.0f32, 3.0, 5.0, 1.0]),
2066            vec![4],
2067            false,
2068        )
2069        .unwrap();
2070        let _out = emb.forward(&idx).unwrap();
2071
2072        // Manually attach a synthetic dense gradient to weight, simulating
2073        // post-backward state. The gradient has known per-row values so we
2074        // can verify slab extraction.
2075        let grad_data: Vec<f32> = (0..8 * 4).map(|i| i as f32).collect();
2076        let grad_tensor =
2077            Tensor::from_storage(TensorStorage::cpu(grad_data), vec![8, 4], false).unwrap();
2078        emb.weight.tensor().set_grad(Some(grad_tensor)).unwrap();
2079
2080        let sg = emb.sparse_grad().unwrap().expect("sparse grad");
2081        // Touched rows are deduped + sorted: {1, 3, 5}.
2082        assert_eq!(sg.indices(), &[1, 3, 5]);
2083        assert_eq!(sg.slab_shape(), &[4]);
2084        // Row 1 of grad: indices 4..8 -> values [4,5,6,7]
2085        // Row 3 -> [12,13,14,15]
2086        // Row 5 -> [20,21,22,23]
2087        assert_eq!(
2088            sg.values(),
2089            &[
2090                4.0, 5.0, 6.0, 7.0, 12.0, 13.0, 14.0, 15.0, 20.0, 21.0, 22.0, 23.0
2091            ]
2092        );
2093    }
2094
2095    #[test]
2096    fn sparse_grad_apply_sgd_updates_only_touched_rows() {
2097        // End-to-end: forward → set synthetic grad → sparse_grad → apply_sgd.
2098        // Verifies that untouched rows stay at their original values.
2099        let mut emb = Embedding::<f32>::new(4, 2, None).unwrap().with_sparse(true);
2100        // Pin weight to known values for a tractable assertion.
2101        let init: Vec<f32> = (0..4 * 2).map(|i| i as f32 * 10.0).collect();
2102        emb.weight = Parameter::new(
2103            Tensor::from_storage(TensorStorage::cpu(init.clone()), vec![4, 2], true).unwrap(),
2104        );
2105
2106        let idx =
2107            Tensor::from_storage(TensorStorage::cpu(vec![0.0f32, 2.0]), vec![2], false).unwrap();
2108        let _ = emb.forward(&idx).unwrap();
2109
2110        // Synthetic gradient: each row is its index repeated.
2111        let grad_vec: Vec<f32> = (0..4_usize)
2112            .flat_map(|r| vec![r as f32, r as f32])
2113            .collect();
2114        let grad_tensor =
2115            Tensor::from_storage(TensorStorage::cpu(grad_vec), vec![4, 2], false).unwrap();
2116        emb.weight.tensor().set_grad(Some(grad_tensor)).unwrap();
2117
2118        let sg = emb.sparse_grad().unwrap().unwrap();
2119        let mut weight = emb.weight.tensor().clone();
2120        sg.apply_sgd(&mut weight, 0.5_f32).unwrap();
2121
2122        // init pattern is `i*10` row-major over [4, 2] → rows
2123        //   r0=[0, 10], r1=[20, 30], r2=[40, 50], r3=[60, 70].
2124        // Touched rows: {0, 2} (deduped). Synthetic per-row grad slabs:
2125        //   r0=[0,0], r1=[1,1], r2=[2,2], r3=[3,3].
2126        // SparseGrad pulls only touched rows -> {0: [0,0], 2: [2,2]}.
2127        // apply_sgd(lr=0.5):
2128        //   r0 -= 0.5 * [0, 0]  → [0, 10]      (no change, grad zero)
2129        //   r1                  → [20, 30]     (untouched, no update)
2130        //   r2 -= 0.5 * [2, 2]  → [40-1, 50-1] = [39, 49]
2131        //   r3                  → [60, 70]     (untouched)
2132        let updated = weight.data().unwrap().to_vec();
2133        assert_eq!(updated, vec![0.0, 10.0, 20.0, 30.0, 39.0, 49.0, 60.0, 70.0]);
2134    }
2135
2136    // -------------------------------------------------------------------
2137    // #1445 — max_norm persisted in-place weight renorm (Embedding)
2138    // -------------------------------------------------------------------
2139    //
2140    // Oracle (live torch 2.11.0):
2141    //   W = [[3,4],[0,0.5],[6,8],[1,1]]
2142    //   F.embedding(torch.tensor([0,2]), w, max_norm=5.0, norm_type=2.0)
2143    //   -> w mutated to [[3,4],[0,0.5],[3,4],[1,1]]
2144    //   row0 norm == 5.0 == max_norm (NOT > so untouched)
2145    //   row2 norm == 10 > 5 -> scale 5/(10+1e-7) ≈ 0.5 -> [3,4]
2146    //   2nd forward leaves w unchanged (rows now satisfy max_norm).
2147    //   See torch/nn/functional.py:2561-2573 (renorm before gather),
2148    //   aten/src/ATen/native/Embedding.cpp:181-212 (embedding_renorm_cpu_).
2149
2150    fn pretrained_embedding(rows: &[[f32; 2]]) -> Embedding<f32> {
2151        let mut data = Vec::with_capacity(rows.len() * 2);
2152        for r in rows {
2153            data.extend_from_slice(r);
2154        }
2155        let w = Tensor::from_storage(TensorStorage::cpu(data), vec![rows.len(), 2], true).unwrap();
2156        Embedding::from_pretrained(w, None).unwrap()
2157    }
2158
2159    #[test]
2160    fn test_max_norm_mutates_persisted_weight() {
2161        let emb = pretrained_embedding(&[[3.0, 4.0], [0.0, 0.5], [6.0, 8.0], [1.0, 1.0]])
2162            .with_max_norm(5.0)
2163            .with_norm_type(2.0);
2164
2165        // Look up rows 0 and 2.
2166        let idx = index_tensor(&[0.0, 2.0]);
2167        let out = emb.forward(&idx).unwrap();
2168        let od = out.data().unwrap();
2169        // Row 0 untouched ([3,4], norm exactly 5); row 2 clipped to [3,4].
2170        assert!((od[0] - 3.0).abs() < 1e-4, "out r0[0]={}", od[0]);
2171        assert!((od[1] - 4.0).abs() < 1e-4, "out r0[1]={}", od[1]);
2172        assert!((od[2] - 3.0).abs() < 1e-4, "out r2[0]={}", od[2]);
2173        assert!((od[3] - 4.0).abs() < 1e-4, "out r2[1]={}", od[3]);
2174
2175        // The PERSISTED weight must be mutated in place, not just the output.
2176        let w_after = emb.weight.data().unwrap().to_vec();
2177        // [[3,4],[0,0.5],[3,4],[1,1]] — only the touched, over-norm row 2 moved.
2178        assert!((w_after[0] - 3.0).abs() < 1e-4); // row0 untouched
2179        assert!((w_after[1] - 4.0).abs() < 1e-4);
2180        assert!((w_after[2] - 0.0).abs() < 1e-4); // row1 not indexed, untouched
2181        assert!((w_after[3] - 0.5).abs() < 1e-4);
2182        assert!(
2183            (w_after[4] - 3.0).abs() < 1e-4,
2184            "row2[0] persisted={}",
2185            w_after[4]
2186        );
2187        assert!(
2188            (w_after[5] - 4.0).abs() < 1e-4,
2189            "row2[1] persisted={}",
2190            w_after[5]
2191        );
2192        assert!((w_after[6] - 1.0).abs() < 1e-4); // row3 not indexed
2193        assert!((w_after[7] - 1.0).abs() < 1e-4);
2194    }
2195
2196    #[test]
2197    fn test_max_norm_second_forward_is_stable() {
2198        // Forward twice: the first call clips the over-norm row in place; the
2199        // second call sees the already-clipped weight and is a no-op on it.
2200        let emb = pretrained_embedding(&[[3.0, 4.0], [0.0, 0.5], [6.0, 8.0], [1.0, 1.0]])
2201            .with_max_norm(5.0)
2202            .with_norm_type(2.0);
2203        let idx = index_tensor(&[0.0, 2.0]);
2204
2205        let _ = emb.forward(&idx).unwrap();
2206        let w_after_first = emb.weight.data().unwrap().to_vec();
2207
2208        let _ = emb.forward(&idx).unwrap();
2209        let w_after_second = emb.weight.data().unwrap().to_vec();
2210
2211        // Stable: a second renorm of already-clipped rows changes nothing.
2212        for (a, b) in w_after_first.iter().zip(w_after_second.iter()) {
2213            assert!((a - b).abs() < 1e-7, "weight drifted: {a} vs {b}");
2214        }
2215        // And the clipped row really did change relative to the original [6,8].
2216        assert!((w_after_first[4] - 3.0).abs() < 1e-4);
2217        assert!((w_after_first[5] - 4.0).abs() < 1e-4);
2218    }
2219
2220    #[test]
2221    fn test_max_norm_untouched_rows_not_renormed() {
2222        // Row 2 ([6,8], norm 10) exceeds max_norm but is NOT indexed this
2223        // forward; only row 0 is looked up. The persisted weight's row 2 must
2224        // stay at its original over-norm value (renorm visits only touched
2225        // rows — Embedding.cpp:198-202).
2226        let emb = pretrained_embedding(&[[3.0, 4.0], [0.0, 0.5], [6.0, 8.0], [1.0, 1.0]])
2227            .with_max_norm(5.0);
2228        let idx = index_tensor(&[0.0]);
2229        let _ = emb.forward(&idx).unwrap();
2230        let w = emb.weight.data().unwrap().to_vec();
2231        assert!(
2232            (w[4] - 6.0).abs() < 1e-6,
2233            "row2 should be untouched: {}",
2234            w[4]
2235        );
2236        assert!(
2237            (w[5] - 8.0).abs() < 1e-6,
2238            "row2 should be untouched: {}",
2239            w[5]
2240        );
2241    }
2242
2243    #[test]
2244    fn test_max_norm_f32_vs_f64_norm_boundary_unchanged() {
2245        // #1612: the renorm clip DECISION must be made in the weight's native
2246        // dtype (f32 here), matching torch's `row.norm(norm_type).item<double>()`
2247        // (Embedding.cpp:202-203) — `at::norm` accumulates in `opmath_type<f32>`
2248        // == f32, stores back as f32, and only THEN widens to double.
2249        //
2250        // #1614 NOTE: the f32 L2 norm is now computed via
2251        // `ferrotorch_core::simd_reduce::l2_norm_f32_torch` (torch's vectorized
2252        // last-dim L2 kernel model), not the old scalar `Σ powf(|v|,2)`. The
2253        // row below was re-selected (live torch 2.11.0+cu130, 2026-05-28) so
2254        // that BOTH torch AND the SIMD primitive give the exact same f32 norm
2255        // 151.10968017578125 (bits 0x43171c14), preserving this test's intent
2256        // (f32-boundary, row unchanged) on a row where ferrotorch matches torch
2257        // byte-for-byte. (The previous row `[-5.0920777, -9.034002, -99.06734,
2258        // -8.838612]` — torch f32 norm == 100.0 — is a known ~3% one-ULP
2259        // residual under the SIMD primitive: torch gives 0x42c80000 but the
2260        // portable model gives 0x42c80001; that residual is documented in
2261        // `simd_reduce.rs` / `.design/ferrotorch-core/simd_reduce.md`. Re-rowing
2262        // here keeps this test pinning the f32-vs-f64 decision, not the residual.)
2263        //
2264        // Oracle: torch f32 norm of this row is 151.10968017578125 (== max_norm
2265        // below), so `F.embedding([0], w, max_norm=151.10968017578125,
2266        // norm_type=2.0)` leaves the row UNCHANGED (norm > max_norm is false,
2267        // verified live). Its f64 norm is 151.10968198544464 > the f32 norm,
2268        // which the OLD f64-accumulate path treated as "exceeds" and wrongly
2269        // scaled the row down — exactly the #1612 distinction this test pins.
2270        let row: [f32; 4] = [-92.500_87, -13.270_86, -86.028_92, -81.857_4];
2271        let emb = {
2272            let mut data = row.to_vec();
2273            data.extend_from_slice(&[0.1f32, 0.2, 0.3, 0.4]);
2274            let w = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 4], true).unwrap();
2275            Embedding::from_pretrained(w, None)
2276                .unwrap()
2277                .with_max_norm(151.109_680_175_781_25)
2278                .with_norm_type(2.0)
2279        };
2280
2281        let idx = index_tensor(&[0.0]);
2282        let _ = emb.forward(&idx).unwrap();
2283
2284        // Persisted weight row 0 must be byte-identical to the input — torch's
2285        // f32 norm == max_norm so it does NOT clip.
2286        let w = emb.weight.data().unwrap().to_vec();
2287        for (i, &orig) in row.iter().enumerate() {
2288            assert_eq!(
2289                w[i], orig,
2290                "row[{i}] must stay byte-for-byte unchanged at the f32-norm==max_norm \
2291                 boundary (torch F.embedding leaves it intact); got {} expected {orig}",
2292                w[i]
2293            );
2294        }
2295    }
2296
2297    // -------------------------------------------------------------------
2298    // #1445 — scale_grad_by_freq divides duplicate-index grad rows
2299    // -------------------------------------------------------------------
2300    //
2301    // Oracle (live torch 2.11.0): indices [1,1,0], grad_output ones[3,2].
2302    //   scale_grad_by_freq=True  -> grad rows: r0=[1,1], r1=[1,1], r2=[0,0]
2303    //   scale_grad_by_freq=False -> grad rows: r0=[1,1], r1=[2,2], r2=[0,0]
2304    //   torch/nn/functional.py:2499-2500 + aten embedding_dense_backward.
2305
2306    #[test]
2307    fn test_scale_grad_by_freq_divides_duplicates() {
2308        let weight =
2309            Tensor::from_storage(TensorStorage::cpu(vec![0.0f32; 6]), vec![3, 2], true).unwrap();
2310        let emb = Embedding::from_pretrained(weight, None)
2311            .unwrap()
2312            .with_scale_grad_by_freq(true);
2313
2314        // Index 1 appears twice, index 0 once.
2315        let idx = index_tensor(&[1.0, 1.0, 0.0]);
2316        let out = emb.forward(&idx).unwrap();
2317
2318        let grad_output =
2319            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 6]), vec![3, 2], false).unwrap();
2320        let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
2321        let gd = grads[0].as_ref().unwrap().data().unwrap();
2322
2323        // Row 0 (1 occurrence): [1,1]; row 1 (2 occurrences /2): [1,1]; row 2: [0,0].
2324        assert!((gd[0] - 1.0).abs() < 1e-6, "r0[0]={}", gd[0]);
2325        assert!((gd[1] - 1.0).abs() < 1e-6);
2326        assert!(
2327            (gd[2] - 1.0).abs() < 1e-6,
2328            "r1[0]={} (should be 1, scaled)",
2329            gd[2]
2330        );
2331        assert!((gd[3] - 1.0).abs() < 1e-6);
2332        assert!((gd[4] - 0.0).abs() < 1e-6);
2333        assert!((gd[5] - 0.0).abs() < 1e-6);
2334    }
2335
2336    #[test]
2337    fn test_scale_grad_by_freq_off_accumulates() {
2338        // Same indices, flag OFF: row 1's grad accumulates to [2,2].
2339        let weight =
2340            Tensor::from_storage(TensorStorage::cpu(vec![0.0f32; 6]), vec![3, 2], true).unwrap();
2341        let emb = Embedding::from_pretrained(weight, None).unwrap();
2342        let idx = index_tensor(&[1.0, 1.0, 0.0]);
2343        let out = emb.forward(&idx).unwrap();
2344        let grad_output =
2345            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 6]), vec![3, 2], false).unwrap();
2346        let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
2347        let gd = grads[0].as_ref().unwrap().data().unwrap();
2348        assert!(
2349            (gd[2] - 2.0).abs() < 1e-6,
2350            "r1[0]={} (should be 2, unscaled)",
2351            gd[2]
2352        );
2353        assert!((gd[3] - 2.0).abs() < 1e-6);
2354    }
2355
2356    // -------------------------------------------------------------------
2357    // #1445 — EmbeddingBag kwargs (max_norm / padding_idx / include_last_offset)
2358    // -------------------------------------------------------------------
2359
2360    fn pretrained_bag(rows: &[Vec<f32>], mode: EmbeddingBagMode) -> EmbeddingBag<f32> {
2361        let dim = rows[0].len();
2362        let mut data = Vec::new();
2363        for r in rows {
2364            data.extend_from_slice(r);
2365        }
2366        let mut bag = EmbeddingBag::<f32>::new(rows.len(), dim, mode).unwrap();
2367        bag.weight = Parameter::new(
2368            Tensor::from_storage(TensorStorage::cpu(data), vec![rows.len(), dim], true).unwrap(),
2369        );
2370        bag
2371    }
2372
2373    #[test]
2374    fn test_bag_modes_match_torch() {
2375        // Oracle (torch 2.11.0): W=[[1,2,3],[4,5,6],[7,8,9],[10,11,12]],
2376        // input [0,1,2,3], offsets [0,2] -> bag0=rows{0,1}, bag1=rows{2,3}.
2377        //   sum:  [[5,7,9],[17,19,21]]
2378        //   mean: [[2.5,3.5,4.5],[8.5,9.5,10.5]]
2379        //   max:  [[4,5,6],[10,11,12]]
2380        let rows = vec![
2381            vec![1.0, 2.0, 3.0],
2382            vec![4.0, 5.0, 6.0],
2383            vec![7.0, 8.0, 9.0],
2384            vec![10.0, 11.0, 12.0],
2385        ];
2386        let inp = index_tensor(&[0.0, 1.0, 2.0, 3.0]);
2387        let offs = [0usize, 2];
2388
2389        let sum = pretrained_bag(&rows, EmbeddingBagMode::Sum)
2390            .forward_bag(&inp, &offs)
2391            .unwrap();
2392        assert_eq!(sum.data().unwrap(), &[5.0, 7.0, 9.0, 17.0, 19.0, 21.0]);
2393
2394        let mean = pretrained_bag(&rows, EmbeddingBagMode::Mean)
2395            .forward_bag(&inp, &offs)
2396            .unwrap();
2397        assert_eq!(mean.data().unwrap(), &[2.5, 3.5, 4.5, 8.5, 9.5, 10.5]);
2398
2399        let max = pretrained_bag(&rows, EmbeddingBagMode::Max)
2400            .forward_bag(&inp, &offs)
2401            .unwrap();
2402        assert_eq!(max.data().unwrap(), &[4.0, 5.0, 6.0, 10.0, 11.0, 12.0]);
2403    }
2404
2405    #[test]
2406    fn test_bag_max_norm_mutates_weight() {
2407        // Oracle (torch 2.11.0): W=[[1,2,3],[4,5,6],[7,8,9],[10,11,12]],
2408        // input [0,1,2,3], offsets [0,2], mode=sum, max_norm=5.0.
2409        // row0 norm sqrt(14)≈3.74 < 5 untouched; rows 1,2,3 over-norm scaled.
2410        // Persisted weight row1 -> ~[2.279212, 2.849014, 3.418817].
2411        let rows = vec![
2412            vec![1.0, 2.0, 3.0],
2413            vec![4.0, 5.0, 6.0],
2414            vec![7.0, 8.0, 9.0],
2415            vec![10.0, 11.0, 12.0],
2416        ];
2417        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum)
2418            .with_max_norm(5.0)
2419            .with_norm_type(2.0);
2420        let inp = index_tensor(&[0.0, 1.0, 2.0, 3.0]);
2421        let offs = [0usize, 2];
2422        let out = bag.forward_bag(&inp, &offs).unwrap();
2423        // bag0 = renormed row0 + renormed row1 = [1,2,3] + [2.279212,2.849014,3.418817]
2424        let od = out.data().unwrap();
2425        assert!((od[0] - 3.279212).abs() < 1e-4, "bag0[0]={}", od[0]);
2426        assert!((od[1] - 4.849014).abs() < 1e-4, "bag0[1]={}", od[1]);
2427        assert!((od[2] - 6.418818).abs() < 1e-4, "bag0[2]={}", od[2]);
2428
2429        // Persisted weight row0 untouched (under norm), row1 renormed.
2430        let w = bag.weight.data().unwrap().to_vec();
2431        assert!((w[0] - 1.0).abs() < 1e-6, "row0 untouched");
2432        assert!(
2433            (w[3] - 2.279212).abs() < 1e-4,
2434            "row1 persisted renorm: {}",
2435            w[3]
2436        );
2437        assert!((w[4] - 2.849014).abs() < 1e-4);
2438        assert!((w[5] - 3.418817).abs() < 1e-4);
2439    }
2440
2441    #[test]
2442    fn test_bag_padding_idx_excluded_from_reduction() {
2443        // Oracle (torch 2.11.0): W=[[1,1],[2,2],[4,4],[8,8]], padding_idx=1,
2444        // single bag input [0,1,2]. idx 1 is padding -> excluded.
2445        //   mean: ([1,1]+[4,4])/2 = [2.5,2.5]   (divides by non-pad count 2)
2446        //   sum:  [1,1]+[4,4]      = [5,5]
2447        let rows = vec![
2448            vec![1.0, 1.0],
2449            vec![2.0, 2.0],
2450            vec![4.0, 4.0],
2451            vec![8.0, 8.0],
2452        ];
2453        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2454        let offs = [0usize];
2455
2456        let mut mean = pretrained_bag(&rows, EmbeddingBagMode::Mean);
2457        mean.padding_idx = Some(1);
2458        let mo = mean.forward_bag(&inp, &offs).unwrap();
2459        assert_eq!(mo.data().unwrap(), &[2.5, 2.5]);
2460
2461        let mut sum = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2462        sum.padding_idx = Some(1);
2463        let so = sum.forward_bag(&inp, &offs).unwrap();
2464        assert_eq!(so.data().unwrap(), &[5.0, 5.0]);
2465    }
2466
2467    #[test]
2468    fn test_bag_include_last_offset() {
2469        // Oracle (torch 2.11.0): W=[[1,2],[3,4],[5,6],[7,8]], input [0,1,2,3],
2470        // offsets [0,2,4], include_last_offset=True, mode=sum.
2471        //   bag0 = row0+row1 = [4,6]; bag1 = row2+row3 = [12,14].
2472        let rows = vec![
2473            vec![1.0, 2.0],
2474            vec![3.0, 4.0],
2475            vec![5.0, 6.0],
2476            vec![7.0, 8.0],
2477        ];
2478        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum).with_include_last_offset(true);
2479        let inp = index_tensor(&[0.0, 1.0, 2.0, 3.0]);
2480        let offs = [0usize, 2, 4];
2481        let out = bag.forward_bag(&inp, &offs).unwrap();
2482        assert_eq!(out.shape(), &[2, 2]);
2483        assert_eq!(out.data().unwrap(), &[4.0, 6.0, 12.0, 14.0]);
2484    }
2485
2486    #[test]
2487    fn test_bag_max_mode_rejects_sparse_and_scale_grad() {
2488        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
2489        let inp = index_tensor(&[0.0, 1.0]);
2490        let offs = [0usize];
2491
2492        let scaled = pretrained_bag(&rows, EmbeddingBagMode::Max).with_scale_grad_by_freq(true);
2493        assert!(scaled.forward_bag(&inp, &offs).is_err());
2494
2495        let sparse = pretrained_bag(&rows, EmbeddingBagMode::Max).with_sparse(true);
2496        assert!(sparse.forward_bag(&inp, &offs).is_err());
2497    }
2498
2499    #[test]
2500    fn test_bag_padding_idx_validated_and_zeroed() {
2501        // padding_idx out of range rejected; in range -> that row zeroed.
2502        assert!(EmbeddingBag::<f32>::new_with(3, 2, EmbeddingBagMode::Sum, Some(5)).is_err());
2503
2504        let bag = EmbeddingBag::<f32>::new_with(4, 3, EmbeddingBagMode::Sum, Some(2)).unwrap();
2505        let w = bag.weight.data().unwrap();
2506        let pad_start = 2 * 3;
2507        for i in 0..3 {
2508            assert!(
2509                w[pad_start + i].abs() < 1e-6,
2510                "padding row not zeroed at {i}: {}",
2511                w[pad_start + i]
2512            );
2513        }
2514        assert_eq!(bag.padding_idx(), Some(2));
2515    }
2516
2517    // -------------------------------------------------------------------
2518    // #1610 — EmbeddingBag per_sample_weights (sum-mode-only scaling +
2519    // gradient to BOTH the embedding table AND per_sample_weights).
2520    // -------------------------------------------------------------------
2521    //
2522    // All oracle values constructed from live torch 2.11.0+cu130
2523    // (2026-05-28) via `torch.nn.functional.embedding_bag(...,
2524    // per_sample_weights=...)` with `.backward()`:
2525    //   torch/nn/functional.py:2576-2791 (psw handling + mode='sum'-only
2526    //   check at :2773-2778; shape check at :2698-2702);
2527    //   aten/src/ATen/native/EmbeddingBag.cpp:537-543 (forward scale),
2528    //   :1564-1582 (grad to weight = grad[bag]*psw), :1716-1724
2529    //   (grad to psw = dot(grad[bag], weight[idx])).
2530
2531    /// Helper: build a `per_sample_weights` tensor with `requires_grad`.
2532    fn psw_tensor(w: &[f32]) -> Tensor<f32> {
2533        Tensor::from_storage(TensorStorage::cpu(w.to_vec()), vec![w.len()], true).unwrap()
2534    }
2535
2536    #[test]
2537    fn test_bag_psw_sum_forward_single_bag() {
2538        // Oracle (torch 2.11.0): W=[[1,2],[3,4],[5,6]], input [0,1,2],
2539        // offsets [0], mode=sum, per_sample_weights=[0.5,2.0,1.0].
2540        //   out = 0.5*[1,2] + 2*[3,4] + 1*[5,6] = [11.5, 15.0]
2541        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
2542        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2543        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2544        let offs = [0usize];
2545        let psw = psw_tensor(&[0.5, 2.0, 1.0]);
2546        let out = bag.forward_bag_weighted(&inp, &offs, Some(&psw)).unwrap();
2547        let od = out.data().unwrap();
2548        assert!((od[0] - 11.5).abs() < 1e-5, "out[0]={}", od[0]);
2549        assert!((od[1] - 15.0).abs() < 1e-5, "out[1]={}", od[1]);
2550    }
2551
2552    #[test]
2553    fn test_bag_psw_sum_grad_to_weight_and_psw() {
2554        // Same setup as the forward test; grad_output = ones[1,2].
2555        // Oracle (torch 2.11.0):
2556        //   grad_W   = [[0.5,0.5],[2,2],[1,1]]   (= grad[bag]*psw per row)
2557        //   grad_psw = [3.0, 7.0, 11.0]          (= dot(grad[bag], weight[idx]))
2558        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
2559        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2560        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2561        let offs = [0usize];
2562        let psw = psw_tensor(&[0.5, 2.0, 1.0]);
2563        let out = bag.forward_bag_weighted(&inp, &offs, Some(&psw)).unwrap();
2564
2565        assert!(out.requires_grad());
2566        assert_eq!(
2567            out.grad_fn().unwrap().name(),
2568            "EmbeddingBagSumWeightedBackward"
2569        );
2570
2571        let grad_output =
2572            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 1.0]), vec![1, 2], false).unwrap();
2573        let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
2574
2575        // grads[0] -> weight (input 0), grads[1] -> psw (input 1).
2576        let gw = grads[0].as_ref().unwrap().data().unwrap();
2577        assert_eq!(grads[0].as_ref().unwrap().shape(), &[3, 2]);
2578        let expect_w = [0.5, 0.5, 2.0, 2.0, 1.0, 1.0];
2579        for (i, &e) in expect_w.iter().enumerate() {
2580            assert!((gw[i] - e).abs() < 1e-5, "grad_W[{i}]={} exp {e}", gw[i]);
2581        }
2582
2583        let gp = grads[1].as_ref().unwrap().data().unwrap();
2584        assert_eq!(grads[1].as_ref().unwrap().shape(), &[3]);
2585        let expect_psw = [3.0, 7.0, 11.0];
2586        for (i, &e) in expect_psw.iter().enumerate() {
2587            assert!((gp[i] - e).abs() < 1e-5, "grad_psw[{i}]={} exp {e}", gp[i]);
2588        }
2589    }
2590
2591    #[test]
2592    fn test_bag_psw_sum_two_bags_offsets() {
2593        // Oracle (torch 2.11.0): W=[[1,2,3],[4,5,6],[7,8,9],[10,11,12]],
2594        // input [0,1,2,3], offsets [0,2], mode=sum, psw=[2,0.5,1.5,3].
2595        //   bag0 = 2*[1,2,3] + 0.5*[4,5,6]   = [4, 6.5, 9]
2596        //   bag1 = 1.5*[7,8,9] + 3*[10,11,12] = [40.5, 45, 49.5]
2597        // grad_output = [[1,1,1],[2,2,2]]:
2598        //   grad_W   = [[2,2,2],[0.5,0.5,0.5],[3,3,3],[6,6,6]]
2599        //   grad_psw = [6, 15, 48, 66]
2600        let rows = vec![
2601            vec![1.0, 2.0, 3.0],
2602            vec![4.0, 5.0, 6.0],
2603            vec![7.0, 8.0, 9.0],
2604            vec![10.0, 11.0, 12.0],
2605        ];
2606        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2607        let inp = index_tensor(&[0.0, 1.0, 2.0, 3.0]);
2608        let offs = [0usize, 2];
2609        let psw = psw_tensor(&[2.0, 0.5, 1.5, 3.0]);
2610        let out = bag.forward_bag_weighted(&inp, &offs, Some(&psw)).unwrap();
2611        let od = out.data().unwrap();
2612        let expect_out = [4.0, 6.5, 9.0, 40.5, 45.0, 49.5];
2613        for (i, &e) in expect_out.iter().enumerate() {
2614            assert!((od[i] - e).abs() < 1e-4, "out[{i}]={} exp {e}", od[i]);
2615        }
2616
2617        let grad_output = Tensor::from_storage(
2618            TensorStorage::cpu(vec![1.0f32, 1.0, 1.0, 2.0, 2.0, 2.0]),
2619            vec![2, 3],
2620            false,
2621        )
2622        .unwrap();
2623        let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
2624        let gw = grads[0].as_ref().unwrap().data().unwrap();
2625        let expect_w = [2.0, 2.0, 2.0, 0.5, 0.5, 0.5, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0];
2626        for (i, &e) in expect_w.iter().enumerate() {
2627            assert!((gw[i] - e).abs() < 1e-4, "grad_W[{i}]={} exp {e}", gw[i]);
2628        }
2629        let gp = grads[1].as_ref().unwrap().data().unwrap();
2630        let expect_psw = [6.0, 15.0, 48.0, 66.0];
2631        for (i, &e) in expect_psw.iter().enumerate() {
2632            assert!((gp[i] - e).abs() < 1e-4, "grad_psw[{i}]={} exp {e}", gp[i]);
2633        }
2634    }
2635
2636    #[test]
2637    fn test_bag_psw_with_padding_idx() {
2638        // Oracle (torch 2.11.0): W=[[1,1],[2,2],[4,4],[8,8]], padding_idx=1,
2639        // single bag input [0,1,2], mode=sum, psw=[2,5,3].
2640        // idx 1 is padding -> excluded from the bag AND from both grads.
2641        //   out = 2*[1,1] + 3*[4,4] = [14, 14]
2642        //   grad_W   (g=ones) = [[2,2],[0,0],[3,3],[0,0]]
2643        //   grad_psw           = [2.0, 0.0, 8.0]   (padding sample's psw grad 0)
2644        let rows = vec![
2645            vec![1.0, 1.0],
2646            vec![2.0, 2.0],
2647            vec![4.0, 4.0],
2648            vec![8.0, 8.0],
2649        ];
2650        let mut bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2651        bag.padding_idx = Some(1);
2652        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2653        let offs = [0usize];
2654        let psw = psw_tensor(&[2.0, 5.0, 3.0]);
2655        let out = bag.forward_bag_weighted(&inp, &offs, Some(&psw)).unwrap();
2656        let od = out.data().unwrap();
2657        assert!((od[0] - 14.0).abs() < 1e-5, "out[0]={}", od[0]);
2658        assert!((od[1] - 14.0).abs() < 1e-5, "out[1]={}", od[1]);
2659
2660        let grad_output =
2661            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32, 1.0]), vec![1, 2], false).unwrap();
2662        let grads = out.grad_fn().unwrap().backward(&grad_output).unwrap();
2663        let gw = grads[0].as_ref().unwrap().data().unwrap();
2664        let expect_w = [2.0, 2.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0];
2665        for (i, &e) in expect_w.iter().enumerate() {
2666            assert!((gw[i] - e).abs() < 1e-5, "grad_W[{i}]={} exp {e}", gw[i]);
2667        }
2668        let gp = grads[1].as_ref().unwrap().data().unwrap();
2669        let expect_psw = [2.0, 0.0, 8.0];
2670        for (i, &e) in expect_psw.iter().enumerate() {
2671            assert!((gp[i] - e).abs() < 1e-5, "grad_psw[{i}]={} exp {e}", gp[i]);
2672        }
2673    }
2674
2675    #[test]
2676    fn test_bag_psw_end_to_end_autograd() {
2677        // End-to-end via the autograd engine: a scalar loss = sum(out) should
2678        // populate grads on BOTH the weight parameter and the psw leaf.
2679        // Reuses the single-bag oracle (grad_output = ones).
2680        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
2681        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2682        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2683        let offs = [0usize];
2684        let psw = psw_tensor(&[0.5, 2.0, 1.0]);
2685        let out = bag.forward_bag_weighted(&inp, &offs, Some(&psw)).unwrap();
2686
2687        // loss = sum(out); SumBackward broadcasts the scalar grad to ones.
2688        let out_data = out.data().unwrap();
2689        let total: f32 = out_data.iter().sum();
2690        #[derive(Debug)]
2691        struct SumBackward<T: Float> {
2692            input: Tensor<T>,
2693        }
2694        impl<T: Float> GradFn<T> for SumBackward<T> {
2695            fn backward(
2696                &self,
2697                grad_output: &Tensor<T>,
2698            ) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2699                let go_val = grad_output.data()?[0];
2700                let grad = vec![go_val; self.input.numel()];
2701                Ok(vec![Some(Tensor::from_storage(
2702                    TensorStorage::cpu(grad),
2703                    self.input.shape().to_vec(),
2704                    false,
2705                )?)])
2706            }
2707            fn inputs(&self) -> Vec<&Tensor<T>> {
2708                vec![&self.input]
2709            }
2710            fn name(&self) -> &'static str {
2711                "SumBackward"
2712            }
2713        }
2714        let loss = Tensor::from_operation(
2715            TensorStorage::cpu(vec![total]),
2716            vec![],
2717            Arc::new(SumBackward { input: out.clone() }),
2718        )
2719        .unwrap();
2720        backward(&loss).unwrap();
2721
2722        // Weight grad = [[0.5,0.5],[2,2],[1,1]].
2723        let wg = bag.weight.tensor().grad().unwrap().unwrap();
2724        let wgd = wg.data().unwrap();
2725        let expect_w = [0.5, 0.5, 2.0, 2.0, 1.0, 1.0];
2726        for (i, &e) in expect_w.iter().enumerate() {
2727            assert!((wgd[i] - e).abs() < 1e-5, "W.grad[{i}]={} exp {e}", wgd[i]);
2728        }
2729        // psw grad = [3,7,11].
2730        let pg = psw.grad().unwrap().unwrap();
2731        let pgd = pg.data().unwrap();
2732        let expect_psw = [3.0, 7.0, 11.0];
2733        for (i, &e) in expect_psw.iter().enumerate() {
2734            assert!(
2735                (pgd[i] - e).abs() < 1e-5,
2736                "psw.grad[{i}]={} exp {e}",
2737                pgd[i]
2738            );
2739        }
2740    }
2741
2742    #[test]
2743    fn test_bag_psw_rejects_mean_and_max_modes() {
2744        // torch raises NotImplementedError with this exact text for non-sum
2745        // modes (functional.py:2773-2778). ferrotorch returns Err with the
2746        // byte-identical message.
2747        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
2748        let inp = index_tensor(&[0.0, 1.0]);
2749        let offs = [0usize];
2750        let psw = psw_tensor(&[1.0, 1.0]);
2751
2752        for (mode, mode_str) in [
2753            (EmbeddingBagMode::Mean, "mean"),
2754            (EmbeddingBagMode::Max, "max"),
2755        ] {
2756            let bag = pretrained_bag(&rows, mode);
2757            let err = bag
2758                .forward_bag_weighted(&inp, &offs, Some(&psw))
2759                .unwrap_err();
2760            let msg = err.to_string();
2761            let expected = format!(
2762                "embedding_bag: per_sample_weights was not None. per_sample_weights is only \
2763                 supported for mode='sum' (got mode='{mode_str}'). Please open a feature request \
2764                 on GitHub."
2765            );
2766            assert!(
2767                msg.contains(&expected),
2768                "mode={mode_str}: error message must contain torch's exact text.\n got: {msg}\n want: {expected}"
2769            );
2770        }
2771    }
2772
2773    #[test]
2774    fn test_bag_psw_rejects_shape_mismatch() {
2775        // psw must have the same shape as input (functional.py:2698-2702).
2776        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
2777        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2778        let inp = index_tensor(&[0.0, 1.0]);
2779        let offs = [0usize];
2780        let psw = psw_tensor(&[1.0]); // wrong length
2781        assert!(bag.forward_bag_weighted(&inp, &offs, Some(&psw)).is_err());
2782    }
2783
2784    #[test]
2785    fn test_bag_forward_bag_unweighted_unchanged() {
2786        // The 2-arg forward_bag (psw=None delegate) must produce the SAME
2787        // unweighted sum as before, with NO grad attached.
2788        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
2789        let bag = pretrained_bag(&rows, EmbeddingBagMode::Sum);
2790        let inp = index_tensor(&[0.0, 1.0, 2.0]);
2791        let offs = [0usize];
2792        let out = bag.forward_bag(&inp, &offs).unwrap();
2793        // sum = [1+3+5, 2+4+6] = [9, 12]
2794        assert_eq!(out.data().unwrap(), &[9.0, 12.0]);
2795        assert!(!out.requires_grad());
2796    }
2797}