Skip to main content

provable_contracts/kernels/
dropout.rs

1//! Dropout kernel.
2//!
3//! Matches `dropout-v1.yaml`.
4//! Train: `y = mask * x / (1 - p)` where mask ~ Bernoulli(1 - p).
5//! Eval: `y = x` (identity).
6//!
7//! Note: The mask is pre-computed and passed in (deterministic), rather than
8//! using internal RNG. This makes the kernel verifiable and reproducible.
9//!
10//! Each function provides one of three backends:
11//! - `fn dropout_{train,eval}_scalar(...)` -- Pure Rust scalar reference
12//! - `unsafe fn dropout_{train,eval}_avx2(...)` -- AVX2 SIMD implementation
13//! - `fn dropout_ptx() -> &'static str` -- PTX assembly source string
14
15// ────────────────────────────────────────────────────────────────────────────
16// Scalar implementation
17// ────────────────────────────────────────────────────────────────────────────
18
19/// Dropout in training mode (scalar reference).
20///
21/// `mask` is a pre-computed boolean mask (0.0 or 1.0), `p` is the drop probability.
22/// `output[i] = mask[i] * input[i] / (1 - p)` (inverted dropout).
23///
24/// # Panics
25/// Panics if dimensions don't match or `p >= 1.0`.
26pub fn dropout_train_scalar(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
27    assert_eq!(input.len(), mask.len(), "input/mask dimension mismatch");
28    assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
29    assert!((0.0..1.0).contains(&p), "p must be in [0, 1), got {p}");
30
31    let scale = 1.0 / (1.0 - p);
32    for i in 0..input.len() {
33        output[i] = mask[i] * input[i] * scale;
34    }
35}
36
37/// Dropout in eval mode (scalar reference).
38///
39/// Identity function: `output[i] = input[i]`.
40///
41/// # Panics
42/// Panics if `input.len() != output.len()`.
43pub fn dropout_eval_scalar(input: &[f32], output: &mut [f32]) {
44    assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
45    output.copy_from_slice(input);
46}
47
48// ────────────────────────────────────────────────────────────────────────────
49// AVX2 implementation
50// ────────────────────────────────────────────────────────────────────────────
51
52/// AVX2 dropout (training) -- delegates to scalar.
53///
54/// # Safety
55/// Requires AVX2 support.
56#[cfg(target_arch = "x86_64")]
57#[target_feature(enable = "avx2")]
58pub unsafe fn dropout_train_avx2(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
59    dropout_train_scalar(input, mask, p, output);
60}
61
62/// AVX2 dropout (eval) -- delegates to scalar.
63///
64/// # Safety
65/// Requires AVX2 support.
66#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68pub unsafe fn dropout_eval_avx2(input: &[f32], output: &mut [f32]) {
69    dropout_eval_scalar(input, output);
70}
71
72// ────────────────────────────────────────────────────────────────────────────
73// PTX implementation
74// ────────────────────────────────────────────────────────────────────────────
75
76/// PTX assembly for dropout (training mode).
77///
78/// One thread per element. Each thread applies mask and scaling.
79pub fn dropout_ptx() -> &'static str {
80    r#".version 8.5
81.target sm_90
82.address_size 64
83.visible .entry dropout_train_kernel(
84    .param .u64 INPUT,
85    .param .u64 MASK,
86    .param .u64 OUT,
87    .param .f32 SCALE,
88    .param .u32 N
89) {
90    .reg .u32 %tid, %bid, %n, %idx;
91    .reg .u64 %in_ptr, %mask_ptr, %out_ptr, %addr, %off64;
92    .reg .f32 %in_val, %mask_val, %scale, %result;
93    .reg .pred %p_bound;
94
95    mov.u32 %tid, %tid.x;
96    mov.u32 %bid, %ctaid.x;
97
98    ld.param.u32 %n, [N];
99    ld.param.f32 %scale, [SCALE];
100    ld.param.u64 %in_ptr, [INPUT];
101    ld.param.u64 %mask_ptr, [MASK];
102    ld.param.u64 %out_ptr, [OUT];
103
104    // Global thread index
105    mul.lo.u32 %idx, %bid, 256;
106    add.u32 %idx, %idx, %tid;
107
108    setp.ge.u32 %p_bound, %idx, %n;
109    @%p_bound bra EXIT;
110
111    mul.wide.u32 %off64, %idx, 4;
112
113    // Load input[idx]
114    add.u64 %addr, %in_ptr, %off64;
115    ld.global.f32 %in_val, [%addr];
116
117    // Load mask[idx]
118    add.u64 %addr, %mask_ptr, %off64;
119    ld.global.f32 %mask_val, [%addr];
120
121    // result = mask * input * scale
122    mul.f32 %result, %mask_val, %in_val;
123    mul.f32 %result, %result, %scale;
124
125    // Store output[idx]
126    add.u64 %addr, %out_ptr, %off64;
127    st.global.f32 [%addr], %result;
128
129EXIT:
130    ret;
131}
132"#
133}
134
135// ────────────────────────────────────────────────────────────────────────────
136// Tests
137// ────────────────────────────────────────────────────────────────────────────
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use proptest::prelude::*;
143
144    #[test]
145    fn test_dropout_eval_is_identity() {
146        let input = [1.0, 2.0, 3.0, 4.0, 5.0];
147        let mut output = [0.0f32; 5];
148        dropout_eval_scalar(&input, &mut output);
149        assert_eq!(&output, &input);
150    }
151
152    #[test]
153    fn test_dropout_train_all_kept() {
154        let input = [1.0, 2.0, 3.0];
155        let mask = [1.0, 1.0, 1.0]; // all kept
156        let mut output = [0.0f32; 3];
157
158        dropout_train_scalar(&input, &mask, 0.5, &mut output);
159        // scale = 1 / (1 - 0.5) = 2.0
160        assert!((output[0] - 2.0).abs() < 1e-6);
161        assert!((output[1] - 4.0).abs() < 1e-6);
162        assert!((output[2] - 6.0).abs() < 1e-6);
163    }
164
165    #[test]
166    fn test_dropout_train_all_dropped() {
167        let input = [1.0, 2.0, 3.0];
168        let mask = [0.0, 0.0, 0.0];
169        let mut output = [99.0f32; 3];
170
171        dropout_train_scalar(&input, &mask, 0.5, &mut output);
172        assert_eq!(&output, &[0.0, 0.0, 0.0]);
173    }
174
175    #[test]
176    fn test_dropout_train_zero_p() {
177        // p=0 means no dropout, scale = 1/(1-0) = 1
178        let input = [1.0, 2.0, 3.0];
179        let mask = [1.0, 1.0, 1.0];
180        let mut output = [0.0f32; 3];
181
182        dropout_train_scalar(&input, &mask, 0.0, &mut output);
183        assert_eq!(&output, &[1.0, 2.0, 3.0]);
184    }
185
186    #[test]
187    fn test_dropout_dropped_units_are_zero() {
188        let input = [5.0, 10.0, 15.0, 20.0];
189        let mask = [1.0, 0.0, 1.0, 0.0]; // drop indices 1, 3
190        let mut output = [0.0f32; 4];
191
192        dropout_train_scalar(&input, &mask, 0.3, &mut output);
193        assert_eq!(output[1], 0.0);
194        assert_eq!(output[3], 0.0);
195        assert!(output[0] > 0.0);
196        assert!(output[2] > 0.0);
197    }
198
199    #[test]
200    fn test_dropout_shape_preservation() {
201        let n = 7;
202        let input = vec![1.0f32; n];
203        let mask = vec![1.0f32; n];
204        let mut output = vec![0.0f32; n];
205
206        dropout_train_scalar(&input, &mask, 0.1, &mut output);
207        assert_eq!(output.len(), input.len());
208    }
209
210    proptest! {
211        #[test]
212        fn prop_dropout_eval_identity(n in 1usize..16) {
213            let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3).collect();
214            let mut output = vec![0.0f32; n];
215            dropout_eval_scalar(&input, &mut output);
216
217            for (i, (&a, &b)) in input.iter().zip(output.iter()).enumerate() {
218                prop_assert_eq!(a, b, "eval not identity at {}", i);
219            }
220        }
221
222        #[test]
223        fn prop_dropout_train_finite(
224            n in 1usize..10,
225            p_int in 0u32..99,
226        ) {
227            let p = p_int as f32 / 100.0;
228            let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
229            let mask: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { 0.0 }).collect();
230            let mut output = vec![0.0f32; n];
231
232            dropout_train_scalar(&input, &mask, p, &mut output);
233
234            for (idx, &val) in output.iter().enumerate() {
235                prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
236            }
237        }
238    }
239
240    #[test]
241    fn test_dropout_ptx_structure() {
242        let ptx = dropout_ptx();
243        assert!(ptx.contains(".entry dropout_train_kernel"));
244        assert!(ptx.contains("mul.f32"));
245        assert!(ptx.contains("ret;"));
246    }
247
248    #[cfg(target_arch = "x86_64")]
249    #[test]
250    fn test_dropout_avx2_parity() {
251        if !is_x86_feature_detected!("avx2") {
252            return;
253        }
254        let input = [1.0, 2.0, 3.0, 4.0];
255        let mask = [1.0, 0.0, 1.0, 0.0];
256        let mut scalar_out = [0.0f32; 4];
257        let mut avx2_out = [0.0f32; 4];
258        dropout_train_scalar(&input, &mask, 0.5, &mut scalar_out);
259        unsafe { dropout_train_avx2(&input, &mask, 0.5, &mut avx2_out) };
260        assert_eq!(scalar_out, avx2_out);
261    }
262}