use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::random::{RngExt, StdRng};
#[derive(Debug, Clone)]
pub struct DropBlock {
pub block_size: usize,
pub drop_prob: f64,
keep_prob: f64,
}
impl DropBlock {
pub fn new(block_size: usize, drop_prob: f64) -> TrainResult<Self> {
if block_size == 0 {
return Err(TrainError::InvalidParameter(
"block_size must be at least 1".to_string(),
));
}
if block_size.is_multiple_of(2) {
return Err(TrainError::InvalidParameter(
"block_size must be odd".to_string(),
));
}
if !(0.0..=1.0).contains(&drop_prob) {
return Err(TrainError::InvalidParameter(
"drop_prob must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self {
block_size,
drop_prob,
keep_prob: 1.0 - drop_prob,
})
}
pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
if !(0.0..=1.0).contains(&drop_prob) {
return Err(TrainError::InvalidParameter(
"drop_prob must be between 0.0 and 1.0".to_string(),
));
}
self.drop_prob = drop_prob;
self.keep_prob = 1.0 - drop_prob;
Ok(())
}
pub fn apply(
&self,
activations: &ArrayView2<f64>,
training: bool,
rng: &mut StdRng,
) -> TrainResult<Array2<f64>> {
if !training || self.drop_prob == 0.0 {
return Ok(activations.to_owned());
}
let (height, width) = activations.dim();
if height < self.block_size || width < self.block_size {
return Err(TrainError::InvalidParameter(format!(
"Activation map size ({}x{}) is smaller than block_size ({})",
height, width, self.block_size
)));
}
let gamma = self.drop_prob * (height * width) as f64
/ ((height - self.block_size + 1) * (width - self.block_size + 1)) as f64
/ (self.block_size * self.block_size) as f64;
let mut mask = Array2::ones((height, width));
let half_block = self.block_size / 2;
for i in 0..height {
for j in 0..width {
if rng.random::<f64>() < gamma {
let i_start = i.saturating_sub(half_block);
let i_end = (i + half_block + 1).min(height);
let j_start = j.saturating_sub(half_block);
let j_end = (j + half_block + 1).min(width);
for ii in i_start..i_end {
for jj in j_start..j_end {
mask[[ii, jj]] = 0.0;
}
}
}
}
}
let mut output = activations.to_owned();
let count_kept = mask.iter().filter(|&&x| x == 1.0).count();
let normalization_factor = if count_kept > 0 {
(height * width) as f64 / count_kept as f64
} else {
1.0
};
for i in 0..height {
for j in 0..width {
output[[i, j]] *= mask[[i, j]] * normalization_factor;
}
}
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct LinearDropBlockScheduler {
pub drop_prob_target: f64,
pub total_steps: usize,
}
impl LinearDropBlockScheduler {
pub fn new(drop_prob_target: f64, total_steps: usize) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&drop_prob_target) {
return Err(TrainError::InvalidParameter(
"drop_prob_target must be between 0.0 and 1.0".to_string(),
));
}
if total_steps == 0 {
return Err(TrainError::InvalidParameter(
"total_steps must be at least 1".to_string(),
));
}
Ok(Self {
drop_prob_target,
total_steps,
})
}
pub fn get_drop_prob(&self, current_step: usize) -> f64 {
if current_step >= self.total_steps {
return self.drop_prob_target;
}
let progress = current_step as f64 / self.total_steps as f64;
self.drop_prob_target * progress
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::SeedableRng;
#[test]
fn test_dropblock_creation() {
let db = DropBlock::new(7, 0.1).expect("unwrap");
assert_eq!(db.block_size, 7);
assert_eq!(db.drop_prob, 0.1);
assert_eq!(db.keep_prob, 0.9);
}
#[test]
fn test_dropblock_invalid_params() {
assert!(DropBlock::new(0, 0.1).is_err());
assert!(DropBlock::new(4, 0.1).is_err());
assert!(DropBlock::new(7, -0.1).is_err());
assert!(DropBlock::new(7, 1.5).is_err());
}
#[test]
fn test_dropblock_set_drop_prob() {
let mut db = DropBlock::new(7, 0.1).expect("unwrap");
db.set_drop_prob(0.2).expect("unwrap");
assert_eq!(db.drop_prob, 0.2);
assert_eq!(db.keep_prob, 0.8);
assert!(db.set_drop_prob(1.5).is_err());
}
#[test]
fn test_dropblock_inference_mode() {
let db = DropBlock::new(3, 0.5).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::ones((10, 10));
let output = db
.apply(&activations.view(), false, &mut rng)
.expect("unwrap");
assert_eq!(output, activations);
}
#[test]
fn test_dropblock_zero_prob() {
let db = DropBlock::new(3, 0.0).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::ones((10, 10));
let output = db
.apply(&activations.view(), true, &mut rng)
.expect("unwrap");
assert_eq!(output, activations);
}
#[test]
fn test_dropblock_training_mode() {
let db = DropBlock::new(3, 0.3).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::ones((20, 20));
let output = db
.apply(&activations.view(), true, &mut rng)
.expect("unwrap");
assert_eq!(output.shape(), activations.shape());
let zeros_count = output.iter().filter(|&&x| x == 0.0).count();
assert!(zeros_count > 0, "Expected some blocks to be dropped");
assert!(zeros_count < 400, "Not all values should be dropped");
}
#[test]
fn test_dropblock_small_activation_map() {
let db = DropBlock::new(7, 0.1).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::ones((5, 5));
let result = db.apply(&activations.view(), true, &mut rng);
assert!(result.is_err());
}
#[test]
fn test_linear_scheduler_creation() {
let scheduler = LinearDropBlockScheduler::new(0.1, 1000).expect("unwrap");
assert_eq!(scheduler.drop_prob_target, 0.1);
assert_eq!(scheduler.total_steps, 1000);
}
#[test]
fn test_linear_scheduler_invalid_params() {
assert!(LinearDropBlockScheduler::new(-0.1, 1000).is_err());
assert!(LinearDropBlockScheduler::new(1.5, 1000).is_err());
assert!(LinearDropBlockScheduler::new(0.1, 0).is_err());
}
#[test]
fn test_linear_scheduler_interpolation() {
let scheduler = LinearDropBlockScheduler::new(0.1, 100).expect("unwrap");
assert_eq!(scheduler.get_drop_prob(0), 0.0);
let mid_prob = scheduler.get_drop_prob(50);
assert!((mid_prob - 0.05).abs() < 1e-10);
assert_eq!(scheduler.get_drop_prob(100), 0.1);
assert_eq!(scheduler.get_drop_prob(150), 0.1);
}
#[test]
fn test_dropblock_with_scheduler() {
let mut db = DropBlock::new(3, 0.0).expect("unwrap");
let scheduler = LinearDropBlockScheduler::new(0.2, 100).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::ones((20, 20));
for step in [0, 50, 100] {
let drop_prob = scheduler.get_drop_prob(step);
db.set_drop_prob(drop_prob).expect("unwrap");
let output = db
.apply(&activations.view(), true, &mut rng)
.expect("unwrap");
assert_eq!(output.shape(), activations.shape());
}
}
#[test]
fn test_dropblock_normalization() {
let db = DropBlock::new(3, 0.1).expect("unwrap");
let mut rng = StdRng::seed_from_u64(42);
let activations = Array2::from_elem((20, 20), 1.0);
let output = db
.apply(&activations.view(), true, &mut rng)
.expect("unwrap");
let input_sum = activations.sum();
let output_sum = output.sum();
let relative_diff = (output_sum - input_sum).abs() / input_sum;
assert!(
relative_diff < 0.5,
"Normalization should preserve approximate expected value"
);
}
}