use crate::random_ops::rand;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{creation::rand_like, Tensor};
pub fn dropout(input: &Tensor, p: f64, training: bool, inplace: bool) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
if !(0.0..=1.0).contains(&p) {
return Err(TorshError::invalid_argument_with_context(
&format!("Dropout probability must be between 0 and 1, got {}", p),
"dropout",
));
}
let keep_prob = 1.0 - p;
let random_tensor = rand_like(input)?;
let scale = 1.0 / keep_prob;
let random_data = random_tensor.data()?;
let mask_data: Vec<f32> = random_data
.iter()
.map(|&x| if x < p as f32 { 0.0 } else { scale as f32 })
.collect();
let mask = Tensor::from_data(mask_data, input.shape().dims().to_vec(), input.device())?;
let output = if inplace {
input.clone().mul_op(&mask)?
} else {
input.mul_op(&mask)?
};
Ok(output)
}
pub fn dropout1d(input: &Tensor, p: f64, training: bool, inplace: bool) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
let shape = input.shape().dims().to_vec();
if shape.len() != 3 {
return Err(TorshError::invalid_argument_with_context(
&format!("Expected 3D input (N, C, L), got {}D", shape.len()),
"dropout1d",
));
}
let mask_shape = vec![shape[0], shape[1], 1];
let random_tensor = rand(&mask_shape, Some(0.0), Some(1.0), None)?;
let keep_prob = 1.0 - p;
let scale = 1.0 / keep_prob;
let mask_data: Vec<f32> = random_tensor
.to_vec()?
.iter()
.map(|&x| if x < p as f32 { 0.0 } else { scale as f32 })
.collect();
let mut broadcast_data = vec![0.0f32; shape[0] * shape[1] * shape[2]];
for n in 0..shape[0] {
for c in 0..shape[1] {
let mask_value = mask_data[n * shape[1] + c];
for l in 0..shape[2] {
let idx = (n * shape[1] + c) * shape[2] + l;
broadcast_data[idx] = mask_value;
}
}
}
let mask = Tensor::from_data(broadcast_data, shape.clone(), input.device())?;
let output = if inplace {
input.clone().mul_op(&mask)?
} else {
input.mul_op(&mask)?
};
Ok(output)
}
pub fn dropout2d(input: &Tensor, p: f64, training: bool, inplace: bool) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
let shape = input.shape().dims().to_vec();
if shape.len() != 4 {
return Err(TorshError::invalid_argument_with_context(
&format!("Expected 4D input (N, C, H, W), got {}D", shape.len()),
"dropout2d",
));
}
let mask_shape = vec![shape[0], shape[1], 1, 1];
let random_tensor = rand(&mask_shape, Some(0.0), Some(1.0), None)?;
let keep_prob = 1.0 - p;
let scale = 1.0 / keep_prob;
let mask_data: Vec<f32> = random_tensor
.to_vec()?
.iter()
.map(|&x| if x < p as f32 { 0.0 } else { scale as f32 })
.collect();
let mut broadcast_data = vec![0.0f32; shape[0] * shape[1] * shape[2] * shape[3]];
for n in 0..shape[0] {
for c in 0..shape[1] {
let mask_value = mask_data[n * shape[1] + c];
for h in 0..shape[2] {
for w in 0..shape[3] {
let idx = ((n * shape[1] + c) * shape[2] + h) * shape[3] + w;
broadcast_data[idx] = mask_value;
}
}
}
}
let mask = Tensor::from_data(broadcast_data, shape.clone(), input.device())?;
let output = if inplace {
input.clone().mul_op(&mask)?
} else {
input.mul_op(&mask)?
};
Ok(output)
}
pub fn dropout3d(input: &Tensor, p: f64, training: bool, inplace: bool) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
let shape = input.shape().dims().to_vec();
if shape.len() != 5 {
return Err(TorshError::invalid_argument_with_context(
&format!("Expected 5D input (N, C, D, H, W), got {}D", shape.len()),
"dropout3d",
));
}
let mask_shape = vec![shape[0], shape[1], 1, 1, 1];
let random_tensor = rand(&mask_shape, Some(0.0), Some(1.0), None)?;
let keep_prob = 1.0 - p;
let scale = 1.0 / keep_prob;
let mask_data: Vec<f32> = random_tensor
.to_vec()?
.iter()
.map(|&x| if x < p as f32 { 0.0 } else { scale as f32 })
.collect();
let mut broadcast_data = vec![0.0f32; shape[0] * shape[1] * shape[2] * shape[3] * shape[4]];
for n in 0..shape[0] {
for c in 0..shape[1] {
let mask_value = mask_data[n * shape[1] + c];
for d in 0..shape[2] {
for h in 0..shape[3] {
for w in 0..shape[4] {
let idx =
(((n * shape[1] + c) * shape[2] + d) * shape[3] + h) * shape[4] + w;
broadcast_data[idx] = mask_value;
}
}
}
}
}
let mask = Tensor::from_data(broadcast_data, shape.clone(), input.device())?;
let output = if inplace {
input.clone().mul_op(&mask)?
} else {
input.mul_op(&mask)?
};
Ok(output)
}
pub fn alpha_dropout(input: &Tensor, p: f64, training: bool, inplace: bool) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
if !(0.0..=1.0).contains(&p) {
return Err(TorshError::invalid_argument_with_context(
&format!(
"Alpha dropout probability must be between 0 and 1, got {}",
p
),
"alpha_dropout",
));
}
let alpha = 1.673_263_2_f32;
let scale = 1.050_701_f32;
let alpha_p = -alpha * scale;
let a = ((1.0f32 - p as f32) * (1.0f32 + p as f32 * alpha_p.powi(2))).sqrt();
let b = -a * alpha_p * p as f32;
let random_tensor = rand_like(input)?;
let random_data = random_tensor.data()?;
let mask_data: Vec<f32> = random_data
.iter()
.map(|&x| if x > p as f32 { 1.0 } else { 0.0 })
.collect();
let mask = Tensor::from_data(mask_data, input.shape().dims().to_vec(), input.device())?;
if inplace {
let not_mask = mask.neg()?.add_scalar(1.0)?;
let alpha_term = not_mask.mul_scalar(alpha_p)?;
let x = input.clone().mul_op(&mask)?.add_op(&alpha_term)?;
x.mul_scalar(a)?.add_scalar(b)
} else {
let not_mask = mask.neg()?.add_scalar(1.0)?;
let alpha_term = not_mask.mul_scalar(alpha_p)?;
let x = input.mul_op(&mask)?.add_op(&alpha_term)?;
x.mul_scalar(a)?.add_scalar(b)
}
}
pub fn feature_alpha_dropout(
input: &Tensor,
p: f64,
training: bool,
inplace: bool,
) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
let shape = input.shape().dims().to_vec();
if shape.len() < 2 {
return Err(TorshError::invalid_argument_with_context(
"Feature alpha dropout requires at least 2D input",
"feature_alpha_dropout",
));
}
let alpha = 1.673_263_2_f32;
let scale = 1.050_701_f32;
let alpha_p = -alpha * scale;
let a = ((1.0f32 - p as f32) * (1.0f32 + p as f32 * alpha_p.powi(2))).sqrt();
let b = -a * alpha_p * p as f32;
let mut mask_shape = vec![shape[0], shape[1]];
mask_shape.extend(vec![1; shape.len() - 2]);
let random_tensor = rand(&mask_shape, Some(0.0), Some(1.0), None)?;
let mask_data: Vec<f32> = random_tensor
.to_vec()?
.iter()
.map(|&x| if x > p as f32 { 1.0 } else { 0.0 })
.collect();
let total_size: usize = shape.iter().product();
let mut broadcast_data = vec![0.0f32; total_size];
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for i in 0..total_size {
let mut idx = i;
let n = idx / strides[0];
idx %= strides[0];
let c = idx / strides[1];
let mask_idx = n * shape[1] + c;
broadcast_data[i] = mask_data[mask_idx];
}
let mask = Tensor::from_data(broadcast_data, shape.clone(), input.device())?;
if inplace {
let not_mask = mask.neg()?.add_scalar(1.0)?;
let alpha_term = not_mask.mul_scalar(alpha_p)?;
let x = input.clone().mul_op(&mask)?.add_op(&alpha_term)?;
x.mul_scalar(a)?.add_scalar(b)
} else {
let not_mask = mask.neg()?.add_scalar(1.0)?;
let alpha_term = not_mask.mul_scalar(alpha_p)?;
let x = input.mul_op(&mask)?.add_op(&alpha_term)?;
x.mul_scalar(a)?.add_scalar(b)
}
}
pub fn fractional_max_pool2d_with_indices(
input: &Tensor,
_kernel_size: (usize, usize),
_output_size: Option<(usize, usize)>,
_output_ratio: Option<(f64, f64)>,
_return_indices: bool,
) -> TorshResult<(Tensor, Option<Tensor>)> {
Ok((input.clone(), None))
}
pub fn gaussian_dropout(
input: &Tensor,
p: f64,
training: bool,
inplace: bool,
) -> TorshResult<Tensor> {
if !training || p == 0.0 {
return Ok(input.clone());
}
if !(0.0..1.0).contains(&p) {
return Err(TorshError::invalid_argument_with_context(
&format!("Gaussian dropout probability must be in [0, 1), got {}", p),
"gaussian_dropout",
));
}
let std = (p / (1.0 - p)).sqrt();
let randn = torsh_tensor::creation::randn_like(input);
let noise = randn?.mul_scalar(std as f32)?.add_scalar(1.0)?;
if inplace {
input.clone().mul_op(&noise)
} else {
input.mul_op(&noise)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_dropout_probability_validation() -> TorshResult<()> {
let input = ones::<f32>(&[2, 3])?;
assert!(dropout(&input, -0.1, true, false).is_err());
assert!(dropout(&input, 1.1, true, false).is_err());
assert!(dropout(&input, 0.0, true, false).is_ok());
assert!(dropout(&input, 0.5, false, false).is_ok());
Ok(())
}
}