Skip to main content

ferrotorch_nn/
dropout.rs

1//! Dropout regularization layers.
2//!
3//! [`Dropout`] randomly zeroes individual elements during training with
4//! probability `p`, scaling surviving elements by `1/(1-p)` (inverted
5//! dropout). [`Dropout1d`], [`Dropout2d`], and [`Dropout3d`] drop entire
6//! channels instead of individual elements, for 3D, 4D, and 5D inputs
7//! respectively. [`AlphaDropout`] preserves mean and variance for use
8//! with SELU activations.
9//!
10//! All six CPU forward paths draw their keep-mask from the byte-exact
11//! MT19937 `Generator` (`ferrotorch_core::rng`) with torch's exact
12//! consumption — per element ([`Dropout`], [`AlphaDropout`]) or per `[N, C]`
13//! channel ([`Dropout1d`]/[`Dropout2d`]/[`Dropout3d`],
14//! [`FeatureAlphaDropout`]) in flat order, keep iff `next_uniform_f64() <
15//! (1 - p)` — so `ferrotorch_core::manual_seed(s)` reproduces
16//! `torch.manual_seed(s); F.dropout{,1d,2d,3d}` / `nn.AlphaDropout` /
17//! `nn.FeatureAlphaDropout` byte-for-byte (#1634, #1635, #1636). The alpha
18//! variants use torch's hardcoded `alpha = 1.7580993408473766` affine
19//! (`aten/src/ATen/native/Dropout.cpp:76`).
20//!
21//! All modules are identity in eval mode and have zero learnable parameters.
22//!
23//! ## REQ status (per `.design/ferrotorch-nn/dropout.md`)
24//!
25//! | REQ | Status | Evidence |
26//! |---|---|---|
27//! | REQ-1 | SHIPPED | impl: `pub struct Dropout<T: Float>` here with `p` / `training` fields + ctor rejecting `p` outside `[0,1)`; non-test consumer: `Dropout::<T>::new(0.5)?` invoked in `ferrotorch-vision/src/models/vgg.rs` (the VGG classifier head dropout). |
28//! | REQ-2 | SHIPPED | impl: `<Dropout as Module>::forward` body with eval / `p==0` short-circuit + Bernoulli + scale here; non-test consumer: `Dropout::forward` is called on every forward pass through the VGG / Inception classifier (constructed in `vgg.rs` and `inception.rs`). |
29//! | REQ-3 | SHIPPED | impl: `input.is_cuda() && backend = ferrotorch_core::gpu_dispatch::gpu_backend()` GPU branch inside `<Dropout as Module>::forward` here; non-test consumer: any vision model run on CUDA (e.g. VGG / Inception fine-tuning with parameters on GPU) triggers this on every forward step. |
30//! | REQ-4 | SHIPPED | impl: `struct DropoutBackward<T>` + `GradFn` impl here; non-test consumer: every `loss.backward()` over a model containing `Dropout` traverses these nodes via the autograd engine. |
31//! | REQ-5 | SHIPPED | impl: `pub struct Dropout2d<T: Float>` + `Module` impl here; per-channel keep-mask drawn from the byte-exact MT19937 `Generator` (`make_feature_noise(input).bernoulli_(1-p)`, `Dropout.cpp:73-74`, keep iff `u < 1-p`), reproducing `torch.manual_seed(s); F.dropout2d` byte-for-byte (#1635, pinned by `divergence_dropout_seed_extended_and_feature_1634.rs::dropout2d_seed42_per_channel_matches_torch` vs live torch 2.11); non-test consumer: `pub use dropout::Dropout2d` in `lib.rs` exposes for downstream vision / segmentation code. |
32//! | REQ-6 | SHIPPED | impl: `pub struct Dropout1d<T: Float>` + `Module` impl here; per-channel MT19937 mask (#1635, pinned by `dropout1d_seed42_per_channel_matches_torch`); non-test consumer: `pub use dropout::Dropout1d` in `lib.rs`. |
33//! | REQ-7 | SHIPPED | impl: `pub struct Dropout3d<T: Float>` + `Module` impl here; per-channel MT19937 mask (#1635, pinned by `dropout3d_seed42_per_channel_matches_torch`); non-test consumer: `pub use dropout::Dropout3d` in `lib.rs`. |
34//! | REQ-8 | SHIPPED | impl: `struct Dropout2dBackward<T>` + `GradFn` impl here; non-test consumer: autograd engine traversal on any model using `Dropout2d` in training. |
35//! | REQ-9 | SHIPPED | impl: `pub struct AlphaDropout<T: Float>` + torch's EXACT alpha affine inside `<AlphaDropout as Module>::forward` here — per-element MT19937 keep-mask (keep iff `u < 1-p`) + `alpha = 1.7580993408473766` (`ALPHA_DROPOUT_ALPHA`, torch's hardcoded literal at `Dropout.cpp:76`, NOT recomputed `SELU_LAMBDA*SELU_ALPHA`), `a = 1/sqrt((alpha^2*p+1)*(1-p))`, kept = `a*x+alpha*a*p`, dropped = `-alpha*a+alpha*a*p` (`Dropout.cpp:74-79`), reproducing `torch.manual_seed(s); nn.AlphaDropout(p)` byte-for-byte (#1636, pinned by `divergence_dropout_seed_extended_and_feature_1634.rs::alpha_dropout_seed42_matches_torch` vs live torch 2.11); non-test consumer: `pub use dropout::AlphaDropout` in `lib.rs`. |
36//! | REQ-10 | SHIPPED | impl: `struct AlphaDropoutBackward<T>` + `GradFn` impl here; non-test consumer: autograd engine traversal on models using `AlphaDropout`. |
37//! | REQ-11 | SHIPPED | impl: 5 `Module<T> for <DropoutKind><T>` impl blocks here, each returning `vec![]` for parameters; non-test consumer: `ferrotorch_optim::Optimizer` walks `Module::parameters_mut()` of containers; dropout returns an empty list (correct: dropout has no trainable parameters). |
38//! | REQ-12 | SHIPPED | impl: `with_inplace` builder + `inplace` getter + `inplace` field on all six dropout structs, the autograd-safe `apply_inplace_dropout` helper (errors on grad-requiring leaf per torch `VariableTypeUtils.h:80-84`; out-of-place fallback on grad-requiring non-leaf — R-DEV-7, ferrotorch lacks torch's version counter `saved_variable.cpp:170-186`; raw `write_inplace`/`Tensor::update_data` only on the non-grad-tracked path), and the `if self.inplace { apply_inplace_dropout(input, &output_data)? }` branch in `<Dropout/Dropout1d/Dropout2d/Dropout3d as Module>::forward` here, mirroring `_VF.dropout_`/`_VF.feature_dropout_` at `torch/nn/functional.py:1449,1516,1579,1629` on the memory-opt path; `AlphaDropout`/`FeatureAlphaDropout` carry the field for ABI parity but match torch's module forward which never forwards `inplace` (`dropout.py:265-269,319-323`). Non-test production consumer: the `if self.inplace` branch is on the live forward path of `<Dropout as Module>::forward` here, exercised by `ferrotorch-nn/src/lora.rs` (LoRA input dropout), `ferrotorch-vision/src/models/vgg.rs` / `inception.rs` (classifier head), and `ferrotorch-graph/src/gcn.rs` (inter-layer dropout). Default `inplace=false` preserves existing behavior. Closes #1446, #1580, #1581. |
39//! | REQ-13 | SHIPPED | impl: `pub struct FeatureAlphaDropout<T: Float>` + `FeatureAlphaDropoutBackward<T>` + `Module<T>` impl here — per-channel MT19937 keep-mask (`make_feature_noise` flat `[N,C]` Bernoulli, keep iff `u < 1-p`) broadcast over `[N, C, *]`, torch's EXACT alpha affine (`alpha = 1.7580993408473766`, kept = `a*x+alpha*a*p`, dropped = `-alpha*a+alpha*a*p`, `Dropout.cpp:73-79`), reproducing `torch.manual_seed(s); nn.FeatureAlphaDropout(p)` byte-for-byte (#1636, pinned by `divergence_dropout_seed_extended_and_feature_1634.rs::feature_alpha_dropout_seed42_matches_torch` vs live torch 2.11); closes #1448; non-test consumer: `pub use dropout::FeatureAlphaDropout` in `lib.rs` (re-export) exposes the layer to downstream self-normalising-network model code in `ferrotorch-vision` / `ferrotorch-llama`. |
40//! | REQ-14 | NOT-STARTED | blocker #1441 (umbrella) — `Dropout2d` / `Dropout1d` / `Dropout3d` GPU forward absent (CUDA inputs return `NotImplementedOnCuda`). Parity-sweep runner arms also absent. |
41
42use std::sync::Arc;
43
44use ferrotorch_core::autograd::no_grad::is_grad_enabled;
45use ferrotorch_core::gpu_dispatch::GpuRngState;
46use ferrotorch_core::tensor::GradFn;
47use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
48
49use crate::module::Module;
50use crate::parameter::Parameter;
51
52// ---------------------------------------------------------------------------
53// Philox 4x32-10 for CPU-side mask regeneration
54// ---------------------------------------------------------------------------
55// We need the Philox algorithm on CPU to regenerate dropout masks during
56// backward for GPU tensors (the forward mask was generated on GPU using
57// the Philox state). This is a copy of the core algorithm from
58// ferrotorch-gpu/src/rng.rs to avoid a dependency on the GPU crate.
59
60#[allow(dead_code)]
61const PHILOX_M0: u32 = 0xD2511F53;
62#[allow(dead_code)]
63const PHILOX_M1: u32 = 0xCD9E8D57;
64#[allow(dead_code)]
65const PHILOX_W0: u32 = 0x9E3779B9;
66#[allow(dead_code)]
67const PHILOX_W1: u32 = 0xBB67AE85;
68
69#[allow(dead_code)]
70#[inline]
71fn philox_round(c0: u32, c1: u32, c2: u32, c3: u32, k0: u32, k1: u32) -> (u32, u32, u32, u32) {
72    let prod0 = (PHILOX_M0 as u64) * (c0 as u64);
73    let hi0 = (prod0 >> 32) as u32;
74    let lo0 = prod0 as u32;
75
76    let prod1 = (PHILOX_M1 as u64) * (c2 as u64);
77    let hi1 = (prod1 >> 32) as u32;
78    let lo1 = prod1 as u32;
79
80    let new_c0 = hi1 ^ c1 ^ k0;
81    let new_c1 = lo1;
82    let new_c2 = hi0 ^ c3 ^ k1;
83    let new_c3 = lo0;
84
85    (new_c0, new_c1, new_c2, new_c3)
86}
87
88/// Philox 4x32-10: produces 4 uniform u32 values from (counter, key).
89#[allow(dead_code)]
90fn philox_4x32_10(counter: u64, key: u64) -> [u32; 4] {
91    let mut c0 = counter as u32;
92    let mut c1 = (counter >> 32) as u32;
93    let mut c2 = 0u32;
94    let mut c3 = 0u32;
95
96    let mut k0 = key as u32;
97    let mut k1 = (key >> 32) as u32;
98
99    for _ in 0..9 {
100        (c0, c1, c2, c3) = philox_round(c0, c1, c2, c3, k0, k1);
101        k0 = k0.wrapping_add(PHILOX_W0);
102        k1 = k1.wrapping_add(PHILOX_W1);
103    }
104    // Round 10 (final, no key advance)
105    (c0, c1, c2, c3) = philox_round(c0, c1, c2, c3, k0, k1);
106
107    [c0, c1, c2, c3]
108}
109
110/// Regenerate the GPU dropout mask on CPU, byte-exactly mirroring the
111/// `dropout_kernel` PTX in ferrotorch-gpu/src/kernels.rs: derived u32 seed
112/// `(counter ^ seed)`, per-element hash `fmix32(i * 2654435761 ^ seed)`
113/// (murmur3 finalizer), drop when `hash < threshold`.
114///
115/// This ensures the backward mask matches the forward mask generated on
116/// GPU. The fmix32 multiplications are load-bearing: a GF(2)-linear mix
117/// here (the original xorshift) made masks invariant under consecutive
118/// Philox counter deltas, so every dropout call drew the same mask
119/// (ferrotorch-paged #43). Any change here must change the PTX in
120/// lockstep.
121fn philox_dropout_mask<T: Float>(
122    numel: usize,
123    threshold: u32,
124    scale: T,
125    rng_state: &GpuRngState,
126) -> Vec<T> {
127    let zero = <T as num_traits::Zero>::zero();
128    let derived_seed = (rng_state.counter() ^ rng_state.seed()) as u32;
129
130    (0..numel)
131        .map(|i| {
132            let mut r = (i as u32).wrapping_mul(2654435761) ^ derived_seed;
133            r ^= r >> 16;
134            r = r.wrapping_mul(0x85eb_ca6b);
135            r ^= r >> 13;
136            r = r.wrapping_mul(0xc2b2_ae35);
137            r ^= r >> 16;
138            if r < threshold { zero } else { scale }
139        })
140        .collect()
141}
142
143// ---------------------------------------------------------------------------
144// In-place storage write
145// ---------------------------------------------------------------------------
146
147/// Whether the in-place dropout policy actually mutated the input storage, or
148/// suppressed the mutation for autograd safety. See [`apply_inplace_dropout`].
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150enum InplaceOutcome {
151    /// The input storage was mutated in place (`inplace=true` honored).
152    Mutated,
153    /// The mutation was suppressed for autograd safety; the caller must build
154    /// the output out-of-place from the freshly-allocated `output_data` buffer.
155    FellBackToOutOfPlace,
156}
157
158/// Apply the in-place dropout policy, mutating `input`'s storage where it is
159/// autograd-safe to do so and matching torch's observable error contract where
160/// ferrotorch can.
161///
162/// # Autograd safety (R-DEV-7 deviation — documented)
163///
164/// torch enforces in-place autograd correctness with two mechanisms that
165/// ferrotorch's autograd engine does NOT have:
166///
167/// 1. A **leaf in-place guard** — mutating a leaf that requires grad raises
168///    `"a leaf Variable that requires grad is being used in an in-place
169///    operation."` from `check_inplace`
170///    (`torch/csrc/autograd/VariableTypeUtils.h:61-63,80-84`).
171/// 2. A **version counter** — every saved tensor records the storage version it
172///    was saved at; if an in-place op bumps that version before backward,
173///    `SavedVariable::unpack` raises `"one of the variables needed for gradient
174///    computation has been modified by an inplace operation"`
175///    (`torch/csrc/autograd/saved_variable.cpp:170-186`).
176///
177/// ferrotorch has neither (no `version` field on `TensorInner`; `Tensor::clone`
178/// shares the `Arc<TensorInner>` storage). Without a version counter it cannot
179/// *detect* that another backward node saved the pre-mutation storage, so an
180/// unconditional in-place write silently corrupts that branch's gradient
181/// (#1580). To eliminate the corruption rather than risk it, this helper adopts
182/// a conservative policy on the grad-tracked path:
183///
184/// * **Leaf requiring grad, grad enabled** → return an `Err` mirroring torch's
185///   leaf-guard message. (Matches torch exactly; pins #1581.)
186/// * **Non-leaf requiring grad, grad enabled** → do NOT mutate; signal
187///   [`InplaceOutcome::FellBackToOutOfPlace`] so the caller builds a fresh
188///   output. The result tensor is numerically identical and the gradient is
189///   correct (no shared-storage corruption); this is *more permissive* than
190///   torch's version-counter `RuntimeError` — ferrotorch cannot prove the
191///   storage is unused by another backward without a version counter, so it
192///   declines to mutate instead of erroring. (Eliminates #1580's corruption.)
193/// * **Grad disabled, or input does not require grad** → mutate in place. This
194///   is the real memory-optimization case; no autograd node observes the
195///   storage, so it is graph-safe and matches torch's `_VF.dropout_`.
196///
197/// The deviation preserves torch's *observable result* (identical output,
198/// correct gradient) while declining to replicate torch's runtime error on the
199/// non-leaf path, because ferrotorch lacks the version-counter infrastructure
200/// that error depends on.
201fn apply_inplace_dropout<T: Float>(
202    input: &Tensor<T>,
203    new_data: &[T],
204) -> FerrotorchResult<InplaceOutcome> {
205    if is_grad_enabled() && input.requires_grad() {
206        if input.is_leaf() {
207            // Match torch's leaf in-place guard
208            // (`torch/csrc/autograd/VariableTypeUtils.h:80-84`).
209            return Err(FerrotorchError::InvalidArgument {
210                message:
211                    "a leaf Variable that requires grad is being used in an in-place operation."
212                        .to_string(),
213            });
214        }
215        // Non-leaf requiring grad: ferrotorch has no version counter to prove
216        // the shared storage is unused by another saved-for-backward node, so
217        // fall back to out-of-place rather than risk corrupting that branch's
218        // gradient (#1580). The caller builds the output from `new_data`.
219        return Ok(InplaceOutcome::FellBackToOutOfPlace);
220    }
221
222    // Grad disabled or input does not require grad: the genuine
223    // memory-optimization case. No autograd node can observe the storage, so
224    // the in-place write is graph-safe and matches torch's `_VF.dropout_`.
225    write_inplace(input, new_data)?;
226    Ok(InplaceOutcome::Mutated)
227}
228
229/// Write `new_data` over `input`'s storage in place, mirroring torch's
230/// `_VF.dropout_` family (`torch/nn/functional.py:1449,1516,1579,1629`)
231/// which mutate the input tensor's buffer rather than allocating a fresh
232/// output.
233///
234/// This is the raw write; the autograd-safety policy that decides *whether* a
235/// write is permitted lives in [`apply_inplace_dropout`]. Callers must route
236/// through that helper and never call this directly on a grad-tracked path.
237fn write_inplace<T: Float>(input: &Tensor<T>, new_data: &[T]) -> FerrotorchResult<()> {
238    // SAFETY: `update_data` requires exclusive access to the input's storage
239    // for the duration of the write. The dropout forward holds the only live
240    // borrow of the input data (consumed into `new_data` by the caller before
241    // this call). The autograd-safety policy in `apply_inplace_dropout`
242    // guarantees this is only reached when grad is disabled or the input does
243    // not require grad, so no backward node has saved (and could later read) a
244    // version of this storage. `new_data.len() == input.numel()` is guaranteed
245    // by the callers (the mask and input share numel). PyTorch performs this
246    // exact mutation in `_VF.dropout_` (`torch/nn/functional.py:1449`).
247    #[allow(
248        clippy::undocumented_unsafe_blocks,
249        reason = "SAFETY comment above documents the exclusive-access invariant; apply_inplace_dropout gates this to the non-grad-tracked path where no backward node observes the storage"
250    )]
251    unsafe {
252        input.update_data(new_data)?;
253    }
254    Ok(())
255}
256
257// ---------------------------------------------------------------------------
258// DropoutBackward
259// ---------------------------------------------------------------------------
260
261/// Backward node for elementwise dropout.
262///
263/// Reapplies the same binary mask scaled by `1/(1-p)` to the upstream
264/// gradient, routing gradients only through surviving elements.
265///
266/// The mask is stored as a [`Tensor<T>`] on the same device as the
267/// forward input so backward reduces to a single `mul` that stays
268/// GPU-native when the input is on CUDA.
269#[derive(Debug)]
270struct DropoutBackward<T: Float> {
271    input: Tensor<T>,
272    /// Mask tensor with elements in `{0, 1/(1-p)}`. Lives on the same
273    /// device as `input`, so `mul(grad_output, scaled_mask)` in the
274    /// backward routes entirely through GPU ops when training on CUDA.
275    scaled_mask: Tensor<T>,
276}
277
278impl<T: Float> GradFn<T> for DropoutBackward<T> {
279    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
280        let da = if self.input.requires_grad() {
281            let g = ferrotorch_core::grad_fns::arithmetic::mul(grad_output, &self.scaled_mask)?;
282            Some(g)
283        } else {
284            None
285        };
286        Ok(vec![da])
287    }
288
289    fn inputs(&self) -> Vec<&Tensor<T>> {
290        vec![&self.input]
291    }
292
293    fn name(&self) -> &'static str {
294        "DropoutBackward"
295    }
296}
297
298// ---------------------------------------------------------------------------
299// Dropout2dBackward
300// ---------------------------------------------------------------------------
301
302/// Backward node for channel-wise dropout.
303///
304/// Identical to [`DropoutBackward`] — the mask shape already encodes the
305/// channel-level structure (all spatial positions in a dropped channel are 0).
306#[derive(Debug)]
307struct Dropout2dBackward<T: Float> {
308    input: Tensor<T>,
309    scaled_mask: Vec<T>,
310}
311
312impl<T: Float> GradFn<T> for Dropout2dBackward<T> {
313    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
314        if grad_output.is_cuda() {
315            return Err(FerrotorchError::NotImplementedOnCuda {
316                op: "dropout2d backward",
317            });
318        }
319        let da = if self.input.requires_grad() {
320            let go_data = grad_output.data_vec()?;
321            let grad_a: Vec<T> = go_data
322                .iter()
323                .zip(self.scaled_mask.iter())
324                .map(|(&g, &m)| g * m)
325                .collect();
326            let g = Tensor::from_storage(
327                TensorStorage::cpu(grad_a),
328                self.input.shape().to_vec(),
329                false,
330            )?;
331            Some(g)
332        } else {
333            None
334        };
335        Ok(vec![da])
336    }
337
338    fn inputs(&self) -> Vec<&Tensor<T>> {
339        vec![&self.input]
340    }
341
342    fn name(&self) -> &'static str {
343        "Dropout2dBackward"
344    }
345}
346
347// ===========================================================================
348// Dropout
349// ===========================================================================
350
351/// Randomly zeroes elements with probability `p` during training.
352///
353/// During training, each element is independently set to zero with probability
354/// `p` and scaled by `1/(1-p)` so that the expected value is preserved
355/// (inverted dropout).  During evaluation (`eval()` mode), the input is
356/// returned unchanged.
357///
358/// # Panics
359///
360/// The constructor returns an error if `p` is outside `[0, 1)`.
361#[derive(Debug)]
362pub struct Dropout<T: Float> {
363    p: f64,
364    training: bool,
365    /// When `true`, the forward mutates the input tensor's storage in place
366    /// (mask + scale written back over the input) instead of allocating a
367    /// fresh output buffer. Mirrors `_DropoutNd.inplace` at
368    /// `torch/nn/modules/dropout.py:29` and the `inplace` branch of
369    /// `F.dropout` at `torch/nn/functional.py:1448-1450`
370    /// (`_VF.dropout_(input, p, training) if inplace`).
371    inplace: bool,
372    _marker: std::marker::PhantomData<T>,
373}
374
375impl<T: Float> Dropout<T> {
376    /// Create a new `Dropout` layer.
377    ///
378    /// `p` is the probability of an element being zeroed. Must be in `[0, 1)`.
379    pub fn new(p: f64) -> FerrotorchResult<Self> {
380        if !(0.0..1.0).contains(&p) {
381            return Err(FerrotorchError::InvalidArgument {
382                message: format!("dropout probability must be in [0, 1), got {p}"),
383            });
384        }
385        Ok(Self {
386            p,
387            training: true,
388            inplace: false,
389            _marker: std::marker::PhantomData,
390        })
391    }
392
393    /// Set the `inplace` flag, mirroring `torch.nn.Dropout(p, inplace=...)`
394    /// at `torch/nn/modules/dropout.py:22-29`. When `true`, training-mode
395    /// forward mutates the input storage instead of allocating a new buffer.
396    #[must_use]
397    pub fn with_inplace(mut self, inplace: bool) -> Self {
398        self.inplace = inplace;
399        self
400    }
401
402    /// Returns the `inplace` flag.
403    pub fn inplace(&self) -> bool {
404        self.inplace
405    }
406
407    /// Override the dropout probability after construction. Same
408    /// validation as [`Self::new`]: `p` must be in `[0, 1)`.
409    ///
410    /// Use case: MC-dropout-style inference where a model loaded with
411    /// `p=0` (eval-time default) is temporarily reactivated with a
412    /// non-zero rate to draw stochastic samples without rebuilding
413    /// the module hierarchy.
414    pub fn set_p(&mut self, p: f64) -> FerrotorchResult<()> {
415        if !(0.0..1.0).contains(&p) {
416            return Err(FerrotorchError::InvalidArgument {
417                message: format!("dropout probability must be in [0, 1), got {p}"),
418            });
419        }
420        self.p = p;
421        Ok(())
422    }
423
424    /// Read the current dropout probability.
425    pub fn p(&self) -> f64 {
426        self.p
427    }
428}
429
430impl<T: Float> Module<T> for Dropout<T> {
431    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
432        // Eval mode or p == 0: identity.
433        if !self.training || self.p == 0.0 {
434            return Ok(input.clone());
435        }
436
437        let numel = input.numel();
438        let scale = T::from(1.0 / (1.0 - self.p)).unwrap();
439        let zero = <T as num_traits::Zero>::zero();
440
441        // GPU fast path: run dropout kernel entirely on device using the
442        // Philox CBRNG. This integrates with the global GPU RNG state so
443        // that gradient checkpointing can reproduce identical masks.
444        if input.is_cuda() {
445            if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
446                let threshold = (self.p * u32::MAX as f64) as u32;
447                let scale_f32 = 1.0f32 / (1.0 - self.p as f32);
448
449                let (handle, rng_state) =
450                    backend.dropout_philox_f32(input.gpu_handle()?, threshold, scale_f32)?;
451
452                // For backward, we need the mask. Regenerate it from the saved
453                // Philox RNG state using the same deterministic hash that the
454                // GPU kernel uses. This is reproducible across checkpoint
455                // save/restore because the Philox state is deterministic.
456                if is_grad_enabled() && input.requires_grad() {
457                    let scaled_mask_vec = philox_dropout_mask(numel, threshold, scale, &rng_state);
458                    // Upload the mask to the input's device so the
459                    // backward `mul` runs on-device without a CPU
460                    // round-trip.
461                    let mask_cpu = Tensor::from_storage(
462                        TensorStorage::cpu(scaled_mask_vec),
463                        input.shape().to_vec(),
464                        false,
465                    )?;
466                    let scaled_mask = mask_cpu.to(input.device())?;
467                    return Tensor::from_operation(
468                        TensorStorage::gpu(handle),
469                        input.shape().to_vec(),
470                        Arc::new(DropoutBackward {
471                            input: input.clone(),
472                            scaled_mask,
473                        }),
474                    );
475                } else {
476                    return Tensor::from_storage(
477                        TensorStorage::gpu(handle),
478                        input.shape().to_vec(),
479                        false,
480                    );
481                }
482            }
483        }
484
485        // CPU path — draw the keep-mask from the byte-exact MT19937
486        // `Generator` (`ferrotorch_core::rng`) using torch's EXACT CPU dropout
487        // consumption, so `ferrotorch_core::manual_seed(s); Dropout::forward`
488        // reproduces `torch.manual_seed(s); F.dropout(...)` byte-for-byte
489        // (#1634). torch draws the mask via `noise.bernoulli_(1 - p)`
490        // (`aten/src/ATen/native/Dropout.cpp:74`); the scalar bernoulli kernel
491        // (`aten/src/ATen/native/cpu/DistributionTemplates.h:388-399`)
492        // evaluates per element in flat order
493        // `transformation::bernoulli<double>(uniform_real<double>(gen->random64(), 0, 1), 1 - p)`
494        // = `uniform64 < (1 - p)` (keep == 1)
495        // (`DistributionsHelper.h:107-113,219-222`,
496        // `TransformationHelper.h:84-89,171-173`).
497        // `uniform_real<double>(random64(), 0, 1)` is exactly
498        // `Generator::next_uniform_f64` (rng.rs REQ-5, byte-exact); survivors
499        // are scaled by `1/(1-p)` (`Dropout.cpp:81` `noise.div_(1 - p)`).
500        let keep_prob = 1.0 - self.p;
501        let scaled_mask_vec: Vec<T> = ferrotorch_core::rng::with_thread_rng(|g| {
502            (0..numel)
503                .map(|_| {
504                    if g.next_uniform_f64() < keep_prob {
505                        scale
506                    } else {
507                        zero
508                    }
509                })
510                .collect()
511        });
512
513        let input_data = input.data()?;
514        let output_data: Vec<T> = input_data
515            .iter()
516            .zip(scaled_mask_vec.iter())
517            .map(|(&x, &m)| x * m)
518            .collect();
519
520        // In-place branch, mirroring `_VF.dropout_(input, p, training)` at
521        // `torch/nn/functional.py:1449`. `apply_inplace_dropout` applies the
522        // autograd-safe policy: it errors on a grad-requiring leaf (matching
523        // torch), falls back to out-of-place for a grad-requiring non-leaf
524        // (ferrotorch lacks torch's version counter, so it declines to mutate
525        // shared storage), and mutates in place only when no autograd node can
526        // observe the storage. The out-of-place output below is always built
527        // from `output_data`, so the fallback needs no special handling here.
528        if self.inplace {
529            apply_inplace_dropout(input, &output_data)?;
530        }
531
532        if is_grad_enabled() && input.requires_grad() {
533            let scaled_mask = Tensor::from_storage(
534                TensorStorage::cpu(scaled_mask_vec),
535                input.shape().to_vec(),
536                false,
537            )?;
538            Tensor::from_operation(
539                TensorStorage::cpu(output_data),
540                input.shape().to_vec(),
541                Arc::new(DropoutBackward {
542                    input: input.clone(),
543                    scaled_mask,
544                }),
545            )
546        } else {
547            Tensor::from_storage(
548                TensorStorage::cpu(output_data),
549                input.shape().to_vec(),
550                false,
551            )
552        }
553    }
554
555    fn parameters(&self) -> Vec<&Parameter<T>> {
556        vec![]
557    }
558
559    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
560        vec![]
561    }
562
563    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
564        vec![]
565    }
566
567    fn train(&mut self) {
568        self.training = true;
569    }
570
571    fn eval(&mut self) {
572        self.training = false;
573    }
574
575    fn is_training(&self) -> bool {
576        self.training
577    }
578}
579
580// ===========================================================================
581// Dropout2d
582// ===========================================================================
583
584/// Randomly zeroes entire channels with probability `p` during training.
585///
586/// Expects input of shape `[B, C, ...]` (at least 2 dimensions). During
587/// training, each channel (the entire `[H, W, ...]` slice for a given `b, c`)
588/// is independently set to zero with probability `p` and surviving channels
589/// are scaled by `1/(1-p)`.  During evaluation the input is returned unchanged.
590///
591/// # Panics
592///
593/// The constructor returns an error if `p` is outside `[0, 1)`.
594#[derive(Debug)]
595pub struct Dropout2d<T: Float> {
596    p: f64,
597    training: bool,
598    /// In-place flag, mirroring `_DropoutNd.inplace` at
599    /// `torch/nn/modules/dropout.py:29` and the `inplace` branch of
600    /// `F.dropout2d` at `torch/nn/functional.py:1578-1582`
601    /// (`_VF.feature_dropout_(input, p, training) if inplace`).
602    inplace: bool,
603    _marker: std::marker::PhantomData<T>,
604}
605
606impl<T: Float> Dropout2d<T> {
607    /// Create a new `Dropout2d` layer.
608    ///
609    /// `p` is the probability of an entire channel being zeroed. Must be in `[0, 1)`.
610    pub fn new(p: f64) -> FerrotorchResult<Self> {
611        if !(0.0..1.0).contains(&p) {
612            return Err(FerrotorchError::InvalidArgument {
613                message: format!("dropout2d probability must be in [0, 1), got {p}"),
614            });
615        }
616        Ok(Self {
617            p,
618            training: true,
619            inplace: false,
620            _marker: std::marker::PhantomData,
621        })
622    }
623
624    /// Set the `inplace` flag, mirroring `torch.nn.Dropout2d(p, inplace=...)`.
625    /// When `true`, training-mode forward mutates the input storage.
626    #[must_use]
627    pub fn with_inplace(mut self, inplace: bool) -> Self {
628        self.inplace = inplace;
629        self
630    }
631
632    /// Returns the `inplace` flag.
633    pub fn inplace(&self) -> bool {
634        self.inplace
635    }
636}
637
638impl<T: Float> Module<T> for Dropout2d<T> {
639    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
640        // Eval mode or p == 0: identity.
641        if !self.training || self.p == 0.0 {
642            return Ok(input.clone());
643        }
644
645        let shape = input.shape();
646        if shape.len() < 2 {
647            return Err(FerrotorchError::InvalidArgument {
648                message: format!(
649                    "Dropout2d expects at least 2D input [B, C, ...], got shape {:?}",
650                    shape
651                ),
652            });
653        }
654
655        let batch = shape[0];
656        let channels = shape[1];
657        // Product of empty slice is 1, so no special case needed for 2-D inputs.
658        let spatial: usize = shape[2..].iter().product();
659
660        let numel = input.numel();
661        let scale = T::from(1.0 / (1.0 - self.p)).unwrap();
662        let zero = <T as num_traits::Zero>::zero();
663
664        // GPU tensors are not yet supported for Dropout2d — needs a fused
665        // channel-broadcast dropout kernel.
666        if input.is_cuda() {
667            return Err(FerrotorchError::NotImplementedOnCuda { op: "Dropout2d" });
668        }
669
670        // CPU path — draw the per-channel keep mask from the byte-exact
671        // MT19937 `Generator` (`ferrotorch_core::rng`), matching torch's
672        // `make_feature_noise(input).bernoulli_(1 - p)`
673        // (`aten/src/ATen/native/Dropout.cpp:73-74`). torch reduces the input
674        // to a `[N, C, 1, 1...]` noise tensor and draws ONE Bernoulli per
675        // `[N, C]` entry in flat order, then broadcasts over the spatial dims
676        // and scales survivors by `1/(1-p)` (`Dropout.cpp:81` `noise.div_(1-p)`).
677        // The scalar bernoulli kernel keeps iff `next_uniform_f64() < (1 - p)`
678        // (`DistributionTemplates.h` / `TransformationHelper.h:171-173`), so a
679        // shared `ferrotorch_core::manual_seed(s)` reproduces
680        // `torch.manual_seed(s); F.dropout2d(...)` byte-for-byte (#1635).
681        let keep_prob = 1.0 - self.p;
682        let channel_mask: Vec<bool> = ferrotorch_core::rng::with_thread_rng(|g| {
683            (0..batch * channels)
684                .map(|_| g.next_uniform_f64() < keep_prob)
685                .collect()
686        });
687
688        // Expand channel mask to full element mask.
689        let scaled_mask: Vec<T> = {
690            let mut mask = Vec::with_capacity(numel);
691            for &cm in &channel_mask {
692                let val = if cm { scale } else { zero };
693                for _ in 0..spatial {
694                    mask.push(val);
695                }
696            }
697            mask
698        };
699
700        let input_data = input.data_vec()?;
701        let output_data: Vec<T> = input_data
702            .iter()
703            .zip(scaled_mask.iter())
704            .map(|(&x, &m)| x * m)
705            .collect();
706
707        // In-place branch mirrors `_VF.feature_dropout_` at
708        // `torch/nn/functional.py:1579`. Routed through the autograd-safe
709        // policy (`apply_inplace_dropout`): errors on a grad-requiring leaf,
710        // falls back to out-of-place on a grad-requiring non-leaf, mutates only
711        // when no autograd node observes the storage.
712        if self.inplace {
713            apply_inplace_dropout(input, &output_data)?;
714        }
715
716        let result = if is_grad_enabled() && input.requires_grad() {
717            Tensor::from_operation(
718                TensorStorage::cpu(output_data),
719                input.shape().to_vec(),
720                Arc::new(Dropout2dBackward {
721                    input: input.clone(),
722                    scaled_mask,
723                }),
724            )?
725        } else {
726            Tensor::from_storage(
727                TensorStorage::cpu(output_data),
728                input.shape().to_vec(),
729                false,
730            )?
731        };
732        Ok(result)
733    }
734
735    fn parameters(&self) -> Vec<&Parameter<T>> {
736        vec![]
737    }
738
739    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
740        vec![]
741    }
742
743    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
744        vec![]
745    }
746
747    fn train(&mut self) {
748        self.training = true;
749    }
750
751    fn eval(&mut self) {
752        self.training = false;
753    }
754
755    fn is_training(&self) -> bool {
756        self.training
757    }
758}
759
760// ===========================================================================
761// Dropout1d — CL-433
762// ===========================================================================
763
764/// Randomly zeroes entire 1D channels with probability `p` during training.
765///
766/// Expects input of shape `[B, C, L]` (3 dimensions). During training,
767/// each channel (the entire length-`L` slice for a given `b, c`) is
768/// independently set to zero with probability `p` and surviving channels
769/// are scaled by `1/(1-p)`. During evaluation the input is returned unchanged.
770///
771/// This is the 1D analogue of [`Dropout2d`].
772///
773/// Matches `torch.nn.Dropout1d`.
774#[derive(Debug)]
775pub struct Dropout1d<T: Float> {
776    p: f64,
777    training: bool,
778    /// In-place flag, mirroring `_DropoutNd.inplace` at
779    /// `torch/nn/modules/dropout.py:29` and the `inplace` branch of
780    /// `F.dropout1d` at `torch/nn/functional.py:1515-1519`
781    /// (`_VF.feature_dropout_(input, p, training) if inplace`).
782    inplace: bool,
783    _marker: std::marker::PhantomData<T>,
784}
785
786impl<T: Float> Dropout1d<T> {
787    /// Create a new `Dropout1d` layer.
788    ///
789    /// `p` is the probability of an entire channel being zeroed. Must be in `[0, 1)`.
790    pub fn new(p: f64) -> FerrotorchResult<Self> {
791        if !(0.0..1.0).contains(&p) {
792            return Err(FerrotorchError::InvalidArgument {
793                message: format!("dropout1d probability must be in [0, 1), got {p}"),
794            });
795        }
796        Ok(Self {
797            p,
798            training: true,
799            inplace: false,
800            _marker: std::marker::PhantomData,
801        })
802    }
803
804    /// Set the `inplace` flag, mirroring `torch.nn.Dropout1d(p, inplace=...)`.
805    /// When `true`, training-mode forward mutates the input storage.
806    #[must_use]
807    pub fn with_inplace(mut self, inplace: bool) -> Self {
808        self.inplace = inplace;
809        self
810    }
811
812    /// Returns the `inplace` flag.
813    pub fn inplace(&self) -> bool {
814        self.inplace
815    }
816}
817
818impl<T: Float> Module<T> for Dropout1d<T> {
819    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
820        if !self.training || self.p == 0.0 {
821            return Ok(input.clone());
822        }
823
824        let shape = input.shape();
825        if shape.len() != 3 {
826            return Err(FerrotorchError::InvalidArgument {
827                message: format!(
828                    "Dropout1d expects 3D input [B, C, L], got shape {:?}",
829                    shape
830                ),
831            });
832        }
833
834        let batch = shape[0];
835        let channels = shape[1];
836        let length = shape[2];
837
838        let numel = input.numel();
839        let scale = T::from(1.0 / (1.0 - self.p)).unwrap();
840        let zero = <T as num_traits::Zero>::zero();
841
842        if input.is_cuda() {
843            return Err(FerrotorchError::NotImplementedOnCuda { op: "Dropout1d" });
844        }
845
846        // Per-channel keep mask from the byte-exact MT19937 `Generator`,
847        // matching torch's `make_feature_noise(input).bernoulli_(1 - p)`
848        // (`aten/src/ATen/native/Dropout.cpp:73-74`): one Bernoulli draw per
849        // `[N, C]` channel in flat order, keep iff `next_uniform_f64() < (1-p)`,
850        // broadcast over the length-`L` dim, survivors scaled by `1/(1-p)`.
851        // Reproducible under `ferrotorch_core::manual_seed` (#1635).
852        let keep_prob = 1.0 - self.p;
853        let channel_mask: Vec<bool> = ferrotorch_core::rng::with_thread_rng(|g| {
854            (0..batch * channels)
855                .map(|_| g.next_uniform_f64() < keep_prob)
856                .collect()
857        });
858
859        let scaled_mask: Vec<T> = {
860            let mut mask = Vec::with_capacity(numel);
861            for &cm in &channel_mask {
862                let val = if cm { scale } else { zero };
863                for _ in 0..length {
864                    mask.push(val);
865                }
866            }
867            mask
868        };
869
870        let input_data = input.data_vec()?;
871        let output_data: Vec<T> = input_data
872            .iter()
873            .zip(scaled_mask.iter())
874            .map(|(&x, &m)| x * m)
875            .collect();
876
877        // In-place branch mirrors `_VF.feature_dropout_` at
878        // `torch/nn/functional.py:1516`. Routed through the autograd-safe
879        // policy (`apply_inplace_dropout`): errors on a grad-requiring leaf,
880        // falls back to out-of-place on a grad-requiring non-leaf, mutates only
881        // when no autograd node observes the storage.
882        if self.inplace {
883            apply_inplace_dropout(input, &output_data)?;
884        }
885
886        let result = if is_grad_enabled() && input.requires_grad() {
887            Tensor::from_operation(
888                TensorStorage::cpu(output_data),
889                input.shape().to_vec(),
890                Arc::new(Dropout2dBackward {
891                    input: input.clone(),
892                    scaled_mask,
893                }),
894            )?
895        } else {
896            Tensor::from_storage(
897                TensorStorage::cpu(output_data),
898                input.shape().to_vec(),
899                false,
900            )?
901        };
902        Ok(result)
903    }
904
905    fn parameters(&self) -> Vec<&Parameter<T>> {
906        vec![]
907    }
908
909    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
910        vec![]
911    }
912
913    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
914        vec![]
915    }
916
917    fn train(&mut self) {
918        self.training = true;
919    }
920
921    fn eval(&mut self) {
922        self.training = false;
923    }
924
925    fn is_training(&self) -> bool {
926        self.training
927    }
928}
929
930// ===========================================================================
931// Dropout3d — CL-433
932// ===========================================================================
933
934/// Randomly zeroes entire 3D channels with probability `p` during training.
935///
936/// Expects input of shape `[B, C, D, H, W]` (5 dimensions). During training,
937/// each channel (the entire `D * H * W` volume for a given `b, c`) is
938/// independently set to zero with probability `p` and surviving channels
939/// are scaled by `1/(1-p)`. During evaluation the input is returned unchanged.
940///
941/// Matches `torch.nn.Dropout3d`.
942#[derive(Debug)]
943pub struct Dropout3d<T: Float> {
944    p: f64,
945    training: bool,
946    /// In-place flag, mirroring `_DropoutNd.inplace` at
947    /// `torch/nn/modules/dropout.py:29` and the `inplace` branch of
948    /// `F.dropout3d` at `torch/nn/functional.py:1628-1632`
949    /// (`_VF.feature_dropout_(input, p, training) if inplace`).
950    inplace: bool,
951    _marker: std::marker::PhantomData<T>,
952}
953
954impl<T: Float> Dropout3d<T> {
955    /// Create a new `Dropout3d` layer.
956    ///
957    /// `p` is the probability of an entire channel being zeroed. Must be in `[0, 1)`.
958    pub fn new(p: f64) -> FerrotorchResult<Self> {
959        if !(0.0..1.0).contains(&p) {
960            return Err(FerrotorchError::InvalidArgument {
961                message: format!("dropout3d probability must be in [0, 1), got {p}"),
962            });
963        }
964        Ok(Self {
965            p,
966            training: true,
967            inplace: false,
968            _marker: std::marker::PhantomData,
969        })
970    }
971
972    /// Set the `inplace` flag, mirroring `torch.nn.Dropout3d(p, inplace=...)`.
973    /// When `true`, training-mode forward mutates the input storage.
974    #[must_use]
975    pub fn with_inplace(mut self, inplace: bool) -> Self {
976        self.inplace = inplace;
977        self
978    }
979
980    /// Returns the `inplace` flag.
981    pub fn inplace(&self) -> bool {
982        self.inplace
983    }
984}
985
986impl<T: Float> Module<T> for Dropout3d<T> {
987    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
988        if !self.training || self.p == 0.0 {
989            return Ok(input.clone());
990        }
991
992        let shape = input.shape();
993        if shape.len() != 5 {
994            return Err(FerrotorchError::InvalidArgument {
995                message: format!(
996                    "Dropout3d expects 5D input [B, C, D, H, W], got shape {:?}",
997                    shape
998                ),
999            });
1000        }
1001
1002        let batch = shape[0];
1003        let channels = shape[1];
1004        let spatial: usize = shape[2..].iter().product();
1005
1006        let numel = input.numel();
1007        let scale = T::from(1.0 / (1.0 - self.p)).unwrap();
1008        let zero = <T as num_traits::Zero>::zero();
1009
1010        if input.is_cuda() {
1011            return Err(FerrotorchError::NotImplementedOnCuda { op: "Dropout3d" });
1012        }
1013
1014        // Per-channel keep mask from the byte-exact MT19937 `Generator`,
1015        // matching torch's `make_feature_noise(input).bernoulli_(1 - p)`
1016        // (`aten/src/ATen/native/Dropout.cpp:73-74`): one Bernoulli draw per
1017        // `[N, C]` channel in flat order, keep iff `next_uniform_f64() < (1-p)`,
1018        // broadcast over the `D*H*W` volume, survivors scaled by `1/(1-p)`.
1019        // Reproducible under `ferrotorch_core::manual_seed` (#1635).
1020        let keep_prob = 1.0 - self.p;
1021        let channel_mask: Vec<bool> = ferrotorch_core::rng::with_thread_rng(|g| {
1022            (0..batch * channels)
1023                .map(|_| g.next_uniform_f64() < keep_prob)
1024                .collect()
1025        });
1026
1027        let scaled_mask: Vec<T> = {
1028            let mut mask = Vec::with_capacity(numel);
1029            for &cm in &channel_mask {
1030                let val = if cm { scale } else { zero };
1031                for _ in 0..spatial {
1032                    mask.push(val);
1033                }
1034            }
1035            mask
1036        };
1037
1038        let input_data = input.data_vec()?;
1039        let output_data: Vec<T> = input_data
1040            .iter()
1041            .zip(scaled_mask.iter())
1042            .map(|(&x, &m)| x * m)
1043            .collect();
1044
1045        // In-place branch mirrors `_VF.feature_dropout_` at
1046        // `torch/nn/functional.py:1629`. Routed through the autograd-safe
1047        // policy (`apply_inplace_dropout`): errors on a grad-requiring leaf,
1048        // falls back to out-of-place on a grad-requiring non-leaf, mutates only
1049        // when no autograd node observes the storage.
1050        if self.inplace {
1051            apply_inplace_dropout(input, &output_data)?;
1052        }
1053
1054        let result = if is_grad_enabled() && input.requires_grad() {
1055            Tensor::from_operation(
1056                TensorStorage::cpu(output_data),
1057                input.shape().to_vec(),
1058                Arc::new(Dropout2dBackward {
1059                    input: input.clone(),
1060                    scaled_mask,
1061                }),
1062            )?
1063        } else {
1064            Tensor::from_storage(
1065                TensorStorage::cpu(output_data),
1066                input.shape().to_vec(),
1067                false,
1068            )?
1069        };
1070        Ok(result)
1071    }
1072
1073    fn parameters(&self) -> Vec<&Parameter<T>> {
1074        vec![]
1075    }
1076
1077    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1078        vec![]
1079    }
1080
1081    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1082        vec![]
1083    }
1084
1085    fn train(&mut self) {
1086        self.training = true;
1087    }
1088
1089    fn eval(&mut self) {
1090        self.training = false;
1091    }
1092
1093    fn is_training(&self) -> bool {
1094        self.training
1095    }
1096}
1097
1098// ===========================================================================
1099// AlphaDropout — CL-433
1100// ===========================================================================
1101
1102/// Alpha Dropout for use with SELU activations.
1103///
1104/// Unlike standard dropout, `AlphaDropout` preserves the self-normalizing
1105/// property of SELU by maintaining the mean and variance of the input.
1106/// Dropped elements are set to the SELU saturation value rather than zero,
1107/// and the output is affinely transformed to restore the original mean and
1108/// variance.
1109///
1110/// During training, mirroring `aten/src/ATen/native/Dropout.cpp:74-79`:
1111/// 1. A per-element Bernoulli keep-mask is drawn at probability `1 - p` from
1112///    the byte-exact MT19937 `Generator` (keep iff `next_uniform_f64() < 1-p`).
1113/// 2. With `alpha = 1.7580993408473766` and
1114///    `a = 1/sqrt((alpha^2 * p + 1) * (1 - p))`:
1115///    - kept elements map to `a*x + alpha*a*p`,
1116///    - dropped elements map to the constant `-alpha*a + alpha*a*p`.
1117///
1118/// During evaluation, the input is returned unchanged.
1119///
1120/// Matches `torch.nn.AlphaDropout`. Reproducible under
1121/// `ferrotorch_core::manual_seed` (#1636).
1122#[derive(Debug)]
1123pub struct AlphaDropout<T: Float> {
1124    p: f64,
1125    training: bool,
1126    /// In-place flag, carried for API parity with `_DropoutNd.inplace`
1127    /// (`torch/nn/modules/dropout.py:29`).
1128    ///
1129    /// NOTE — faithful upstream behaviour: `AlphaDropout.forward` at
1130    /// `torch/nn/modules/dropout.py:265-269` calls
1131    /// `F.alpha_dropout(input, self.p, self.training)` and does **not** pass
1132    /// `self.inplace`, so torch's `nn.AlphaDropout(p, inplace=True)` does NOT
1133    /// mutate in place at the module level — the `inplace` field exists on the
1134    /// struct (inherited from `_DropoutNd.__init__`) but the module forward
1135    /// drops it. We mirror that exactly: the field is stored for ABI parity,
1136    /// but [`AlphaDropout::forward`] never mutates the input. (The functional
1137    /// `F.alpha_dropout` does accept `inplace`, but the module never forwards
1138    /// it.)
1139    inplace: bool,
1140    _marker: std::marker::PhantomData<T>,
1141}
1142
1143/// The alpha-dropout affine constant torch hardcodes at
1144/// `aten/src/ATen/native/Dropout.cpp:76`
1145/// (`constexpr double alpha = 1.7580993408473766;`). This is the SELU-derived
1146/// `lambda * alpha` magnitude, but used VERBATIM as torch's literal — NOT
1147/// recomputed as `SELU_LAMBDA * SELU_ALPHA`, which differs in the last ULPs and
1148/// would shift the affine away from torch byte-for-byte (#1636).
1149const ALPHA_DROPOUT_ALPHA: f64 = 1.7580993408473766;
1150
1151impl<T: Float> AlphaDropout<T> {
1152    /// Create a new `AlphaDropout` layer.
1153    ///
1154    /// `p` is the probability of an element being dropped. Must be in `[0, 1)`.
1155    pub fn new(p: f64) -> FerrotorchResult<Self> {
1156        if !(0.0..1.0).contains(&p) {
1157            return Err(FerrotorchError::InvalidArgument {
1158                message: format!("alpha_dropout probability must be in [0, 1), got {p}"),
1159            });
1160        }
1161        Ok(Self {
1162            p,
1163            training: true,
1164            inplace: false,
1165            _marker: std::marker::PhantomData,
1166        })
1167    }
1168
1169    /// Set the `inplace` flag for API parity with
1170    /// `torch.nn.AlphaDropout(p, inplace=...)`.
1171    ///
1172    /// Like upstream, the module `forward` does NOT mutate in place even when
1173    /// this is `true` — `torch.nn.AlphaDropout.forward` never forwards
1174    /// `self.inplace` to `F.alpha_dropout` (`dropout.py:265-269`). The flag is
1175    /// retained so the constructor surface matches torch field-for-field.
1176    #[must_use]
1177    pub fn with_inplace(mut self, inplace: bool) -> Self {
1178        self.inplace = inplace;
1179        self
1180    }
1181
1182    /// Returns the `inplace` flag.
1183    pub fn inplace(&self) -> bool {
1184        self.inplace
1185    }
1186}
1187
1188/// Backward node for AlphaDropout.
1189///
1190/// The affine correction factor `a` is baked into the scaled_mask:
1191/// surviving elements get `a`, dropped elements get `0`.
1192/// Gradient routing: grad_input = grad_output * scaled_mask.
1193#[derive(Debug)]
1194struct AlphaDropoutBackward<T: Float> {
1195    input: Tensor<T>,
1196    /// Mask with `a` for kept elements and `0` for dropped elements.
1197    grad_mask: Vec<T>,
1198}
1199
1200impl<T: Float> GradFn<T> for AlphaDropoutBackward<T> {
1201    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1202        if grad_output.is_cuda() {
1203            return Err(FerrotorchError::NotImplementedOnCuda {
1204                op: "AlphaDropout backward",
1205            });
1206        }
1207        let da = if self.input.requires_grad() {
1208            let go_data = grad_output.data_vec()?;
1209            let grad_a: Vec<T> = go_data
1210                .iter()
1211                .zip(self.grad_mask.iter())
1212                .map(|(&g, &m)| g * m)
1213                .collect();
1214            let g = Tensor::from_storage(
1215                TensorStorage::cpu(grad_a),
1216                self.input.shape().to_vec(),
1217                false,
1218            )?;
1219            Some(g)
1220        } else {
1221            None
1222        };
1223        Ok(vec![da])
1224    }
1225
1226    fn inputs(&self) -> Vec<&Tensor<T>> {
1227        vec![&self.input]
1228    }
1229
1230    fn name(&self) -> &'static str {
1231        "AlphaDropoutBackward"
1232    }
1233}
1234
1235impl<T: Float> Module<T> for AlphaDropout<T> {
1236    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1237        if !self.training || self.p == 0.0 {
1238            return Ok(input.clone());
1239        }
1240
1241        if input.is_cuda() {
1242            return Err(FerrotorchError::NotImplementedOnCuda { op: "AlphaDropout" });
1243        }
1244
1245        let numel = input.numel();
1246        let p = self.p;
1247
1248        // torch's EXACT alpha affine, `aten/src/ATen/native/Dropout.cpp:74-79`:
1249        //   noise.bernoulli_(1 - p)                 // 1.0 kept, 0.0 dropped
1250        //   constexpr double alpha = 1.7580993408473766;
1251        //   double a = 1. / sqrt((alpha*alpha*p + 1) * (1 - p));
1252        //   b = noise.add(-1).mul_(alpha*a).add_(alpha*a*p);
1253        //   noise.mul_(a);                          // a kept, 0 dropped
1254        //   out = input * noise + b
1255        // Folding the per-element `b = (noise-1)*alpha*a + alpha*a*p`:
1256        //   kept  (noise=1): out = a*x + alpha*a*p
1257        //   dropped(noise=0): out = -alpha*a + alpha*a*p   (constant in x)
1258        // We use torch's hardcoded `alpha` constant verbatim — NOT a recomputed
1259        // `-SELU_LAMBDA*SELU_ALPHA` (= -1.7580993..., same magnitude but the
1260        // recomputed value diverges in the last ULPs and changes the affine).
1261        let alpha = ALPHA_DROPOUT_ALPHA;
1262        let a_f64 = 1.0 / ((alpha * alpha * p + 1.0) * (1.0 - p)).sqrt();
1263        let dropped_f64 = -alpha * a_f64 + alpha * a_f64 * p;
1264        let kept_b_f64 = alpha * a_f64 * p;
1265
1266        let a = T::from(a_f64).unwrap();
1267        let kept_b = T::from(kept_b_f64).unwrap();
1268        let dropped_v = T::from(dropped_f64).unwrap();
1269        let zero = <T as num_traits::Zero>::zero();
1270
1271        // Per-element keep mask from the byte-exact MT19937 `Generator`,
1272        // matching `at::empty_like(input).bernoulli_(1 - p)` (alpha_dropout is
1273        // element-wise, NOT feature noise; `Dropout.cpp:73`). Keep iff
1274        // `next_uniform_f64() < (1 - p)`; reproducible under
1275        // `ferrotorch_core::manual_seed` (#1636).
1276        let keep_prob = 1.0 - p;
1277        let keep: Vec<bool> = ferrotorch_core::rng::with_thread_rng(|g| {
1278            (0..numel)
1279                .map(|_| g.next_uniform_f64() < keep_prob)
1280                .collect()
1281        });
1282
1283        let input_data = input.data()?;
1284        let mut output_data = Vec::with_capacity(numel);
1285        let mut grad_mask = Vec::with_capacity(numel);
1286
1287        for (i, &x) in input_data.iter().enumerate() {
1288            if keep[i] {
1289                // Kept element: a * x + alpha*a*p
1290                output_data.push(a * x + kept_b);
1291                grad_mask.push(a);
1292            } else {
1293                // Dropped element: -alpha*a + alpha*a*p (independent of x).
1294                output_data.push(dropped_v);
1295                grad_mask.push(zero);
1296            }
1297        }
1298
1299        if is_grad_enabled() && input.requires_grad() {
1300            Tensor::from_operation(
1301                TensorStorage::cpu(output_data),
1302                input.shape().to_vec(),
1303                Arc::new(AlphaDropoutBackward {
1304                    input: input.clone(),
1305                    grad_mask,
1306                }),
1307            )
1308        } else {
1309            Tensor::from_storage(
1310                TensorStorage::cpu(output_data),
1311                input.shape().to_vec(),
1312                false,
1313            )
1314        }
1315    }
1316
1317    fn parameters(&self) -> Vec<&Parameter<T>> {
1318        vec![]
1319    }
1320
1321    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1322        vec![]
1323    }
1324
1325    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1326        vec![]
1327    }
1328
1329    fn train(&mut self) {
1330        self.training = true;
1331    }
1332
1333    fn eval(&mut self) {
1334        self.training = false;
1335    }
1336
1337    fn is_training(&self) -> bool {
1338        self.training
1339    }
1340}
1341
1342// ===========================================================================
1343// FeatureAlphaDropout — closes #1448
1344// ===========================================================================
1345
1346/// Randomly masks entire feature-channels with the SELU saturation value
1347/// during training, mirroring `torch.nn.FeatureAlphaDropout`
1348/// (`torch/nn/modules/dropout.py:233-281`).
1349///
1350/// Unlike [`AlphaDropout`], which drops individual elements, this layer
1351/// drops every spatial position within a `(b, c)` feature-channel as a unit
1352/// — the dropout decision is sampled once per channel and broadcast over
1353/// the trailing spatial dims. Used in self-normalising convolutional
1354/// networks where per-feature decorrelation must be preserved while
1355/// maintaining mean/variance.
1356///
1357/// During training, mirroring `aten/src/ATen/native/Dropout.cpp:73-79`
1358/// (`_dropout_impl<feature=true, alpha=true>`):
1359/// 1. A per-channel Bernoulli keep-mask is drawn over the reduced
1360///    `[N, C, 1, 1...]` noise tensor at probability `1 - p` from the
1361///    byte-exact MT19937 `Generator` (keep iff `next_uniform_f64() < 1-p`),
1362///    in flat `[N, C]` order, then broadcast over the spatial volume.
1363/// 2. With `alpha = 1.7580993408473766` and
1364///    `a = 1/sqrt((alpha^2 * p + 1) * (1 - p))`, kept channels map to
1365///    `a*x + alpha*a*p` and dropped channels to `-alpha*a + alpha*a*p`.
1366///
1367/// During evaluation, the input is returned unchanged.
1368///
1369/// Expects input of shape `[N, C, *]` (at least 2-D). Reproducible under
1370/// `ferrotorch_core::manual_seed` (#1636).
1371#[derive(Debug)]
1372pub struct FeatureAlphaDropout<T: Float> {
1373    p: f64,
1374    training: bool,
1375    /// In-place flag, carried for API parity with `_DropoutNd.inplace`
1376    /// (`torch/nn/modules/dropout.py:29`).
1377    ///
1378    /// NOTE — faithful upstream behaviour: `FeatureAlphaDropout.forward` at
1379    /// `torch/nn/modules/dropout.py:319-323` calls
1380    /// `F.feature_alpha_dropout(input, self.p, self.training)` and does **not**
1381    /// pass `self.inplace`, so torch's `nn.FeatureAlphaDropout(p,
1382    /// inplace=True)` does NOT mutate in place at the module level. We mirror
1383    /// that exactly: the field is stored for ABI parity, but
1384    /// [`FeatureAlphaDropout::forward`] never mutates the input.
1385    inplace: bool,
1386    _marker: std::marker::PhantomData<T>,
1387}
1388
1389impl<T: Float> FeatureAlphaDropout<T> {
1390    /// Create a new `FeatureAlphaDropout` layer.
1391    ///
1392    /// `p` is the probability of an entire feature-channel being dropped.
1393    /// Must be in `[0, 1)`.
1394    pub fn new(p: f64) -> FerrotorchResult<Self> {
1395        if !(0.0..1.0).contains(&p) {
1396            return Err(FerrotorchError::InvalidArgument {
1397                message: format!("feature_alpha_dropout probability must be in [0, 1), got {p}"),
1398            });
1399        }
1400        Ok(Self {
1401            p,
1402            training: true,
1403            inplace: false,
1404            _marker: std::marker::PhantomData,
1405        })
1406    }
1407
1408    /// Set the `inplace` flag for API parity with
1409    /// `torch.nn.FeatureAlphaDropout(p, inplace=...)`.
1410    ///
1411    /// Like upstream, the module `forward` does NOT mutate in place even when
1412    /// this is `true` — `torch.nn.FeatureAlphaDropout.forward` never forwards
1413    /// `self.inplace` to `F.feature_alpha_dropout` (`dropout.py:319-323`).
1414    #[must_use]
1415    pub fn with_inplace(mut self, inplace: bool) -> Self {
1416        self.inplace = inplace;
1417        self
1418    }
1419
1420    /// Returns the `inplace` flag.
1421    pub fn inplace(&self) -> bool {
1422        self.inplace
1423    }
1424}
1425
1426/// Backward node for `FeatureAlphaDropout`.
1427///
1428/// The affine factor `a` is baked into the broadcast mask: kept channels
1429/// receive `a`, dropped channels receive `0`. Gradient routes as
1430/// `grad_input = grad_output * grad_mask`.
1431#[derive(Debug)]
1432struct FeatureAlphaDropoutBackward<T: Float> {
1433    input: Tensor<T>,
1434    /// Full-shape mask with `a` for kept channels, `0` for dropped.
1435    grad_mask: Vec<T>,
1436}
1437
1438impl<T: Float> GradFn<T> for FeatureAlphaDropoutBackward<T> {
1439    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1440        if grad_output.is_cuda() {
1441            return Err(FerrotorchError::NotImplementedOnCuda {
1442                op: "FeatureAlphaDropout backward",
1443            });
1444        }
1445        let da = if self.input.requires_grad() {
1446            let go_data = grad_output.data_vec()?;
1447            let grad_a: Vec<T> = go_data
1448                .iter()
1449                .zip(self.grad_mask.iter())
1450                .map(|(&g, &m)| g * m)
1451                .collect();
1452            let g = Tensor::from_storage(
1453                TensorStorage::cpu(grad_a),
1454                self.input.shape().to_vec(),
1455                false,
1456            )?;
1457            Some(g)
1458        } else {
1459            None
1460        };
1461        Ok(vec![da])
1462    }
1463
1464    fn inputs(&self) -> Vec<&Tensor<T>> {
1465        vec![&self.input]
1466    }
1467
1468    fn name(&self) -> &'static str {
1469        "FeatureAlphaDropoutBackward"
1470    }
1471}
1472
1473impl<T: Float> Module<T> for FeatureAlphaDropout<T> {
1474    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1475        if !self.training || self.p == 0.0 {
1476            return Ok(input.clone());
1477        }
1478
1479        let shape = input.shape();
1480        if shape.len() < 2 {
1481            return Err(FerrotorchError::InvalidArgument {
1482                message: format!(
1483                    "FeatureAlphaDropout expects at least 2D input [N, C, ...], got shape {:?}",
1484                    shape
1485                ),
1486            });
1487        }
1488
1489        if input.is_cuda() {
1490            return Err(FerrotorchError::NotImplementedOnCuda {
1491                op: "FeatureAlphaDropout",
1492            });
1493        }
1494
1495        let batch = shape[0];
1496        let channels = shape[1];
1497        // Spatial dims (D, H, W, ...). For a 2-D `[N, C]` input the product
1498        // of the empty suffix is 1, matching torch's broadcast behaviour.
1499        let spatial: usize = shape[2..].iter().product();
1500
1501        let numel = input.numel();
1502        let p = self.p;
1503
1504        // torch's EXACT feature-alpha affine: `feature_alpha_dropout` calls
1505        // `_dropout_impl<feature=true, alpha=true>`, so the noise is a
1506        // PER-CHANNEL `make_feature_noise` tensor (`Dropout.cpp:73`) drawn with
1507        // `bernoulli_(1 - p)`, then the alpha affine
1508        // (`Dropout.cpp:76-79`): `alpha = 1.7580993408473766`,
1509        // `a = 1/sqrt((alpha^2*p + 1)*(1-p))`,
1510        // kept (noise=1) -> `a*x + alpha*a*p`,
1511        // dropped (noise=0) -> `-alpha*a + alpha*a*p` (constant in x).
1512        let alpha = ALPHA_DROPOUT_ALPHA;
1513        let a_f64 = 1.0 / ((alpha * alpha * p + 1.0) * (1.0 - p)).sqrt();
1514        let dropped_f64 = -alpha * a_f64 + alpha * a_f64 * p;
1515        let kept_b_f64 = alpha * a_f64 * p;
1516
1517        let a = T::from(a_f64).unwrap();
1518        let kept_b = T::from(kept_b_f64).unwrap();
1519        let dropped_v = T::from(dropped_f64).unwrap();
1520        let zero = <T as num_traits::Zero>::zero();
1521
1522        // Per-channel keep mask: one Bernoulli draw per `[N, C]` entry in flat
1523        // order from the byte-exact MT19937 `Generator`, keep iff
1524        // `next_uniform_f64() < (1 - p)`, broadcast over the trailing spatial
1525        // volume. Reproducible under `ferrotorch_core::manual_seed` (#1636).
1526        let keep_prob = 1.0 - p;
1527        let keep_channel: Vec<bool> = ferrotorch_core::rng::with_thread_rng(|g| {
1528            (0..batch * channels)
1529                .map(|_| g.next_uniform_f64() < keep_prob)
1530                .collect()
1531        });
1532
1533        let input_data = input.data_vec()?;
1534        let mut output_data = Vec::with_capacity(numel);
1535        let mut grad_mask = Vec::with_capacity(numel);
1536
1537        // For each channel: emit `spatial` masked elements at once.
1538        for bc in 0..batch * channels {
1539            let keep = keep_channel[bc];
1540            let base = bc * spatial;
1541            for s in 0..spatial {
1542                let x = input_data[base + s];
1543                if keep {
1544                    output_data.push(a * x + kept_b);
1545                    grad_mask.push(a);
1546                } else {
1547                    output_data.push(dropped_v);
1548                    grad_mask.push(zero);
1549                }
1550            }
1551        }
1552
1553        if is_grad_enabled() && input.requires_grad() {
1554            Tensor::from_operation(
1555                TensorStorage::cpu(output_data),
1556                input.shape().to_vec(),
1557                Arc::new(FeatureAlphaDropoutBackward {
1558                    input: input.clone(),
1559                    grad_mask,
1560                }),
1561            )
1562        } else {
1563            Tensor::from_storage(
1564                TensorStorage::cpu(output_data),
1565                input.shape().to_vec(),
1566                false,
1567            )
1568        }
1569    }
1570
1571    fn parameters(&self) -> Vec<&Parameter<T>> {
1572        vec![]
1573    }
1574
1575    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1576        vec![]
1577    }
1578
1579    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1580        vec![]
1581    }
1582
1583    fn train(&mut self) {
1584        self.training = true;
1585    }
1586
1587    fn eval(&mut self) {
1588        self.training = false;
1589    }
1590
1591    fn is_training(&self) -> bool {
1592        self.training
1593    }
1594}
1595
1596// ===========================================================================
1597// Tests
1598// ===========================================================================
1599
1600#[cfg(test)]
1601mod tests {
1602    use super::*;
1603
1604    /// Create a leaf tensor with given data and shape.
1605    fn leaf_tensor(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
1606        Tensor::from_storage(
1607            TensorStorage::cpu(data.to_vec()),
1608            shape.to_vec(),
1609            requires_grad,
1610        )
1611        .unwrap()
1612    }
1613
1614    // -----------------------------------------------------------------------
1615    // Dropout
1616    // -----------------------------------------------------------------------
1617
1618    #[test]
1619    fn test_dropout_rate_approximately_correct() {
1620        let d = Dropout::<f32>::new(0.5).unwrap();
1621        let input = ferrotorch_core::ones::<f32>(&[100_000]).unwrap();
1622        let output = d.forward(&input).unwrap();
1623        let data = output.data().unwrap();
1624
1625        // Count zeros — should be roughly 50%.
1626        let zeros = data.iter().filter(|&&x| x == 0.0).count();
1627        let rate = zeros as f64 / data.len() as f64;
1628        assert!(
1629            (rate - 0.5).abs() < 0.05,
1630            "dropout rate = {rate}, expected ~0.5"
1631        );
1632
1633        // Surviving elements should be scaled by 1/(1-0.5) = 2.0.
1634        let non_zero: Vec<f32> = data.iter().copied().filter(|&x| x != 0.0).collect();
1635        assert!(!non_zero.is_empty());
1636        for &v in &non_zero {
1637            assert!(
1638                (v - 2.0).abs() < 1e-6,
1639                "surviving element = {v}, expected 2.0"
1640            );
1641        }
1642    }
1643
1644    #[test]
1645    fn test_dropout_eval_is_identity() {
1646        let mut d = Dropout::<f32>::new(0.5).unwrap();
1647        d.eval();
1648        assert!(!d.is_training());
1649
1650        let input = ferrotorch_core::ones::<f32>(&[100]).unwrap();
1651        let output = d.forward(&input).unwrap();
1652
1653        // In eval mode the output should be the exact same Arc (identity).
1654        assert!(output.is_same(&input));
1655    }
1656
1657    #[test]
1658    fn test_dropout_zero_prob_is_identity() {
1659        let d = Dropout::<f32>::new(0.0).unwrap();
1660        let input = ferrotorch_core::ones::<f32>(&[100]).unwrap();
1661        let output = d.forward(&input).unwrap();
1662        assert!(output.is_same(&input));
1663    }
1664
1665    #[test]
1666    fn test_dropout_invalid_p() {
1667        assert!(Dropout::<f32>::new(1.0).is_err());
1668        assert!(Dropout::<f32>::new(-0.1).is_err());
1669        assert!(Dropout::<f32>::new(1.5).is_err());
1670    }
1671
1672    #[test]
1673    fn test_dropout_backward_routes_through_surviving() {
1674        let d = Dropout::<f32>::new(0.5).unwrap();
1675        let input = leaf_tensor(&[1.0; 1000], &[1000], true);
1676        let output = d.forward(&input).unwrap();
1677
1678        // To backward we need a scalar loss. Sum the output manually.
1679        let out_data = output.data().unwrap().to_vec();
1680        let total: f32 = out_data.iter().sum();
1681
1682        // Build a SumBackward so we can call backward.
1683        #[derive(Debug)]
1684        struct SumBackward<T: Float> {
1685            input: Tensor<T>,
1686        }
1687        impl<T: Float> GradFn<T> for SumBackward<T> {
1688            fn backward(
1689                &self,
1690                _grad_output: &Tensor<T>,
1691            ) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1692                let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
1693                let t = Tensor::from_storage(
1694                    TensorStorage::cpu(ones),
1695                    self.input.shape().to_vec(),
1696                    false,
1697                )?;
1698                Ok(vec![Some(t)])
1699            }
1700            fn inputs(&self) -> Vec<&Tensor<T>> {
1701                vec![&self.input]
1702            }
1703            fn name(&self) -> &'static str {
1704                "SumBackward"
1705            }
1706        }
1707
1708        let loss = Tensor::from_operation(
1709            TensorStorage::cpu(vec![total]),
1710            vec![],
1711            Arc::new(SumBackward {
1712                input: output.clone(),
1713            }),
1714        )
1715        .unwrap();
1716        loss.backward().unwrap();
1717
1718        let grad = input.grad().unwrap().unwrap();
1719        let grad_data = grad.data().unwrap();
1720
1721        // Every gradient element should be either 0 (dropped) or 1/(1-p) = 2.0 (survived).
1722        for &g in grad_data {
1723            assert!(
1724                g == 0.0 || (g - 2.0).abs() < 1e-6,
1725                "gradient element = {g}, expected 0.0 or 2.0"
1726            );
1727        }
1728
1729        // The dropout mask for forward and backward should match: output zero
1730        // iff gradient zero.
1731        let out_data = output.data().unwrap();
1732        for (i, (&o, &g)) in out_data.iter().zip(grad_data.iter()).enumerate() {
1733            assert_eq!(
1734                o == 0.0,
1735                g == 0.0,
1736                "mismatch at index {i}: output={o}, grad={g}"
1737            );
1738        }
1739    }
1740
1741    #[test]
1742    fn test_dropout_no_parameters() {
1743        let d = Dropout::<f32>::new(0.3).unwrap();
1744        assert!(d.parameters().is_empty());
1745        assert!(d.named_parameters().is_empty());
1746    }
1747
1748    #[test]
1749    fn test_dropout_train_eval_toggle() {
1750        let mut d = Dropout::<f32>::new(0.5).unwrap();
1751        assert!(d.is_training());
1752        d.eval();
1753        assert!(!d.is_training());
1754        d.train();
1755        assert!(d.is_training());
1756    }
1757
1758    #[test]
1759    fn test_dropout_is_send_sync() {
1760        fn assert_send_sync<T: Send + Sync>() {}
1761        assert_send_sync::<Dropout<f32>>();
1762        assert_send_sync::<Dropout<f64>>();
1763    }
1764
1765    // -----------------------------------------------------------------------
1766    // Dropout2d
1767    // -----------------------------------------------------------------------
1768
1769    #[test]
1770    fn test_dropout2d_drops_whole_channels() {
1771        let d = Dropout2d::<f32>::new(0.5).unwrap();
1772        // Shape: [2, 10, 4, 4] — 2 batches, 10 channels, 4x4 spatial.
1773        let input = ferrotorch_core::ones::<f32>(&[2, 10, 4, 4]).unwrap();
1774        let output = d.forward(&input).unwrap();
1775        let data = output.data().unwrap();
1776
1777        let spatial = 4 * 4;
1778        // Check that each channel is either entirely zero or entirely scaled.
1779        for b in 0..2 {
1780            for c in 0..10 {
1781                let start = (b * 10 + c) * spatial;
1782                let end = start + spatial;
1783                let channel = &data[start..end];
1784
1785                let first = channel[0];
1786                assert!(
1787                    channel.iter().all(|&x| (x - first).abs() < 1e-6),
1788                    "channel (b={b}, c={c}) is not uniform: first={first}, channel={channel:?}"
1789                );
1790                // Value should be 0 or 1/(1-0.5) = 2.0.
1791                assert!(
1792                    first == 0.0 || (first - 2.0).abs() < 1e-6,
1793                    "channel value = {first}, expected 0.0 or 2.0"
1794                );
1795            }
1796        }
1797    }
1798
1799    #[test]
1800    fn test_dropout2d_rate_approximately_correct() {
1801        let d = Dropout2d::<f32>::new(0.5).unwrap();
1802        // Many channels to get a good statistical sample.
1803        let input = ferrotorch_core::ones::<f32>(&[1, 1000, 2, 2]).unwrap();
1804        let output = d.forward(&input).unwrap();
1805        let data = output.data().unwrap();
1806
1807        let spatial = 2 * 2;
1808        let mut dropped = 0;
1809        for c in 0..1000 {
1810            let start = c * spatial;
1811            if data[start] == 0.0 {
1812                dropped += 1;
1813            }
1814        }
1815        let rate = dropped as f64 / 1000.0;
1816        assert!(
1817            (rate - 0.5).abs() < 0.05,
1818            "dropout2d rate = {rate}, expected ~0.5"
1819        );
1820    }
1821
1822    #[test]
1823    fn test_dropout2d_eval_is_identity() {
1824        let mut d = Dropout2d::<f32>::new(0.5).unwrap();
1825        d.eval();
1826        let input = ferrotorch_core::ones::<f32>(&[2, 3, 4, 4]).unwrap();
1827        let output = d.forward(&input).unwrap();
1828        assert!(output.is_same(&input));
1829    }
1830
1831    #[test]
1832    fn test_dropout2d_invalid_p() {
1833        assert!(Dropout2d::<f32>::new(1.0).is_err());
1834        assert!(Dropout2d::<f32>::new(-0.1).is_err());
1835    }
1836
1837    #[test]
1838    fn test_dropout2d_requires_2d_input() {
1839        let d = Dropout2d::<f32>::new(0.3).unwrap();
1840        let input_1d = ferrotorch_core::ones::<f32>(&[10]).unwrap();
1841        assert!(d.forward(&input_1d).is_err());
1842    }
1843
1844    #[test]
1845    fn test_dropout2d_backward_routes_through_surviving_channels() {
1846        let d = Dropout2d::<f32>::new(0.5).unwrap();
1847        // [1, 20, 3, 3]
1848        let input = leaf_tensor(&[1.0; 20 * 3 * 3], &[1, 20, 3, 3], true);
1849        let output = d.forward(&input).unwrap();
1850
1851        let out_data = output.data().unwrap().to_vec();
1852        let total: f32 = out_data.iter().sum();
1853
1854        #[derive(Debug)]
1855        struct SumBackward<T: Float> {
1856            input: Tensor<T>,
1857        }
1858        impl<T: Float> GradFn<T> for SumBackward<T> {
1859            fn backward(
1860                &self,
1861                _grad_output: &Tensor<T>,
1862            ) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1863                let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
1864                let t = Tensor::from_storage(
1865                    TensorStorage::cpu(ones),
1866                    self.input.shape().to_vec(),
1867                    false,
1868                )?;
1869                Ok(vec![Some(t)])
1870            }
1871            fn inputs(&self) -> Vec<&Tensor<T>> {
1872                vec![&self.input]
1873            }
1874            fn name(&self) -> &'static str {
1875                "SumBackward"
1876            }
1877        }
1878
1879        let loss = Tensor::from_operation(
1880            TensorStorage::cpu(vec![total]),
1881            vec![],
1882            Arc::new(SumBackward {
1883                input: output.clone(),
1884            }),
1885        )
1886        .unwrap();
1887        loss.backward().unwrap();
1888
1889        let grad = input.grad().unwrap().unwrap();
1890        let grad_data = grad.data().unwrap();
1891        let out_data = output.data().unwrap();
1892
1893        // Gradient mask must match output mask.
1894        for (i, (&o, &g)) in out_data.iter().zip(grad_data.iter()).enumerate() {
1895            assert_eq!(
1896                o == 0.0,
1897                g == 0.0,
1898                "mismatch at index {i}: output={o}, grad={g}"
1899            );
1900        }
1901
1902        // Gradients should be channel-uniform.
1903        let spatial = 3 * 3;
1904        for c in 0..20 {
1905            let start = c * spatial;
1906            let end = start + spatial;
1907            let channel_grad = &grad_data[start..end];
1908            let first = channel_grad[0];
1909            assert!(
1910                channel_grad.iter().all(|&g| (g - first).abs() < 1e-6),
1911                "gradient channel {c} is not uniform"
1912            );
1913        }
1914    }
1915
1916    #[test]
1917    fn test_dropout2d_no_parameters() {
1918        let d = Dropout2d::<f32>::new(0.3).unwrap();
1919        assert!(d.parameters().is_empty());
1920        assert!(d.named_parameters().is_empty());
1921    }
1922
1923    #[test]
1924    fn test_dropout2d_is_send_sync() {
1925        fn assert_send_sync<T: Send + Sync>() {}
1926        assert_send_sync::<Dropout2d<f32>>();
1927        assert_send_sync::<Dropout2d<f64>>();
1928    }
1929
1930    // -----------------------------------------------------------------------
1931    // Dropout1d — CL-433
1932    // -----------------------------------------------------------------------
1933
1934    #[test]
1935    fn test_dropout1d_drops_whole_channels() {
1936        let d = Dropout1d::<f32>::new(0.5).unwrap();
1937        // Shape: [2, 10, 8] — 2 batches, 10 channels, length 8.
1938        let input = ferrotorch_core::ones::<f32>(&[2, 10, 8]).unwrap();
1939        let output = d.forward(&input).unwrap();
1940        let data = output.data().unwrap();
1941
1942        let length = 8;
1943        for b in 0..2 {
1944            for c in 0..10 {
1945                let start = (b * 10 + c) * length;
1946                let end = start + length;
1947                let channel = &data[start..end];
1948
1949                let first = channel[0];
1950                assert!(
1951                    channel.iter().all(|&x| (x - first).abs() < 1e-6),
1952                    "channel (b={b}, c={c}) is not uniform"
1953                );
1954                assert!(
1955                    first == 0.0 || (first - 2.0).abs() < 1e-6,
1956                    "channel value = {first}, expected 0.0 or 2.0"
1957                );
1958            }
1959        }
1960    }
1961
1962    #[test]
1963    fn test_dropout1d_rate_approximately_correct() {
1964        let d = Dropout1d::<f32>::new(0.5).unwrap();
1965        let input = ferrotorch_core::ones::<f32>(&[1, 1000, 4]).unwrap();
1966        let output = d.forward(&input).unwrap();
1967        let data = output.data().unwrap();
1968
1969        let length = 4;
1970        let mut dropped = 0;
1971        for c in 0..1000 {
1972            if data[c * length] == 0.0 {
1973                dropped += 1;
1974            }
1975        }
1976        let rate = dropped as f64 / 1000.0;
1977        assert!(
1978            (rate - 0.5).abs() < 0.05,
1979            "dropout1d rate = {rate}, expected ~0.5"
1980        );
1981    }
1982
1983    #[test]
1984    fn test_dropout1d_eval_is_identity() {
1985        let mut d = Dropout1d::<f32>::new(0.5).unwrap();
1986        d.eval();
1987        let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
1988        let output = d.forward(&input).unwrap();
1989        assert!(output.is_same(&input));
1990    }
1991
1992    #[test]
1993    fn test_dropout1d_invalid_p() {
1994        assert!(Dropout1d::<f32>::new(1.0).is_err());
1995        assert!(Dropout1d::<f32>::new(-0.1).is_err());
1996    }
1997
1998    #[test]
1999    fn test_dropout1d_requires_3d_input() {
2000        let d = Dropout1d::<f32>::new(0.3).unwrap();
2001        let input_2d = ferrotorch_core::ones::<f32>(&[10, 5]).unwrap();
2002        assert!(d.forward(&input_2d).is_err());
2003    }
2004
2005    #[test]
2006    fn test_dropout1d_no_parameters() {
2007        let d = Dropout1d::<f32>::new(0.3).unwrap();
2008        assert!(d.parameters().is_empty());
2009    }
2010
2011    #[test]
2012    fn test_dropout1d_is_send_sync() {
2013        fn assert_send_sync<T: Send + Sync>() {}
2014        assert_send_sync::<Dropout1d<f32>>();
2015        assert_send_sync::<Dropout1d<f64>>();
2016    }
2017
2018    // -----------------------------------------------------------------------
2019    // Dropout3d — CL-433
2020    // -----------------------------------------------------------------------
2021
2022    #[test]
2023    fn test_dropout3d_drops_whole_channels() {
2024        let d = Dropout3d::<f32>::new(0.5).unwrap();
2025        // Shape: [2, 10, 2, 2, 2] — 2 batches, 10 channels, 2x2x2 spatial.
2026        let input = ferrotorch_core::ones::<f32>(&[2, 10, 2, 2, 2]).unwrap();
2027        let output = d.forward(&input).unwrap();
2028        let data = output.data().unwrap();
2029
2030        let spatial = 2 * 2 * 2;
2031        for b in 0..2 {
2032            for c in 0..10 {
2033                let start = (b * 10 + c) * spatial;
2034                let end = start + spatial;
2035                let channel = &data[start..end];
2036
2037                let first = channel[0];
2038                assert!(
2039                    channel.iter().all(|&x| (x - first).abs() < 1e-6),
2040                    "channel (b={b}, c={c}) is not uniform"
2041                );
2042                assert!(
2043                    first == 0.0 || (first - 2.0).abs() < 1e-6,
2044                    "channel value = {first}, expected 0.0 or 2.0"
2045                );
2046            }
2047        }
2048    }
2049
2050    #[test]
2051    fn test_dropout3d_rate_approximately_correct() {
2052        let d = Dropout3d::<f32>::new(0.5).unwrap();
2053        let input = ferrotorch_core::ones::<f32>(&[1, 1000, 2, 2, 2]).unwrap();
2054        let output = d.forward(&input).unwrap();
2055        let data = output.data().unwrap();
2056
2057        let spatial = 2 * 2 * 2;
2058        let mut dropped = 0;
2059        for c in 0..1000 {
2060            if data[c * spatial] == 0.0 {
2061                dropped += 1;
2062            }
2063        }
2064        let rate = dropped as f64 / 1000.0;
2065        assert!(
2066            (rate - 0.5).abs() < 0.05,
2067            "dropout3d rate = {rate}, expected ~0.5"
2068        );
2069    }
2070
2071    #[test]
2072    fn test_dropout3d_eval_is_identity() {
2073        let mut d = Dropout3d::<f32>::new(0.5).unwrap();
2074        d.eval();
2075        let input = ferrotorch_core::ones::<f32>(&[2, 3, 2, 2, 2]).unwrap();
2076        let output = d.forward(&input).unwrap();
2077        assert!(output.is_same(&input));
2078    }
2079
2080    #[test]
2081    fn test_dropout3d_invalid_p() {
2082        assert!(Dropout3d::<f32>::new(1.0).is_err());
2083        assert!(Dropout3d::<f32>::new(-0.1).is_err());
2084    }
2085
2086    #[test]
2087    fn test_dropout3d_requires_5d_input() {
2088        let d = Dropout3d::<f32>::new(0.3).unwrap();
2089        let input_4d = ferrotorch_core::ones::<f32>(&[2, 3, 4, 4]).unwrap();
2090        assert!(d.forward(&input_4d).is_err());
2091    }
2092
2093    #[test]
2094    fn test_dropout3d_no_parameters() {
2095        let d = Dropout3d::<f32>::new(0.3).unwrap();
2096        assert!(d.parameters().is_empty());
2097    }
2098
2099    #[test]
2100    fn test_dropout3d_is_send_sync() {
2101        fn assert_send_sync<T: Send + Sync>() {}
2102        assert_send_sync::<Dropout3d<f32>>();
2103        assert_send_sync::<Dropout3d<f64>>();
2104    }
2105
2106    // -----------------------------------------------------------------------
2107    // AlphaDropout — CL-433
2108    // -----------------------------------------------------------------------
2109
2110    #[test]
2111    fn test_alpha_dropout_preserves_mean_approx() {
2112        // With large sample, mean should be approximately preserved.
2113        let d = AlphaDropout::<f64>::new(0.5).unwrap();
2114        // Generate input with known mean.
2115        let n = 100_000;
2116        let data: Vec<f64> = (0..n).map(|i| (i as f64 / n as f64) - 0.5).collect();
2117        let input_mean: f64 = data.iter().sum::<f64>() / n as f64;
2118
2119        let input = Tensor::from_storage(TensorStorage::cpu(data), vec![1, n], false).unwrap();
2120        let output = d.forward(&input).unwrap();
2121        let out_data = output.data().unwrap();
2122        let out_mean: f64 = out_data.iter().sum::<f64>() / n as f64;
2123
2124        // Mean should be roughly preserved (within statistical tolerance).
2125        assert!(
2126            (out_mean - input_mean).abs() < 0.05,
2127            "AlphaDropout mean = {out_mean}, input mean = {input_mean}"
2128        );
2129    }
2130
2131    #[test]
2132    fn test_alpha_dropout_eval_is_identity() {
2133        let mut d = AlphaDropout::<f32>::new(0.5).unwrap();
2134        d.eval();
2135        let input = ferrotorch_core::ones::<f32>(&[100]).unwrap();
2136        let output = d.forward(&input).unwrap();
2137        assert!(output.is_same(&input));
2138    }
2139
2140    #[test]
2141    fn test_alpha_dropout_zero_prob_is_identity() {
2142        let d = AlphaDropout::<f32>::new(0.0).unwrap();
2143        let input = ferrotorch_core::ones::<f32>(&[100]).unwrap();
2144        let output = d.forward(&input).unwrap();
2145        assert!(output.is_same(&input));
2146    }
2147
2148    #[test]
2149    fn test_alpha_dropout_invalid_p() {
2150        assert!(AlphaDropout::<f32>::new(1.0).is_err());
2151        assert!(AlphaDropout::<f32>::new(-0.1).is_err());
2152        assert!(AlphaDropout::<f32>::new(1.5).is_err());
2153    }
2154
2155    #[test]
2156    fn test_alpha_dropout_no_parameters() {
2157        let d = AlphaDropout::<f32>::new(0.3).unwrap();
2158        assert!(d.parameters().is_empty());
2159    }
2160
2161    #[test]
2162    fn test_alpha_dropout_backward_routes_gradient() {
2163        let d = AlphaDropout::<f32>::new(0.5).unwrap();
2164        let input = leaf_tensor(&[1.0; 1000], &[1000], true);
2165        let output = d.forward(&input).unwrap();
2166
2167        let out_data = output.data().unwrap().to_vec();
2168        let total: f32 = out_data.iter().sum();
2169
2170        #[derive(Debug)]
2171        struct SumBackward<T: Float> {
2172            input: Tensor<T>,
2173        }
2174        impl<T: Float> GradFn<T> for SumBackward<T> {
2175            fn backward(
2176                &self,
2177                _grad_output: &Tensor<T>,
2178            ) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2179                let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
2180                let t = Tensor::from_storage(
2181                    TensorStorage::cpu(ones),
2182                    self.input.shape().to_vec(),
2183                    false,
2184                )?;
2185                Ok(vec![Some(t)])
2186            }
2187            fn inputs(&self) -> Vec<&Tensor<T>> {
2188                vec![&self.input]
2189            }
2190            fn name(&self) -> &'static str {
2191                "SumBackward"
2192            }
2193        }
2194
2195        let loss = Tensor::from_operation(
2196            TensorStorage::cpu(vec![total]),
2197            vec![],
2198            Arc::new(SumBackward {
2199                input: output.clone(),
2200            }),
2201        )
2202        .unwrap();
2203        loss.backward().unwrap();
2204
2205        let grad = input.grad().unwrap().unwrap();
2206        let grad_data = grad.data().unwrap();
2207
2208        // Gradient should have two types of values: 0 for dropped, `a` for kept.
2209        let mut seen_zero = false;
2210        let mut seen_nonzero = false;
2211        for &g in grad_data {
2212            if g == 0.0 {
2213                seen_zero = true;
2214            } else {
2215                seen_nonzero = true;
2216            }
2217        }
2218        assert!(
2219            seen_zero,
2220            "some elements should have zero gradient (dropped)"
2221        );
2222        assert!(
2223            seen_nonzero,
2224            "some elements should have nonzero gradient (kept)"
2225        );
2226    }
2227
2228    #[test]
2229    fn test_alpha_dropout_train_eval_toggle() {
2230        let mut d = AlphaDropout::<f32>::new(0.5).unwrap();
2231        assert!(d.is_training());
2232        d.eval();
2233        assert!(!d.is_training());
2234        d.train();
2235        assert!(d.is_training());
2236    }
2237
2238    #[test]
2239    fn test_alpha_dropout_is_send_sync() {
2240        fn assert_send_sync<T: Send + Sync>() {}
2241        assert_send_sync::<AlphaDropout<f32>>();
2242        assert_send_sync::<AlphaDropout<f64>>();
2243    }
2244
2245    // -----------------------------------------------------------------------
2246    // inplace=true — blocker #1446
2247    //
2248    // Mirrors torch's `_VF.dropout_` / `_VF.feature_dropout_` family
2249    // (`torch/nn/functional.py:1449,1516,1579,1629`): with `inplace=True` and
2250    // training, the input tensor's storage is mutated (mask + scale written
2251    // back) instead of a fresh buffer being allocated. The mask-based backward
2252    // keeps autograd correct.
2253    // -----------------------------------------------------------------------
2254
2255    /// A minimal sum-reduction backward node used to drive `.backward()` in
2256    /// the in-place gradient tests below.
2257    #[derive(Debug)]
2258    struct SumBackward<T: Float> {
2259        input: Tensor<T>,
2260    }
2261    impl<T: Float> GradFn<T> for SumBackward<T> {
2262        fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2263            let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
2264            let t =
2265                Tensor::from_storage(TensorStorage::cpu(ones), self.input.shape().to_vec(), false)?;
2266            Ok(vec![Some(t)])
2267        }
2268        fn inputs(&self) -> Vec<&Tensor<T>> {
2269            vec![&self.input]
2270        }
2271        fn name(&self) -> &'static str {
2272            "SumBackward"
2273        }
2274    }
2275
2276    // (a) inplace=true mutates the SAME input storage. The input buffer (all
2277    //     ones before forward) is overwritten with the masked / scaled values
2278    //     {0, 2.0}. Verified by reading `input.data()` AFTER forward.
2279    #[test]
2280    fn test_dropout_inplace_mutates_input_storage() {
2281        let d = Dropout::<f32>::new(0.5).unwrap().with_inplace(true);
2282        assert!(d.inplace());
2283
2284        // Leaf without grad so we can re-read the input storage directly.
2285        let buf = vec![1.0f32; 10_000];
2286        let input = leaf_tensor(&buf, &[10_000], false);
2287        // Before forward: every element is 1.0.
2288        assert!(input.data().unwrap().iter().all(|&x| x == 1.0));
2289
2290        let output = d.forward(&input).unwrap();
2291
2292        // After forward: the INPUT storage itself has been mutated to the
2293        // post-dropout values (0.0 dropped, 2.0 = 1/(1-0.5) survivors). This
2294        // is the load-bearing in-place observation.
2295        let in_after = input.data().unwrap();
2296        assert!(
2297            in_after.contains(&0.0),
2298            "inplace forward must have zeroed some input elements"
2299        );
2300        for &x in in_after {
2301            assert!(
2302                x == 0.0 || (x - 2.0).abs() < 1e-6,
2303                "mutated input element = {x}, expected 0.0 or 2.0"
2304            );
2305        }
2306
2307        // (b) The output equals the mutated input element-for-element: the
2308        //     in-place write and the returned buffer carry the identical mask.
2309        let out_data = output.data().unwrap();
2310        assert_eq!(out_data.len(), in_after.len());
2311        for (i, (&o, &x)) in out_data.iter().zip(in_after.iter()).enumerate() {
2312            assert_eq!(o, x, "output/input mismatch at {i}: out={o}, in={x}");
2313        }
2314    }
2315
2316    // (d) eval-mode inplace is identity — torch's `F.dropout(.., training=False,
2317    //     inplace=True)` returns the input untouched (the `_VF.dropout_` branch
2318    //     is never reached because training is False; see functional.py:1448).
2319    #[test]
2320    fn test_dropout_inplace_eval_is_identity() {
2321        let mut d = Dropout::<f32>::new(0.5).unwrap().with_inplace(true);
2322        d.eval();
2323        let input = leaf_tensor(&[1.0; 100], &[100], false);
2324        let output = d.forward(&input).unwrap();
2325        // Identity: same tensor object returned, input storage untouched.
2326        assert!(output.is_same(&input));
2327        assert!(input.data().unwrap().iter().all(|&x| x == 1.0));
2328    }
2329
2330    // p == 0 with inplace=true is also identity.
2331    #[test]
2332    fn test_dropout_inplace_p_zero_is_identity() {
2333        let d = Dropout::<f32>::new(0.0).unwrap().with_inplace(true);
2334        let input = leaf_tensor(&[1.0; 100], &[100], false);
2335        let output = d.forward(&input).unwrap();
2336        assert!(output.is_same(&input));
2337        assert!(input.data().unwrap().iter().all(|&x| x == 1.0));
2338    }
2339
2340    // (c) backward through an in-place dropout on a grad-tracked NON-LEAF is
2341    //     correct: the autograd-safe policy falls back to out-of-place (no
2342    //     version counter to prove the shared storage is unused), so the input
2343    //     storage is NOT mutated, but the gradient still routes only through
2344    //     surviving elements (0 for dropped, 2.0 for kept) and the grad mask
2345    //     matches the output mask, exactly as the out-of-place path.
2346    #[test]
2347    fn test_dropout_inplace_backward_routes_through_surviving() {
2348        use ferrotorch_core::grad_fns::arithmetic::mul;
2349
2350        let d = Dropout::<f32>::new(0.5).unwrap().with_inplace(true);
2351        // Non-leaf grad-tracked input: `t = x * 1` requires grad but is not a
2352        // leaf, so `apply_inplace_dropout` takes the out-of-place fallback
2353        // rather than erroring on the leaf guard.
2354        let x = leaf_tensor(&[1.0; 1000], &[1000], true);
2355        let ones = leaf_tensor(&[1.0; 1000], &[1000], false);
2356        let input = mul(&x, &ones).unwrap();
2357        assert!(input.requires_grad() && !input.is_leaf());
2358        let input_before = input.data().unwrap().to_vec();
2359
2360        let output = d.forward(&input).unwrap();
2361
2362        // Safe fallback: the grad-tracked non-leaf storage is left UNMUTATED.
2363        let input_after = input.data().unwrap().to_vec();
2364        assert_eq!(
2365            input_before, input_after,
2366            "in-place dropout on a grad-tracked non-leaf must fall back to \
2367             out-of-place and leave the input storage untouched (no version \
2368             counter to prove the shared storage is unused)"
2369        );
2370
2371        let out_data = output.data().unwrap().to_vec();
2372        let total: f32 = out_data.iter().sum();
2373        let loss = Tensor::from_operation(
2374            TensorStorage::cpu(vec![total]),
2375            vec![],
2376            Arc::new(SumBackward {
2377                input: output.clone(),
2378            }),
2379        )
2380        .unwrap();
2381        loss.backward().unwrap();
2382
2383        // Gradient flows back to the leaf `x` through the out-of-place dropout.
2384        let grad = x.grad().unwrap().unwrap();
2385        let grad_data = grad.data().unwrap();
2386        for &g in grad_data {
2387            assert!(
2388                g == 0.0 || (g - 2.0).abs() < 1e-6,
2389                "gradient element = {g}, expected 0.0 or 2.0"
2390            );
2391        }
2392        // grad mask matches output mask: dropped iff zero gradient.
2393        for (i, (&o, &g)) in out_data.iter().zip(grad_data.iter()).enumerate() {
2394            assert_eq!(
2395                o == 0.0,
2396                g == 0.0,
2397                "mismatch at index {i}: out={o}, grad={g}"
2398            );
2399        }
2400    }
2401
2402    // (c2) in-place dropout on a grad-requiring LEAF errors, matching torch's
2403    //      leaf in-place guard (`torch/csrc/autograd/VariableTypeUtils.h:80-84`,
2404    //      "a leaf Variable that requires grad is being used in an in-place
2405    //      operation."). Pins #1581.
2406    #[test]
2407    fn test_dropout_inplace_on_grad_leaf_errors() {
2408        let original = vec![1.0f32; 100];
2409        let d = Dropout::<f32>::new(0.5).unwrap().with_inplace(true);
2410        let input = leaf_tensor(&original, &[100], true);
2411        assert!(input.is_leaf() && input.requires_grad());
2412
2413        let err = d.forward(&input).unwrap_err();
2414        match err {
2415            FerrotorchError::InvalidArgument { message } => assert!(
2416                message.contains("leaf Variable that requires grad"),
2417                "expected torch leaf-guard message, got: {message}"
2418            ),
2419            other => panic!("expected InvalidArgument leaf-guard error, got {other:?}"),
2420        }
2421        // The leaf storage is left untouched (no partial mutation before error).
2422        assert_eq!(input.data().unwrap().to_vec(), original);
2423    }
2424
2425    // (e) all four standard dropout variants honor inplace: the input storage
2426    //     is mutated channel-wise (or element-wise for `Dropout`).
2427    #[test]
2428    fn test_dropout2d_inplace_mutates_input_storage() {
2429        let d = Dropout2d::<f32>::new(0.5).unwrap().with_inplace(true);
2430        assert!(d.inplace());
2431        let input = leaf_tensor(&[1.0; 2 * 500 * 4], &[2, 500, 2, 2], false);
2432        let _ = d.forward(&input).unwrap();
2433        let in_after = input.data().unwrap();
2434        // Channel-wise: each (b, c) block of 4 spatial elems is uniform.
2435        let spatial = 4;
2436        let mut saw_dropped = false;
2437        for blk in in_after.chunks(spatial) {
2438            let first = blk[0];
2439            assert!(blk.iter().all(|&x| (x - first).abs() < 1e-6));
2440            assert!(first == 0.0 || (first - 2.0).abs() < 1e-6);
2441            if first == 0.0 {
2442                saw_dropped = true;
2443            }
2444        }
2445        assert!(
2446            saw_dropped,
2447            "inplace dropout2d must have zeroed some channels"
2448        );
2449    }
2450
2451    #[test]
2452    fn test_dropout1d_inplace_mutates_input_storage() {
2453        let d = Dropout1d::<f32>::new(0.5).unwrap().with_inplace(true);
2454        assert!(d.inplace());
2455        let input = leaf_tensor(&[1.0; 500 * 4], &[1, 500, 4], false);
2456        let _ = d.forward(&input).unwrap();
2457        let in_after = input.data().unwrap();
2458        let mut saw_dropped = false;
2459        for blk in in_after.chunks(4) {
2460            let first = blk[0];
2461            assert!(blk.iter().all(|&x| (x - first).abs() < 1e-6));
2462            assert!(first == 0.0 || (first - 2.0).abs() < 1e-6);
2463            if first == 0.0 {
2464                saw_dropped = true;
2465            }
2466        }
2467        assert!(
2468            saw_dropped,
2469            "inplace dropout1d must have zeroed some channels"
2470        );
2471    }
2472
2473    #[test]
2474    fn test_dropout3d_inplace_mutates_input_storage() {
2475        let d = Dropout3d::<f32>::new(0.5).unwrap().with_inplace(true);
2476        assert!(d.inplace());
2477        let input = leaf_tensor(&[1.0; 500 * 8], &[1, 500, 2, 2, 2], false);
2478        let _ = d.forward(&input).unwrap();
2479        let in_after = input.data().unwrap();
2480        let mut saw_dropped = false;
2481        for blk in in_after.chunks(8) {
2482            let first = blk[0];
2483            assert!(blk.iter().all(|&x| (x - first).abs() < 1e-6));
2484            assert!(first == 0.0 || (first - 2.0).abs() < 1e-6);
2485            if first == 0.0 {
2486                saw_dropped = true;
2487            }
2488        }
2489        assert!(
2490            saw_dropped,
2491            "inplace dropout3d must have zeroed some channels"
2492        );
2493    }
2494
2495    // The non-inplace path is the default and leaves the input untouched —
2496    // confirms inplace=false (existing behavior) is preserved.
2497    #[test]
2498    fn test_dropout_default_is_not_inplace() {
2499        let d = Dropout::<f32>::new(0.5).unwrap();
2500        assert!(!d.inplace());
2501        let input = leaf_tensor(&[1.0; 1000], &[1000], false);
2502        let _ = d.forward(&input).unwrap();
2503        // Input untouched: still all ones.
2504        assert!(input.data().unwrap().iter().all(|&x| x == 1.0));
2505    }
2506
2507    // AlphaDropout / FeatureAlphaDropout carry the `inplace` field for ABI
2508    // parity but — matching torch's module forward (`dropout.py:265-269`,
2509    // `319-323`, which never pass `self.inplace` to the functional) — do NOT
2510    // mutate the input even when inplace=true. The field is observable via the
2511    // `inplace()` getter.
2512    #[test]
2513    fn test_alpha_dropout_inplace_field_does_not_mutate() {
2514        let d = AlphaDropout::<f32>::new(0.5).unwrap().with_inplace(true);
2515        assert!(d.inplace(), "field is retained for API parity");
2516        let input = leaf_tensor(&[1.0; 1000], &[1000], false);
2517        let _ = d.forward(&input).unwrap();
2518        // Matching torch: the module forward ignores inplace, input untouched.
2519        assert!(
2520            input.data().unwrap().iter().all(|&x| x == 1.0),
2521            "AlphaDropout module forward must not mutate in place (matches torch dropout.py:265-269)"
2522        );
2523    }
2524
2525    #[test]
2526    fn test_feature_alpha_dropout_inplace_field_does_not_mutate() {
2527        let d = FeatureAlphaDropout::<f32>::new(0.5)
2528            .unwrap()
2529            .with_inplace(true);
2530        assert!(d.inplace(), "field is retained for API parity");
2531        let input = leaf_tensor(&[1.0; 1000], &[1, 1000], false);
2532        let _ = d.forward(&input).unwrap();
2533        assert!(
2534            input.data().unwrap().iter().all(|&x| x == 1.0),
2535            "FeatureAlphaDropout module forward must not mutate in place (matches torch dropout.py:319-323)"
2536        );
2537    }
2538
2539    // -----------------------------------------------------------------------
2540    // Seed-reproducible byte-match vs LIVE torch 2.11 (#1635 / #1636).
2541    //
2542    // Reference values produced by live torch under `torch.manual_seed(42)`
2543    // — NOT copied from the ferrotorch side (R-CHAR-3). The per-channel /
2544    // per-element masks come from the byte-exact MT19937 `Generator`, so a
2545    // shared `ferrotorch_core::manual_seed(42)` reproduces torch's stream.
2546    // -----------------------------------------------------------------------
2547
2548    fn ones_shape_t(shape: &[usize]) -> Tensor<f32> {
2549        let n: usize = shape.iter().product();
2550        Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; n]), shape.to_vec(), false).unwrap()
2551    }
2552
2553    /// `torch.manual_seed(42); F.dropout2d(ones(1,8,1,1),0.5,True)` per-channel
2554    /// -> survivors scaled by 1/(1-0.5)=2 in the MT19937 keep pattern
2555    /// [keep,keep,keep,keep,DROP,keep,DROP,DROP].
2556    #[test]
2557    fn test_dropout2d_seed42_matches_torch() {
2558        let want = [2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 0.0, 0.0];
2559        ferrotorch_core::rng::manual_seed(42);
2560        let d = Dropout2d::<f32>::new(0.5).unwrap();
2561        let y = d.forward(&ones_shape_t(&[1, 8, 1, 1])).unwrap();
2562        assert_eq!(y.data().unwrap(), &want);
2563    }
2564
2565    /// `torch.manual_seed(42); F.dropout1d(ones(1,6,3),0.5,True)` per-channel
2566    /// -> [2,2,2,2,0,2], broadcast over the length-3 dim.
2567    #[test]
2568    fn test_dropout1d_seed42_matches_torch() {
2569        let want = [2.0, 2.0, 2.0, 2.0, 0.0, 2.0];
2570        ferrotorch_core::rng::manual_seed(42);
2571        let d = Dropout1d::<f32>::new(0.5).unwrap();
2572        let y = d.forward(&ones_shape_t(&[1, 6, 3])).unwrap();
2573        let data = y.data().unwrap();
2574        let per_chan: Vec<f32> = (0..6).map(|c| data[c * 3]).collect();
2575        assert_eq!(per_chan.as_slice(), &want);
2576    }
2577
2578    /// `torch.manual_seed(42); F.dropout3d(ones(1,6,1,1,1),0.5,True)` per-channel
2579    /// -> [2,2,2,2,0,2].
2580    #[test]
2581    fn test_dropout3d_seed42_matches_torch() {
2582        let want = [2.0, 2.0, 2.0, 2.0, 0.0, 2.0];
2583        ferrotorch_core::rng::manual_seed(42);
2584        let d = Dropout3d::<f32>::new(0.5).unwrap();
2585        let y = d.forward(&ones_shape_t(&[1, 6, 1, 1, 1])).unwrap();
2586        assert_eq!(y.data().unwrap(), &want);
2587    }
2588
2589    /// Two seeded `Dropout2d` forwards under the SAME `manual_seed(42)` produce
2590    /// the SAME mask (MT19937 reset on manual_seed; no system-time entropy).
2591    #[test]
2592    fn test_dropout2d_reproducible_under_manual_seed() {
2593        let d = Dropout2d::<f32>::new(0.5).unwrap();
2594        ferrotorch_core::rng::manual_seed(42);
2595        let y1 = d.forward(&ones_shape_t(&[1, 64, 1, 1])).unwrap();
2596        ferrotorch_core::rng::manual_seed(42);
2597        let y2 = d.forward(&ones_shape_t(&[1, 64, 1, 1])).unwrap();
2598        assert_eq!(y1.data().unwrap(), y2.data().unwrap());
2599    }
2600
2601    /// `torch.manual_seed(42); nn.AlphaDropout(0.5).train()(ones(10))`
2602    /// -> kept = 1.6655989, dropped = -0.7791939 in the MT19937 keep pattern.
2603    /// kept/dropped values from torch's exact affine (`Dropout.cpp:74-79`),
2604    /// alpha = 1.7580993408473766.
2605    #[test]
2606    fn test_alpha_dropout_seed42_matches_torch() {
2607        let want = [
2608            1.6655989, 1.6655989, 1.6655989, 1.6655989, -0.7791939, 1.6655989, -0.7791939,
2609            -0.7791939, 1.6655989, 1.6655989,
2610        ];
2611        ferrotorch_core::rng::manual_seed(42);
2612        let d = AlphaDropout::<f32>::new(0.5).unwrap();
2613        let y = d.forward(&ones_shape_t(&[10])).unwrap();
2614        let got = y.data().unwrap();
2615        for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
2616            assert!((g - w).abs() < 1e-4, "elem {i}: got {g} want {w}");
2617        }
2618    }
2619
2620    /// `torch.manual_seed(42); nn.FeatureAlphaDropout(0.5).train()(ones(1,6,1,1))`
2621    /// per-channel -> [1.6655989 ×4, -0.7791939, 1.6655989].
2622    #[test]
2623    fn test_feature_alpha_dropout_seed42_matches_torch() {
2624        let want = [
2625            1.6655989, 1.6655989, 1.6655989, 1.6655989, -0.7791939, 1.6655989,
2626        ];
2627        ferrotorch_core::rng::manual_seed(42);
2628        let d = FeatureAlphaDropout::<f32>::new(0.5).unwrap();
2629        let y = d.forward(&ones_shape_t(&[1, 6, 1, 1])).unwrap();
2630        let got = y.data().unwrap();
2631        for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
2632            assert!((g - w).abs() < 1e-4, "elem {i}: got {g} want {w}");
2633        }
2634    }
2635}