Skip to main content

Module dropout

Module dropout 

Source
Expand description

Dropout kernel.

Matches dropout-v1.yaml. Train: y = mask * x / (1 - p) where mask ~ Bernoulli(1 - p). Eval: y = x (identity).

Note: The mask is pre-computed and passed in (deterministic), rather than using internal RNG. This makes the kernel verifiable and reproducible.

Each function provides one of three backends:

  • fn dropout_{train,eval}_scalar(...) – Pure Rust scalar reference
  • unsafe fn dropout_{train,eval}_avx2(...) – AVX2 SIMD implementation
  • fn dropout_ptx() -> &'static str – PTX assembly source string

Functions§

dropout_eval_avx2
AVX2 dropout (eval) – delegates to scalar.
dropout_eval_scalar
Dropout in eval mode (scalar reference).
dropout_ptx
PTX assembly for dropout (training mode).
dropout_train_avx2
AVX2 dropout (training) – delegates to scalar.
dropout_train_scalar
Dropout in training mode (scalar reference).