use crate::error::ModelError;
use crate::neural_network::Tensor;
fn dropout_backward(
grad_output: &Tensor,
mask: &Option<Tensor>,
training: bool,
rate: f32,
) -> Result<Tensor, ModelError> {
if !training || rate == 0.0 {
return Ok(grad_output.clone());
}
if rate == 1.0 {
return Ok(Tensor::zeros(grad_output.raw_dim()));
}
if let Some(mask) = mask {
let scale = 1.0 / (1.0 - rate);
let grad_input = grad_output * mask * scale;
Ok(grad_input)
} else {
Err(ModelError::ProcessingError(
"Forward pass has not been run".to_string(),
))
}
}
fn dropout_output_shape(input_shape: &[usize]) -> String {
if !input_shape.is_empty() {
format!(
"({})",
input_shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ")
)
} else {
String::from("Unknown")
}
}
fn apply_spatial_dropout_threshold(mask_2d: &mut Tensor, rate: f32, parallel_threshold: usize) {
let total_elements = mask_2d.len();
if total_elements >= parallel_threshold {
mask_2d.par_mapv_inplace(|x| if x >= rate { 1.0 } else { 0.0 });
} else {
mask_2d.mapv_inplace(|x| if x >= rate { 1.0 } else { 0.0 });
}
}
pub mod dropout;
pub mod spatial_dropout_1d;
pub mod spatial_dropout_2d;
pub mod spatial_dropout_3d;
pub use dropout::Dropout;
pub use spatial_dropout_1d::SpatialDropout1D;
pub use spatial_dropout_2d::SpatialDropout2D;
pub use spatial_dropout_3d::SpatialDropout3D;