use crate::error::{NeuralError, Result};
use crate::layers::conv::PaddingMode;
use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{rngs::SmallRng, SeedableRng};
use std::fmt::Debug;
#[allow(dead_code)]
pub fn swish<F: Float>(x: F) -> F {
x * (F::one() + (-x).exp()).recip()
}
#[derive(Debug, Clone)]
pub struct MBConvConfig {
pub input_channels: usize,
pub output_channels: usize,
pub kernel_size: usize,
pub stride: usize,
pub expand_ratio: usize,
pub use_se: bool,
pub drop_connect_rate: f64,
}
pub struct EfficientNetStage {
pub mbconv_config: MBConvConfig,
pub num_blocks: usize,
}
pub struct EfficientNetConfig {
pub width_coefficient: f64,
pub depth_coefficient: f64,
pub resolution: usize,
pub dropout_rate: f64,
pub stages: Vec<EfficientNetStage>,
pub input_channels: usize,
pub num_classes: usize,
}
impl EfficientNetConfig {
pub fn efficientnet_b0(input_channels: usize, num_classes: usize) -> Self {
let stages = vec![
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 32,
output_channels: 16,
kernel_size: 3,
stride: 1,
expand_ratio: 1,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 1,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 16,
output_channels: 24,
kernel_size: 3,
stride: 2,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 2,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 24,
output_channels: 40,
kernel_size: 5,
stride: 2,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 2,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 40,
output_channels: 80,
kernel_size: 3,
stride: 2,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 3,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 80,
output_channels: 112,
kernel_size: 5,
stride: 1,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 3,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 112,
output_channels: 192,
kernel_size: 5,
stride: 2,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 4,
},
EfficientNetStage {
mbconv_config: MBConvConfig {
input_channels: 192,
output_channels: 320,
kernel_size: 3,
stride: 1,
expand_ratio: 6,
use_se: true,
drop_connect_rate: 0.2,
},
num_blocks: 1,
},
];
Self {
width_coefficient: 1.0,
depth_coefficient: 1.0,
resolution: 224,
dropout_rate: 0.2,
stages,
input_channels,
num_classes,
}
}
pub fn efficientnet_b1(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.0;
config.depth_coefficient = 1.1;
config.resolution = 240;
config.dropout_rate = 0.2;
config
}
pub fn efficientnet_b2(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.1;
config.depth_coefficient = 1.2;
config.resolution = 260;
config.dropout_rate = 0.3;
config
}
pub fn efficientnet_b3(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.2;
config.depth_coefficient = 1.4;
config.resolution = 300;
config.dropout_rate = 0.3;
config
}
pub fn efficientnet_b4(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.4;
config.depth_coefficient = 1.8;
config.resolution = 380;
config.dropout_rate = 0.4;
config
}
pub fn efficientnet_b5(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.6;
config.depth_coefficient = 2.2;
config.resolution = 456;
config.dropout_rate = 0.4;
config
}
pub fn efficientnet_b6(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 1.8;
config.depth_coefficient = 2.6;
config.resolution = 528;
config.dropout_rate = 0.5;
config
}
pub fn efficientnet_b7(input_channels: usize, num_classes: usize) -> Self {
let mut config = Self::efficientnet_b0(input_channels, num_classes);
config.width_coefficient = 2.0;
config.depth_coefficient = 3.1;
config.resolution = 600;
config.dropout_rate = 0.5;
config
}
pub fn scale_channels(&self, channels: usize) -> usize {
let scaled = (channels as f64 * self.width_coefficient).round();
(scaled as usize).div_ceil(8) * 8
}
pub fn scale_depth(&self, depth: usize) -> usize {
(depth as f64 * self.depth_coefficient).ceil() as usize
}
}
struct SqueezeExcitation<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
input_channels: usize,
#[allow(dead_code)]
squeeze_channels: usize,
fc1: Conv2D<F>,
fc2: Conv2D<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> SqueezeExcitation<F> {
pub fn new(input_channels: usize, squeeze_channels: usize) -> Result<Self> {
let mut rng = SmallRng::from_seed([42; 32]);
let fc1 = Conv2D::new(input_channels, squeeze_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let fc2 = Conv2D::new(squeeze_channels, input_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
Ok(Self {
input_channels,
squeeze_channels,
fc1,
fc2,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for SqueezeExcitation<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input, got {:?}",
shape
)));
}
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
if channels != self.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected {} input channels, got {}",
self.input_channels, channels
)));
}
let mut x = Array::zeros(IxDyn(&[batch_size, channels, 1, 1]));
for b in 0..batch_size {
for c in 0..channels {
let mut sum = F::zero();
for h in 0..height {
for w in 0..width {
sum += input[[b, c, h, w]];
}
}
x[[b, c, 0, 0]] =
sum / F::from(height * width).expect("Failed to convert to float");
}
}
let x = self.fc1.forward(&x)?;
let x = x.mapv(|v: F| v.max(F::zero()));
let x = self.fc2.forward(&x)?;
let x = x.mapv(|v| F::one() / (F::one() + (-v).exp()));
let mut result = input.clone();
for b in 0..batch_size {
for c in 0..channels {
let scale = x[[b, c, 0, 0]];
for h in 0..height {
for w in 0..width {
result[[b, c, h, w]] = input[[b, c, h, w]] * scale;
}
}
}
}
Ok(result)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.fc1.update(learning_rate)?;
self.fc2.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
struct MBConvBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
#[allow(dead_code)]
config: MBConvConfig,
has_skip_connection: bool,
expand_conv: Option<Conv2D<F>>,
expand_bn: Option<BatchNorm<F>>,
depthwise_conv: Conv2D<F>,
depthwise_bn: BatchNorm<F>,
se: Option<SqueezeExcitation<F>>,
project_conv: Conv2D<F>,
project_bn: BatchNorm<F>,
drop_connect_rate: F,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> MBConvBlock<F> {
pub fn new(config: MBConvConfig) -> Result<Self> {
let input_channels = config.input_channels;
let output_channels = config.output_channels;
let expand_ratio = config.expand_ratio;
let kernel_size = config.kernel_size;
let stride = config.stride;
let use_se = config.use_se;
let drop_connect_rate =
F::from(config.drop_connect_rate).expect("Failed to convert to float");
let mut rng = SmallRng::from_seed([42; 32]);
let has_skip_connection = input_channels == output_channels && stride == 1;
let (expand_conv, expand_bn) = if expand_ratio != 1 {
let expanded_channels = input_channels * expand_ratio;
let conv = Conv2D::new(input_channels, expanded_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
(Some(conv), Some(bn))
} else {
(None, None)
};
let expanded_channels = if expand_ratio != 1 {
input_channels * expand_ratio
} else {
input_channels
};
let depthwise_conv = Conv2D::new(
expanded_channels,
expanded_channels,
(kernel_size, kernel_size),
(stride, stride),
None,
)?
.with_padding(PaddingMode::Same);
let depthwise_bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
let se = if use_se {
let squeeze_channels = (expanded_channels as f64 / 4.0).round() as usize;
Some(SqueezeExcitation::new(expanded_channels, squeeze_channels)?)
} else {
None
};
let project_conv = Conv2D::new(expanded_channels, output_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let project_bn = BatchNorm::new(output_channels, 1e-3, 0.01, &mut rng)?;
Ok(Self {
config,
has_skip_connection,
expand_conv,
expand_bn,
depthwise_conv,
depthwise_bn,
se,
project_conv,
project_bn,
drop_connect_rate,
})
}
fn drop_connect<R: scirs2_core::random::Rng>(
&self,
input: &Array<F, IxDyn>,
rng: &mut R,
) -> Array<F, IxDyn> {
if self.drop_connect_rate <= F::zero() || !self.has_skip_connection {
return input.clone();
}
let shape = input.shape();
let mut result = input.clone();
let keep_prob = F::one() - self.drop_connect_rate;
if rng.random::<f64>() > self.drop_connect_rate.to_f64().expect("Operation failed") {
result = result.mapv(|x| x / keep_prob);
} else {
result = Array::zeros(IxDyn(shape));
}
result
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for MBConvBlock<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut rng = SmallRng::from_seed([42; 32]);
let mut x = input.clone();
if let (Some(ref expand_conv), Some(ref expand_bn)) = (&self.expand_conv, &self.expand_bn) {
x = expand_conv.forward(&x)?;
x = expand_bn.forward(&x)?;
x = x.mapv(swish); }
x = self.depthwise_conv.forward(&x)?;
x = self.depthwise_bn.forward(&x)?;
x = x.mapv(swish);
if let Some(ref se) = self.se {
x = se.forward(&x)?;
}
x = self.project_conv.forward(&x)?;
x = self.project_bn.forward(&x)?;
if self.has_skip_connection {
x = self.drop_connect(&x, &mut rng);
let mut result = input.clone();
for i in 0..result.len() {
result[i] += x[i];
}
x = result;
}
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
if let (Some(ref mut expand_conv), Some(ref mut expand_bn)) =
(&mut self.expand_conv, &mut self.expand_bn)
{
expand_conv.update(learning_rate)?;
expand_bn.update(learning_rate)?;
}
self.depthwise_conv.update(learning_rate)?;
self.depthwise_bn.update(learning_rate)?;
if let Some(ref mut se) = self.se {
se.update(learning_rate)?;
}
self.project_conv.update(learning_rate)?;
self.project_bn.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub struct EfficientNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
config: EfficientNetConfig,
stem_conv: Conv2D<F>,
stem_bn: BatchNorm<F>,
blocks: Vec<MBConvBlock<F>>,
head_conv: Conv2D<F>,
head_bn: BatchNorm<F>,
classifier: Dense<F>,
dropout: Dropout<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> EfficientNet<F> {
pub fn new(config: EfficientNetConfig) -> Result<Self> {
let mut rng = SmallRng::from_seed([42; 32]);
let num_classes = config.num_classes;
let input_channels = config.input_channels;
let stem_channels = config.scale_channels(32);
let stem_conv = Conv2D::new(input_channels, stem_channels, (3, 3), (2, 2), None)?
.with_padding(PaddingMode::Same);
let stem_bn = BatchNorm::new(stem_channels, 1e-3, 0.01, &mut rng)?;
let mut blocks = Vec::new();
let mut in_channels = stem_channels;
for stage in &config.stages {
let num_blocks = config.scale_depth(stage.num_blocks);
let out_channels = config.scale_channels(stage.mbconv_config.output_channels);
let first_block_config = MBConvConfig {
input_channels: in_channels,
output_channels: out_channels,
kernel_size: stage.mbconv_config.kernel_size,
stride: stage.mbconv_config.stride,
expand_ratio: stage.mbconv_config.expand_ratio,
use_se: stage.mbconv_config.use_se,
drop_connect_rate: stage.mbconv_config.drop_connect_rate,
};
blocks.push(MBConvBlock::new(first_block_config)?);
for _ in 1..num_blocks {
let block_config = MBConvConfig {
input_channels: out_channels,
output_channels: out_channels,
kernel_size: stage.mbconv_config.kernel_size,
stride: 1,
expand_ratio: stage.mbconv_config.expand_ratio,
use_se: stage.mbconv_config.use_se,
drop_connect_rate: stage.mbconv_config.drop_connect_rate,
};
blocks.push(MBConvBlock::new(block_config)?);
}
in_channels = out_channels;
}
let head_channels = config.scale_channels(1280);
let head_conv = Conv2D::new(in_channels, head_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let head_bn = BatchNorm::new(head_channels, 1e-3, 0.01, &mut rng)?;
let classifier = Dense::new(head_channels, num_classes, None, &mut rng)?;
let dropout = Dropout::new(config.dropout_rate, &mut rng)?;
Ok(Self {
config,
stem_conv,
stem_bn,
blocks,
head_conv,
head_bn,
classifier,
dropout,
})
}
pub fn efficientnet_b0(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b0(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b1(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b1(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b2(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b2(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b3(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b3(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b4(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b4(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b5(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b5(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b6(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b6(input_channels, num_classes);
Self::new(config)
}
pub fn efficientnet_b7(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = EfficientNetConfig::efficientnet_b7(input_channels, num_classes);
Self::new(config)
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for EfficientNet<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 || shape[1] != self.config.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected input shape [batch_size, {}, height, width], got {:?}",
self.config.input_channels, shape
)));
}
let batch_size = shape[0];
let mut x = self.stem_conv.forward(input)?;
x = self.stem_bn.forward(&x)?;
x = x.mapv(swish);
for block in &self.blocks {
x = block.forward(&x)?;
}
x = self.head_conv.forward(&x)?;
x = self.head_bn.forward(&x)?;
x = x.mapv(swish);
let channels = x.shape()[1];
let height = x.shape()[2];
let width = x.shape()[3];
let mut pooled = Array::zeros(IxDyn(&[batch_size, channels]));
for b in 0..batch_size {
for c in 0..channels {
let mut sum = F::zero();
for h in 0..height {
for w in 0..width {
sum += x[[b, c, h, w]];
}
}
pooled[[b, c]] = sum / F::from(height * width).expect("Failed to convert to float");
}
}
let pooled = self.dropout.forward(&pooled)?;
let logits = self.classifier.forward(&pooled)?;
Ok(logits)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.stem_conv.update(learning_rate)?;
self.stem_bn.update(learning_rate)?;
for block in &mut self.blocks {
block.update(learning_rate)?;
}
self.head_conv.update(learning_rate)?;
self.head_bn.update(learning_rate)?;
self.classifier.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}