use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
pub struct Dropout {
base: ModuleBase,
p: f32,
inplace: bool,
}
impl Dropout {
pub fn new(p: f32) -> Self {
Self {
base: ModuleBase::new(),
p: p.clamp(0.0, 1.0),
inplace: false,
}
}
pub fn with_inplace(p: f32, inplace: bool) -> Self {
Self {
base: ModuleBase::new(),
p: p.clamp(0.0, 1.0),
inplace,
}
}
}
impl Default for Dropout {
fn default() -> Self {
Self::new(0.5)
}
}
impl Module for Dropout {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() {
return Ok(input.clone());
}
if self.p == 0.0 {
return Ok(input.clone());
}
if self.p == 1.0 {
return zeros(input.shape().dims());
}
let keep_prob = 1.0 - self.p;
let mask = full(input.shape().dims(), keep_prob)?;
let dropped = input.mul_op(&mask)?;
let scale = 1.0 / keep_prob;
dropped.mul_op(&full(input.shape().dims(), scale)?)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Dropout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dropout")
.field("p", &self.p)
.field("inplace", &self.inplace)
.finish()
}
}
pub struct DropConnect {
base: ModuleBase,
p: f32,
}
impl DropConnect {
pub fn new(p: f32) -> Self {
Self {
base: ModuleBase::new(),
p: p.clamp(0.0, 1.0),
}
}
}
impl Module for DropConnect {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() || self.p == 0.0 {
return Ok(input.clone());
}
let keep_prob = 1.0 - self.p;
let mask = full(input.shape().dims(), keep_prob)?; let dropped = input.mul_op(&mask)?;
let scale = 1.0 / keep_prob;
dropped.mul_op(&full(input.shape().dims(), scale)?)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Dropout2d {
base: ModuleBase,
p: f32,
}
impl Dropout2d {
pub fn new(p: f32) -> Self {
Self {
base: ModuleBase::new(),
p: p.clamp(0.0, 1.0),
}
}
}
impl Module for Dropout2d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() || self.p == 0.0 {
return Ok(input.clone());
}
let shape = input.shape();
let dims = shape.dims();
if dims.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Dropout2d expects 4D input (batch, channels, height, width)".to_string(),
));
}
let keep_prob = 1.0 - self.p;
let batch_size = dims[0];
let channels = dims[1];
let mask = full(&[batch_size, channels, 1, 1], keep_prob)?;
let dropped = input.mul_op(&mask)?;
let scale = 1.0 / keep_prob;
dropped.mul_op(&full(&[1], scale)?)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct DropBlock2d {
base: ModuleBase,
drop_prob: f32,
#[allow(dead_code)]
block_size: usize,
}
impl DropBlock2d {
pub fn new(drop_prob: f32, block_size: usize) -> Self {
Self {
base: ModuleBase::new(),
drop_prob: drop_prob.clamp(0.0, 1.0),
block_size,
}
}
}
impl Module for DropBlock2d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() || self.drop_prob == 0.0 {
return Ok(input.clone());
}
let shape = input.shape();
let dims = shape.dims();
if dims.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"DropBlock2d expects 4D input (batch, channels, height, width)".to_string(),
));
}
let keep_prob = 1.0 - self.drop_prob;
let mask = full(dims, keep_prob)?;
let dropped = input.mul_op(&mask)?;
let scale = 1.0 / keep_prob;
dropped.mul_op(&full(&[1], scale)?)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct StochasticDepth {
base: ModuleBase,
drop_prob: f32,
scale_by_keep: bool,
}
impl StochasticDepth {
pub fn new(drop_prob: f32) -> Self {
Self {
base: ModuleBase::new(),
drop_prob: drop_prob.clamp(0.0, 1.0),
scale_by_keep: true,
}
}
pub fn with_scaling(drop_prob: f32, scale_by_keep: bool) -> Self {
Self {
base: ModuleBase::new(),
drop_prob: drop_prob.clamp(0.0, 1.0),
scale_by_keep,
}
}
}
impl Module for StochasticDepth {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() || self.drop_prob == 0.0 {
return Ok(input.clone());
}
let keep_prob = 1.0 - self.drop_prob;
if self.scale_by_keep {
input.mul_op(&full(&[1], keep_prob)?)
} else {
Ok(input.clone())
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct AlphaDropout {
base: ModuleBase,
p: f32,
alpha: f32,
#[allow(dead_code)]
scale: f32,
}
impl AlphaDropout {
pub fn new(p: f32) -> Self {
let alpha = -1.7580993408473766; let scale = 1.0507009873554804;
Self {
base: ModuleBase::new(),
p: p.clamp(0.0, 1.0),
alpha,
scale,
}
}
}
impl Module for AlphaDropout {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() || self.p == 0.0 {
return Ok(input.clone());
}
let keep_prob = 1.0 - self.p;
let a = ((1.0 - keep_prob) * (1.0 + keep_prob * self.alpha * self.alpha)).sqrt();
let b = -a * self.alpha * keep_prob;
let mask = full(input.shape().dims(), keep_prob)?;
let dropped = input.mul_op(&mask)?;
let a_tensor = full(&[1], a)?;
let scaled = dropped.mul_op(&a_tensor)?;
let b_tensor = full(input.shape().dims(), b)?;
scaled.add(&b_tensor)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Cutout {
base: ModuleBase,
#[allow(dead_code)]
n_holes: usize,
#[allow(dead_code)]
length: usize,
}
impl Cutout {
pub fn new(n_holes: usize, length: usize) -> Self {
Self {
base: ModuleBase::new(),
n_holes,
length,
}
}
}
impl Module for Cutout {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.base.training() {
return Ok(input.clone());
}
let shape = input.shape();
let dims = shape.dims();
if dims.len() < 2 {
return Ok(input.clone());
}
Ok(input.clone())
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Mixup {
#[allow(dead_code)]
alpha: f32,
}
impl Mixup {
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
pub fn apply(&self, input: &Tensor, _training: bool) -> Result<(Tensor, f32)> {
let lambda = 0.5_f32;
Ok((input.clone(), lambda))
}
}
pub struct CutMix {
#[allow(dead_code)]
alpha: f32,
}
impl CutMix {
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
pub fn apply(&self, input: &Tensor, _training: bool) -> Result<(Tensor, f32)> {
let lambda = 0.5_f32;
Ok((input.clone(), lambda))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dropout_creation() {
let dropout = Dropout::new(0.5);
assert!((dropout.p - 0.5).abs() < 1e-6);
}
#[test]
fn test_dropconnect_creation() {
let dropconnect = DropConnect::new(0.3);
assert!((dropconnect.p - 0.3).abs() < 1e-6);
}
#[test]
fn test_dropout2d_creation() {
let dropout2d = Dropout2d::new(0.5);
assert!((dropout2d.p - 0.5).abs() < 1e-6);
}
#[test]
fn test_dropblock_creation() {
let dropblock = DropBlock2d::new(0.1, 7);
assert!((dropblock.drop_prob - 0.1).abs() < 1e-6);
assert_eq!(dropblock.block_size, 7);
}
#[test]
fn test_stochastic_depth_creation() {
let stoch_depth = StochasticDepth::new(0.2);
assert!((stoch_depth.drop_prob - 0.2).abs() < 1e-6);
}
#[test]
fn test_alpha_dropout_creation() {
let alpha_dropout = AlphaDropout::new(0.1);
assert!((alpha_dropout.p - 0.1).abs() < 1e-6);
}
#[test]
fn test_cutout_creation() {
let cutout = Cutout::new(1, 16);
assert_eq!(cutout.n_holes, 1);
assert_eq!(cutout.length, 16);
}
#[test]
fn test_mixup_creation() {
let mixup = Mixup::new(1.0);
assert!((mixup.alpha - 1.0).abs() < 1e-6);
}
#[test]
fn test_cutmix_creation() {
let cutmix = CutMix::new(1.0);
assert!((cutmix.alpha - 1.0).abs() < 1e-6);
}
}