use std::collections::HashMap;
use crate::autograd::AutogradError;
use crate::nn::{Module, Parameter};
use crate::tensor::Tensor;
pub struct Dropout {
pub p: f32,
is_training: bool,
}
impl Dropout {
pub fn new(p: f32) -> Self {
assert!(p >= 0.0 && p < 1.0, "Dropout: p must be in [0, 1)");
Self {
p,
is_training: true,
}
}
pub fn forward(&self, input: &Tensor) -> Tensor {
if !self.is_training || self.p == 0.0 {
return input.clone();
}
input.dropout(self.p)
}
}
impl Module for Dropout {
fn parameters(&self) -> Vec<Parameter> {
vec![]
}
fn train(&mut self) {
self.is_training = true;
}
fn eval(&mut self) {
self.is_training = false;
}
fn state_dict(&self, _prefix: &str) -> HashMap<String, Tensor> {
HashMap::new()
}
fn load_state_dict(
&mut self,
_dict: &HashMap<String, Tensor>,
_prefix: &str,
) -> Result<(), AutogradError> {
Ok(())
}
}