use super::{Regularization, RegularizationClone};
use crate::common::matrix::DMat;
use crate::common::random::Randomizer;
use crate::error::NetworkError;
use serde::{Deserialize, Serialize};
use typetag;
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct DropoutRegularization {
dropout_rate: f32,
randomizer: Randomizer,
}
#[typetag::serde]
impl Regularization for DropoutRegularization {
fn apply(&self, params: &mut [&mut DMat], _grads: &mut [&mut DMat]) {
for param in params.iter_mut() {
param.apply_with_indices(|_, _, v| {
if self.randomizer.float32() < self.dropout_rate {
*v = 0.0;
}
});
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl RegularizationClone for DropoutRegularization {
fn clone_box(&self) -> Box<dyn Regularization> {
Box::new(self.clone())
}
}
pub struct Dropout {
dropout_rate: f32,
seed: Option<u64>,
}
impl Dropout {
fn new() -> Self {
Self {
dropout_rate: 0.5,
seed: None,
}
}
pub fn dropout_rate(mut self, dropout_rate: f32) -> Self {
self.dropout_rate = dropout_rate;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.dropout_rate < 0.0 || self.dropout_rate > 1.0 {
return Err(NetworkError::ConfigError(format!(
"Dropout rate must be in the range [0.0, 1.0], but was {}",
self.dropout_rate
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn Regularization>, NetworkError> {
self.validate()?;
Ok(Box::new(DropoutRegularization {
dropout_rate: self.dropout_rate,
randomizer: Randomizer::new(self.seed),
}))
}
}
impl Default for Dropout {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::matrix::DMat;
use crate::util;
#[test]
fn test_dropout_regularization() {
let mut params = [DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0])];
let mut grads = [DMat::new(2, 2, &[0.1, 0.1, 0.1, 0.1])];
let dropout = Dropout::new().dropout_rate(0.5).seed(42).build().unwrap();
let mut params_refs: Vec<&mut DMat> = params.iter_mut().collect();
let mut grads_refs: Vec<&mut DMat> = grads.iter_mut().collect();
dropout.apply(&mut params_refs, &mut grads_refs);
let flattened = util::flatten(¶ms[0]);
assert!(flattened.iter().any(|&v| v == 0.0));
}
#[test]
fn test_dropout_builder_validate() {
let dropout = Dropout::new().dropout_rate(0.5).seed(42);
assert!(dropout.validate().is_ok());
let dropout_invalid = Dropout::new().dropout_rate(1.5).seed(42);
assert!(dropout_invalid.validate().is_err());
}
}