pub fn dropout_train_scalar(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
assert_eq!(input.len(), mask.len(), "input/mask dimension mismatch");
assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
assert!((0.0..1.0).contains(&p), "p must be in [0, 1), got {p}");
let scale = 1.0 / (1.0 - p);
for i in 0..input.len() {
output[i] = mask[i] * input[i] * scale;
}
}
pub fn dropout_eval_scalar(input: &[f32], output: &mut [f32]) {
assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
output.copy_from_slice(input);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn dropout_train_avx2(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
dropout_train_scalar(input, mask, p, output);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn dropout_eval_avx2(input: &[f32], output: &mut [f32]) {
dropout_eval_scalar(input, output);
}
pub fn dropout_ptx() -> &'static str {
r#".version 8.5
.target sm_90
.address_size 64
.visible .entry dropout_train_kernel(
.param .u64 INPUT,
.param .u64 MASK,
.param .u64 OUT,
.param .f32 SCALE,
.param .u32 N
) {
.reg .u32 %tid, %bid, %n, %idx;
.reg .u64 %in_ptr, %mask_ptr, %out_ptr, %addr, %off64;
.reg .f32 %in_val, %mask_val, %scale, %result;
.reg .pred %p_bound;
mov.u32 %tid, %tid.x;
mov.u32 %bid, %ctaid.x;
ld.param.u32 %n, [N];
ld.param.f32 %scale, [SCALE];
ld.param.u64 %in_ptr, [INPUT];
ld.param.u64 %mask_ptr, [MASK];
ld.param.u64 %out_ptr, [OUT];
// Global thread index
mul.lo.u32 %idx, %bid, 256;
add.u32 %idx, %idx, %tid;
setp.ge.u32 %p_bound, %idx, %n;
@%p_bound bra EXIT;
mul.wide.u32 %off64, %idx, 4;
// Load input[idx]
add.u64 %addr, %in_ptr, %off64;
ld.global.f32 %in_val, [%addr];
// Load mask[idx]
add.u64 %addr, %mask_ptr, %off64;
ld.global.f32 %mask_val, [%addr];
// result = mask * input * scale
mul.f32 %result, %mask_val, %in_val;
mul.f32 %result, %result, %scale;
// Store output[idx]
add.u64 %addr, %out_ptr, %off64;
st.global.f32 [%addr], %result;
EXIT:
ret;
}
"#
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_dropout_eval_is_identity() {
let input = [1.0, 2.0, 3.0, 4.0, 5.0];
let mut output = [0.0f32; 5];
dropout_eval_scalar(&input, &mut output);
assert_eq!(&output, &input);
}
#[test]
fn test_dropout_train_all_kept() {
let input = [1.0, 2.0, 3.0];
let mask = [1.0, 1.0, 1.0]; let mut output = [0.0f32; 3];
dropout_train_scalar(&input, &mask, 0.5, &mut output);
assert!((output[0] - 2.0).abs() < 1e-6);
assert!((output[1] - 4.0).abs() < 1e-6);
assert!((output[2] - 6.0).abs() < 1e-6);
}
#[test]
fn test_dropout_train_all_dropped() {
let input = [1.0, 2.0, 3.0];
let mask = [0.0, 0.0, 0.0];
let mut output = [99.0f32; 3];
dropout_train_scalar(&input, &mask, 0.5, &mut output);
assert_eq!(&output, &[0.0, 0.0, 0.0]);
}
#[test]
fn test_dropout_train_zero_p() {
let input = [1.0, 2.0, 3.0];
let mask = [1.0, 1.0, 1.0];
let mut output = [0.0f32; 3];
dropout_train_scalar(&input, &mask, 0.0, &mut output);
assert_eq!(&output, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_dropout_dropped_units_are_zero() {
let input = [5.0, 10.0, 15.0, 20.0];
let mask = [1.0, 0.0, 1.0, 0.0]; let mut output = [0.0f32; 4];
dropout_train_scalar(&input, &mask, 0.3, &mut output);
assert_eq!(output[1], 0.0);
assert_eq!(output[3], 0.0);
assert!(output[0] > 0.0);
assert!(output[2] > 0.0);
}
#[test]
fn test_dropout_shape_preservation() {
let n = 7;
let input = vec![1.0f32; n];
let mask = vec![1.0f32; n];
let mut output = vec![0.0f32; n];
dropout_train_scalar(&input, &mask, 0.1, &mut output);
assert_eq!(output.len(), input.len());
}
proptest! {
#[test]
fn prop_dropout_eval_identity(n in 1usize..16) {
let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3).collect();
let mut output = vec![0.0f32; n];
dropout_eval_scalar(&input, &mut output);
for (i, (&a, &b)) in input.iter().zip(output.iter()).enumerate() {
prop_assert_eq!(a, b, "eval not identity at {}", i);
}
}
#[test]
fn prop_dropout_train_finite(
n in 1usize..10,
p_int in 0u32..99,
) {
let p = p_int as f32 / 100.0;
let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
let mask: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { 0.0 }).collect();
let mut output = vec![0.0f32; n];
dropout_train_scalar(&input, &mask, p, &mut output);
for (idx, &val) in output.iter().enumerate() {
prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
}
}
}
#[test]
fn test_dropout_ptx_structure() {
let ptx = dropout_ptx();
assert!(ptx.contains(".entry dropout_train_kernel"));
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("ret;"));
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_dropout_avx2_parity() {
if !is_x86_feature_detected!("avx2") {
return;
}
let input = [1.0, 2.0, 3.0, 4.0];
let mask = [1.0, 0.0, 1.0, 0.0];
let mut scalar_out = [0.0f32; 4];
let mut avx2_out = [0.0f32; 4];
dropout_train_scalar(&input, &mask, 0.5, &mut scalar_out);
unsafe { dropout_train_avx2(&input, &mask, 0.5, &mut avx2_out) };
assert_eq!(scalar_out, avx2_out);
}
}