provable_contracts/kernels/
dropout.rs1pub fn dropout_train_scalar(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
27 assert_eq!(input.len(), mask.len(), "input/mask dimension mismatch");
28 assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
29 assert!((0.0..1.0).contains(&p), "p must be in [0, 1), got {p}");
30
31 let scale = 1.0 / (1.0 - p);
32 for i in 0..input.len() {
33 output[i] = mask[i] * input[i] * scale;
34 }
35}
36
37pub fn dropout_eval_scalar(input: &[f32], output: &mut [f32]) {
44 assert_eq!(input.len(), output.len(), "input/output dimension mismatch");
45 output.copy_from_slice(input);
46}
47
48#[cfg(target_arch = "x86_64")]
57#[target_feature(enable = "avx2")]
58pub unsafe fn dropout_train_avx2(input: &[f32], mask: &[f32], p: f32, output: &mut [f32]) {
59 dropout_train_scalar(input, mask, p, output);
60}
61
62#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68pub unsafe fn dropout_eval_avx2(input: &[f32], output: &mut [f32]) {
69 dropout_eval_scalar(input, output);
70}
71
72pub fn dropout_ptx() -> &'static str {
80 r#".version 8.5
81.target sm_90
82.address_size 64
83.visible .entry dropout_train_kernel(
84 .param .u64 INPUT,
85 .param .u64 MASK,
86 .param .u64 OUT,
87 .param .f32 SCALE,
88 .param .u32 N
89) {
90 .reg .u32 %tid, %bid, %n, %idx;
91 .reg .u64 %in_ptr, %mask_ptr, %out_ptr, %addr, %off64;
92 .reg .f32 %in_val, %mask_val, %scale, %result;
93 .reg .pred %p_bound;
94
95 mov.u32 %tid, %tid.x;
96 mov.u32 %bid, %ctaid.x;
97
98 ld.param.u32 %n, [N];
99 ld.param.f32 %scale, [SCALE];
100 ld.param.u64 %in_ptr, [INPUT];
101 ld.param.u64 %mask_ptr, [MASK];
102 ld.param.u64 %out_ptr, [OUT];
103
104 // Global thread index
105 mul.lo.u32 %idx, %bid, 256;
106 add.u32 %idx, %idx, %tid;
107
108 setp.ge.u32 %p_bound, %idx, %n;
109 @%p_bound bra EXIT;
110
111 mul.wide.u32 %off64, %idx, 4;
112
113 // Load input[idx]
114 add.u64 %addr, %in_ptr, %off64;
115 ld.global.f32 %in_val, [%addr];
116
117 // Load mask[idx]
118 add.u64 %addr, %mask_ptr, %off64;
119 ld.global.f32 %mask_val, [%addr];
120
121 // result = mask * input * scale
122 mul.f32 %result, %mask_val, %in_val;
123 mul.f32 %result, %result, %scale;
124
125 // Store output[idx]
126 add.u64 %addr, %out_ptr, %off64;
127 st.global.f32 [%addr], %result;
128
129EXIT:
130 ret;
131}
132"#
133}
134
135#[cfg(test)]
140mod tests {
141 use super::*;
142 use proptest::prelude::*;
143
144 #[test]
145 fn test_dropout_eval_is_identity() {
146 let input = [1.0, 2.0, 3.0, 4.0, 5.0];
147 let mut output = [0.0f32; 5];
148 dropout_eval_scalar(&input, &mut output);
149 assert_eq!(&output, &input);
150 }
151
152 #[test]
153 fn test_dropout_train_all_kept() {
154 let input = [1.0, 2.0, 3.0];
155 let mask = [1.0, 1.0, 1.0]; let mut output = [0.0f32; 3];
157
158 dropout_train_scalar(&input, &mask, 0.5, &mut output);
159 assert!((output[0] - 2.0).abs() < 1e-6);
161 assert!((output[1] - 4.0).abs() < 1e-6);
162 assert!((output[2] - 6.0).abs() < 1e-6);
163 }
164
165 #[test]
166 fn test_dropout_train_all_dropped() {
167 let input = [1.0, 2.0, 3.0];
168 let mask = [0.0, 0.0, 0.0];
169 let mut output = [99.0f32; 3];
170
171 dropout_train_scalar(&input, &mask, 0.5, &mut output);
172 assert_eq!(&output, &[0.0, 0.0, 0.0]);
173 }
174
175 #[test]
176 fn test_dropout_train_zero_p() {
177 let input = [1.0, 2.0, 3.0];
179 let mask = [1.0, 1.0, 1.0];
180 let mut output = [0.0f32; 3];
181
182 dropout_train_scalar(&input, &mask, 0.0, &mut output);
183 assert_eq!(&output, &[1.0, 2.0, 3.0]);
184 }
185
186 #[test]
187 fn test_dropout_dropped_units_are_zero() {
188 let input = [5.0, 10.0, 15.0, 20.0];
189 let mask = [1.0, 0.0, 1.0, 0.0]; let mut output = [0.0f32; 4];
191
192 dropout_train_scalar(&input, &mask, 0.3, &mut output);
193 assert_eq!(output[1], 0.0);
194 assert_eq!(output[3], 0.0);
195 assert!(output[0] > 0.0);
196 assert!(output[2] > 0.0);
197 }
198
199 #[test]
200 fn test_dropout_shape_preservation() {
201 let n = 7;
202 let input = vec![1.0f32; n];
203 let mask = vec![1.0f32; n];
204 let mut output = vec![0.0f32; n];
205
206 dropout_train_scalar(&input, &mask, 0.1, &mut output);
207 assert_eq!(output.len(), input.len());
208 }
209
210 proptest! {
211 #[test]
212 fn prop_dropout_eval_identity(n in 1usize..16) {
213 let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3).collect();
214 let mut output = vec![0.0f32; n];
215 dropout_eval_scalar(&input, &mut output);
216
217 for (i, (&a, &b)) in input.iter().zip(output.iter()).enumerate() {
218 prop_assert_eq!(a, b, "eval not identity at {}", i);
219 }
220 }
221
222 #[test]
223 fn prop_dropout_train_finite(
224 n in 1usize..10,
225 p_int in 0u32..99,
226 ) {
227 let p = p_int as f32 / 100.0;
228 let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
229 let mask: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { 0.0 }).collect();
230 let mut output = vec![0.0f32; n];
231
232 dropout_train_scalar(&input, &mask, p, &mut output);
233
234 for (idx, &val) in output.iter().enumerate() {
235 prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
236 }
237 }
238 }
239
240 #[test]
241 fn test_dropout_ptx_structure() {
242 let ptx = dropout_ptx();
243 assert!(ptx.contains(".entry dropout_train_kernel"));
244 assert!(ptx.contains("mul.f32"));
245 assert!(ptx.contains("ret;"));
246 }
247
248 #[cfg(target_arch = "x86_64")]
249 #[test]
250 fn test_dropout_avx2_parity() {
251 if !is_x86_feature_detected!("avx2") {
252 return;
253 }
254 let input = [1.0, 2.0, 3.0, 4.0];
255 let mask = [1.0, 0.0, 1.0, 0.0];
256 let mut scalar_out = [0.0f32; 4];
257 let mut avx2_out = [0.0f32; 4];
258 dropout_train_scalar(&input, &mask, 0.5, &mut scalar_out);
259 unsafe { dropout_train_avx2(&input, &mask, 0.5, &mut avx2_out) };
260 assert_eq!(scalar_out, avx2_out);
261 }
262}