optirs_core/regularizers/
dropconnect.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14#[derive(Debug, Clone)]
37pub struct DropConnect<A: Float> {
38 drop_prob: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + Send + Sync> DropConnect<A> {
43 pub fn new(dropprob: A) -> Result<Self> {
53 if dropprob < A::zero() || dropprob > A::one() {
54 return Err(OptimError::InvalidConfig(
55 "Drop probability must be between 0.0 and 1.0".to_string(),
56 ));
57 }
58
59 Ok(Self {
60 drop_prob: dropprob,
61 })
62 }
63
64 pub fn apply_to_weights<D: Dimension>(
71 &self,
72 weights: &Array<A, D>,
73 training: bool,
74 ) -> Array<A, D> {
75 if !training || self.drop_prob == A::zero() {
76 return weights.clone();
78 }
79
80 let keep_prob = A::one() - self.drop_prob;
82 let keep_prob_f64 = keep_prob.to_f64().unwrap();
83
84 let mut rng = thread_rng();
86 let mask = Array::from_shape_fn(weights.raw_dim(), |_| rng.random_bool(keep_prob_f64));
87
88 let mut result = weights.clone();
90 for (r, &m) in result.iter_mut().zip(mask.iter()) {
91 if !m {
92 *r = A::zero();
93 } else {
94 *r = *r / keep_prob;
96 }
97 }
98
99 result
100 }
101
102 pub fn apply_to_gradients<D: Dimension>(
107 &self,
108 gradients: &Array<A, D>,
109 weightsshape: D,
110 training: bool,
111 ) -> Array<A, D> {
112 if !training || self.drop_prob == A::zero() {
113 return gradients.clone();
114 }
115
116 let keep_prob = A::one() - self.drop_prob;
118 let keep_prob_f64 = keep_prob.to_f64().unwrap();
119
120 let mut rng = thread_rng();
122 let mask = Array::from_shape_fn(weightsshape, |_| rng.random_bool(keep_prob_f64));
123
124 let mut result = gradients.clone();
126 for (g, &m) in result.iter_mut().zip(mask.iter()) {
127 if !m {
128 *g = A::zero();
129 } else {
130 *g = *g / keep_prob;
132 }
133 }
134
135 result
136 }
137}
138
139impl<A: Float + Debug + ScalarOperand + Send + Sync, D: Dimension + Send + Sync> Regularizer<A, D>
140 for DropConnect<A>
141{
142 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
143 let masked_gradients = self.apply_to_gradients(gradients, params.raw_dim(), true);
145
146 gradients.assign(&masked_gradients);
148
149 Ok(A::zero())
151 }
152
153 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
154 Ok(A::zero())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use approx::assert_relative_eq;
163 use scirs2_core::ndarray::array;
164
165 #[test]
166 fn test_dropconnect_creation() {
167 let dc = DropConnect::<f64>::new(0.5).unwrap();
169 assert_eq!(dc.drop_prob, 0.5);
170
171 assert!(DropConnect::<f64>::new(-0.1).is_err());
173 assert!(DropConnect::<f64>::new(1.1).is_err());
174 }
175
176 #[test]
177 fn test_dropconnect_training_mode() {
178 let dc = DropConnect::new(0.5).unwrap();
179 let weights = array![[1.0, 2.0], [3.0, 4.0]];
180
181 let masked_weights = dc.apply_to_weights(&weights, true);
183
184 let _zeros = masked_weights.iter().filter(|&&x| x == 0.0).count();
186
187 for (&original, &masked) in weights.iter().zip(masked_weights.iter()) {
189 if masked != 0.0 {
190 assert_relative_eq!(masked, original * 2.0, epsilon = 1e-10);
192 }
193 }
194 }
195
196 #[test]
197 fn test_dropconnect_inference_mode() {
198 let dc = DropConnect::new(0.5).unwrap();
199 let weights = array![[1.0, 2.0], [3.0, 4.0]];
200
201 let inference_weights = dc.apply_to_weights(&weights, false);
203 assert_eq!(weights, inference_weights);
204 }
205
206 #[test]
207 fn test_dropconnect_zero_probability() {
208 let dc = DropConnect::new(0.0).unwrap();
209 let weights = array![[1.0, 2.0], [3.0, 4.0]];
210
211 let result = dc.apply_to_weights(&weights, true);
213 assert_eq!(weights, result);
214 }
215
216 #[test]
217 fn test_dropconnect_gradients() {
218 let dc = DropConnect::new(0.5).unwrap();
219 let gradients = array![[1.0, 1.0], [1.0, 1.0]];
220 let weightsshape = gradients.raw_dim();
221
222 let masked_grads = dc.apply_to_gradients(&gradients, weightsshape, true);
224
225 for &grad in masked_grads.iter() {
227 if grad != 0.0 {
228 assert_relative_eq!(grad, 2.0, epsilon = 1e-10);
229 }
230 }
231 }
232
233 #[test]
234 fn test_regularizer_trait() {
235 let dc = DropConnect::new(0.3).unwrap();
236 let params = array![[1.0, 2.0], [3.0, 4.0]];
237 let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
238
239 let penalty = dc.penalty(¶ms).unwrap();
241 assert_eq!(penalty, 0.0); let penalty_from_apply = dc.apply(¶ms, &mut gradient).unwrap();
245 assert_eq!(penalty_from_apply, 0.0);
246
247 let zeros = gradient.iter().filter(|&&x| x == 0.0).count();
249 assert!(zeros <= 4); }
251}