use crate::device::Device;
use crate::errors::Result;
use crate::tensor::Tensor;
use crate::traits::Layer;
use scirs2_core::random::*;
#[derive(Debug, Clone)]
pub struct Dropout {
p: f32,
training: bool,
device: Device,
}
impl Dropout {
pub fn new(p: f32) -> Self {
Self::new_with_device(p, Device::CPU)
}
pub fn new_with_device(p: f32, device: Device) -> Self {
assert!(
(0.0..=1.0).contains(&p),
"Dropout probability must be between 0.0 and 1.0, got {}",
p
);
Self {
p,
training: true,
device,
}
}
pub fn set_training(&mut self, training: bool) {
self.training = training;
}
pub fn dropout_rate(&self) -> f32 {
self.p
}
pub fn is_training(&self) -> bool {
self.training
}
pub fn device(&self) -> Device {
self.device
}
pub fn to_device(mut self, device: Device) -> Self {
self.device = device;
self
}
}
impl Layer for Dropout {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
if !self.training || self.p == 0.0 {
return Ok(input);
}
let mut rng = thread_rng(); let data = input.data()?;
let mut output_data = Vec::with_capacity(data.len());
let keep_prob = 1.0 - self.p;
let scale = 1.0 / keep_prob;
for &value in &data {
if rng.random::<f32>() < keep_prob {
output_data.push(value * scale);
} else {
output_data.push(0.0);
}
}
Tensor::from_vec(output_data, &input.shape())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dropout_creation() {
let dropout = Dropout::new(0.5);
assert_eq!(dropout.dropout_rate(), 0.5);
assert!(dropout.is_training());
}
#[test]
fn test_dropout_with_device() {
let dropout = Dropout::new_with_device(0.3, Device::CPU);
assert_eq!(dropout.dropout_rate(), 0.3);
assert_eq!(dropout.device(), Device::CPU);
}
#[test]
fn test_dropout_inference_mode() {
let mut dropout = Dropout::new(0.5);
dropout.set_training(false);
let input =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Tensor from_vec failed");
let input_data = input.data().expect("operation failed in test");
let output = dropout.forward(input).expect("Forward pass failed");
let output_data = output.data().expect("operation failed in test");
assert_eq!(input_data, output_data);
}
#[test]
fn test_dropout_zero_rate() {
let dropout = Dropout::new(0.0);
let input =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Tensor from_vec failed");
let input_data = input.data().expect("operation failed in test");
let input_shape = input.shape().to_vec();
let output = dropout.forward(input).expect("Forward pass failed");
let output_data = output.data().expect("operation failed in test");
let output_shape = output.shape().to_vec();
assert_eq!(input_data, output_data);
assert_eq!(input_shape, output_shape);
}
#[test]
fn test_dropout_full_rate() {
let dropout = Dropout::new(1.0);
let input =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Tensor from_vec failed");
let input_shape = input.shape().to_vec();
let output = dropout.forward(input).expect("Forward pass failed");
let output_data = output.data().expect("operation failed in test");
let output_shape = output.shape().to_vec();
assert!(output_data.iter().all(|&x| x == 0.0));
assert_eq!(input_shape, output_shape);
}
#[test]
fn test_dropout_statistical_properties() {
let dropout = Dropout::new(0.5);
let size = 1000;
let mut zero_counts = Vec::new();
let mut sums = Vec::new();
for _ in 0..20 {
let input = Tensor::from_vec(vec![1.0; size], &[size]).expect("Tensor from_vec failed");
let output = dropout.forward(input).expect("Forward pass failed");
let output_data = output.data().expect("operation failed in test");
let zero_count = output_data.iter().filter(|&&x| x == 0.0).count();
let sum: f32 = output_data.iter().sum();
zero_counts.push(zero_count);
sums.push(sum);
}
let avg_zero_rate = zero_counts.iter().sum::<usize>() as f32 / (20.0 * size as f32);
assert!(
(avg_zero_rate - 0.5).abs() < 0.1,
"Expected ~50% zeros, got {:.1}%",
avg_zero_rate * 100.0
);
let avg_sum = sums.iter().sum::<f32>() / 20.0;
let expected_sum = size as f32; assert!(
(avg_sum - expected_sum).abs() < expected_sum * 0.2,
"Sum not preserved: expected {}, got {}",
expected_sum,
avg_sum
);
}
#[test]
#[should_panic(expected = "Dropout probability must be between 0.0 and 1.0")]
fn test_invalid_dropout_rate_high() {
Dropout::new(1.5);
}
#[test]
#[should_panic(expected = "Dropout probability must be between 0.0 and 1.0")]
fn test_invalid_dropout_rate_negative() {
Dropout::new(-0.1);
}
}