use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::{Arc, RwLock};
pub struct Dropout<F: Float + Debug + Send + Sync + NumAssign> {
p: F,
rng: Arc<RwLock<Box<dyn Rng + Send + Sync>>>,
training: bool,
input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
mask_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + Send + Sync + NumAssign> std::fmt::Debug for Dropout<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dropout")
.field("p", &self.p)
.field("rng", &"<dyn Rng>")
.field("training", &self.training)
.finish()
}
}
impl<F: Float + Debug + Send + Sync + NumAssign> Clone for Dropout<F> {
fn clone(&self) -> Self {
let rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
Self {
p: self.p,
rng: Arc::new(RwLock::new(Box::new(rng))),
training: self.training,
input_cache: Arc::new(RwLock::new(None)),
mask_cache: Arc::new(RwLock::new(None)),
_phantom: PhantomData,
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> Dropout<F> {
pub fn new<R: Rng + 'static + Clone + Send + Sync>(p: f64, rng: &mut R) -> Result<Self> {
if !(0.0..1.0).contains(&p) {
return Err(NeuralError::InvalidArchitecture(
"Dropout probability must be in [0, 1)".to_string(),
));
}
let p = F::from(p).ok_or_else(|| {
NeuralError::InvalidArchitecture(
"Failed to convert dropout probability to type F".to_string(),
)
})?;
Ok(Self {
p,
rng: Arc::new(RwLock::new(Box::new(rng.clone()))),
training: true,
input_cache: Arc::new(RwLock::new(None)),
mask_cache: Arc::new(RwLock::new(None)),
_phantom: PhantomData,
})
}
pub fn set_training(&mut self, training: bool) {
self.training = training;
}
pub fn p(&self) -> f64 {
self.p.to_f64().unwrap_or(0.0)
}
pub fn is_training(&self) -> bool {
self.training
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> Layer<F> for Dropout<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if let Ok(mut cache) = self.input_cache.write() {
*cache = Some(input.clone());
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on input cache".to_string(),
));
}
if !self.training || self.p == F::zero() {
return Ok(input.clone());
}
let mut mask = Array::<F, IxDyn>::from_elem(input.dim(), F::one());
let one = F::one();
let zero = F::zero();
{
let mut rng_guard = match self.rng.write() {
Ok(guard) => guard,
Err(_) => {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on RNG".to_string(),
))
}
};
for elem in mask.iter_mut() {
if F::from((**rng_guard).random::<f64>()).expect("Operation failed") < self.p {
*elem = zero;
}
}
}
let scale = one / (one - self.p);
if let Ok(mut cache) = self.mask_cache.write() {
*cache = Some(mask.clone());
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on mask cache".to_string(),
));
}
let output = input * &mask * scale;
Ok(output)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
if !self.training {
return Ok(grad_output.clone());
}
let mask = {
let cache = match self.mask_cache.read() {
Ok(cache) => cache,
Err(_) => {
return Err(NeuralError::InferenceError(
"Failed to acquire read lock on mask cache".to_string(),
))
}
};
match cache.as_ref() {
Some(mask) => mask.clone(),
None => {
return Err(NeuralError::InferenceError(
"No cached mask for backward pass".to_string(),
))
}
}
};
let one = F::one();
let scale = one / (one - self.p);
let grad_input = grad_output * &mask * scale;
Ok(grad_input)
}
fn update(&mut self, _learningrate: F) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn set_training(&mut self, training: bool) {
self.training = training;
}
fn is_training(&self) -> bool {
self.training
}
fn layer_type(&self) -> &str {
"Dropout"
}
fn parameter_count(&self) -> usize {
0 }
fn layer_description(&self) -> String {
format!("type:Dropout, p:{:.3}", self.p())
}
}
unsafe impl<F: Float + Debug + Send + Sync + NumAssign> Send for Dropout<F> {}
unsafe impl<F: Float + Debug + Send + Sync + NumAssign> Sync for Dropout<F> {}