use crate::error::{NeuralError, Result};
use crate::layers::conv::PaddingMode;
use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::SeedableRng;
use std::fmt::Debug;
#[allow(dead_code)]
pub fn relu6<F: Float>(x: F) -> F {
let zero = F::zero();
let six = F::from(6.0).expect("Failed to convert constant to float");
x.max(zero).min(six)
}
#[allow(dead_code)]
pub fn hard_sigmoid<F: Float>(x: F) -> F {
let zero = F::zero();
let six = F::from(6.0).expect("Failed to convert constant to float");
let _one = F::one(); let three = F::from(3.0).expect("Failed to convert constant to float");
(x + three).max(zero).min(six) / six
}
#[allow(dead_code)]
pub fn hard_swish<F: Float>(x: F) -> F {
x * hard_sigmoid(x)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MobileNetVersion {
V1,
V2,
V3Small,
V3Large,
}
#[derive(Debug, Clone)]
pub struct ConvBlockConfig {
pub input_channels: usize,
pub output_channels: usize,
pub kernel_size: usize,
pub stride: usize,
pub use_residual: bool,
pub expand_ratio: usize,
pub use_se: bool,
pub activation: String,
}
pub struct MobileNetConfig {
pub version: MobileNetVersion,
pub width_multiplier: f64,
pub resolution_multiplier: f64,
pub dropout_rate: f64,
pub blocks: Vec<ConvBlockConfig>,
pub input_channels: usize,
pub num_classes: usize,
}
impl MobileNetConfig {
pub fn mobilenet_v1(input_channels: usize, num_classes: usize) -> Self {
let mut blocks = Vec::new();
blocks.push(ConvBlockConfig {
input_channels,
output_channels: 32,
kernel_size: 3,
stride: 2,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "relu".to_string(),
});
let channels = [
64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024,
];
let strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1];
let mut input_c = 32;
for i in 0..channels.len() {
blocks.push(ConvBlockConfig {
input_channels: input_c,
output_channels: channels[i],
kernel_size: 3,
stride: strides[i],
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "relu".to_string(),
});
input_c = channels[i];
}
Self {
version: MobileNetVersion::V1,
width_multiplier: 1.0,
resolution_multiplier: 1.0,
dropout_rate: 0.001,
blocks,
input_channels,
num_classes,
}
}
pub fn mobilenet_v2(input_channels: usize, num_classes: usize) -> Self {
let mut blocks = Vec::new();
blocks.push(ConvBlockConfig {
input_channels,
output_channels: 32,
kernel_size: 3,
stride: 2,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "relu6".to_string(),
});
let t = [1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]; let c = [
16, 24, 24, 32, 32, 32, 64, 64, 64, 64, 96, 96, 96, 160, 160, 160, 320,
]; let s = [1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1]; let r = [
false, true, true, true, true, true, true, true, true, true, true, true, true, true,
true, true, false,
]; let mut input_c = 32;
for i in 0..t.len() {
blocks.push(ConvBlockConfig {
input_channels: input_c,
output_channels: c[i],
kernel_size: 3,
stride: s[i],
use_residual: r[i],
expand_ratio: t[i],
use_se: false,
activation: "relu6".to_string(),
});
input_c = c[i];
}
blocks.push(ConvBlockConfig {
input_channels: 320,
output_channels: 1280,
kernel_size: 1,
stride: 1,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "relu6".to_string(),
});
Self {
version: MobileNetVersion::V2,
width_multiplier: 1.0,
resolution_multiplier: 1.0,
dropout_rate: 0.001,
blocks,
input_channels,
num_classes,
}
}
pub fn mobilenet_v3_small(input_channels: usize, num_classes: usize) -> Self {
let mut blocks = Vec::new();
blocks.push(ConvBlockConfig {
input_channels,
output_channels: 16,
kernel_size: 3,
stride: 2,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
let configs = vec![
(1.0, 16, 3, 2, true, "relu", false),
(4.0, 24, 3, 2, false, "relu", false),
(3.0, 24, 3, 1, false, "relu", true),
(3.0, 40, 5, 2, true, "hard_swish", false),
(3.0, 40, 5, 1, true, "hard_swish", true),
(6.0, 80, 3, 2, false, "hard_swish", false),
(2.5, 80, 3, 1, false, "hard_swish", true),
(2.3, 80, 3, 1, false, "hard_swish", true),
(6.0, 112, 3, 1, true, "hard_swish", false),
(6.0, 112, 3, 1, true, "hard_swish", true),
(6.0, 160, 5, 2, true, "hard_swish", false),
(6.0, 160, 5, 1, true, "hard_swish", true),
];
let mut input_c = 16;
for (exp_ratio, out_ch, kernel, stride, use_se, activation, use_res) in configs {
let expand_ratio = (exp_ratio * 100.0).round() as usize / 100;
blocks.push(ConvBlockConfig {
input_channels: input_c,
output_channels: out_ch,
kernel_size: kernel,
stride,
use_residual: use_res,
expand_ratio,
use_se,
activation: activation.to_string(),
});
input_c = out_ch;
}
blocks.push(ConvBlockConfig {
input_channels: 160,
output_channels: 960,
kernel_size: 1,
stride: 1,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
blocks.push(ConvBlockConfig {
input_channels: 960,
output_channels: 1280,
kernel_size: 1,
stride: 1,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
Self {
version: MobileNetVersion::V3Small,
width_multiplier: 1.0,
resolution_multiplier: 1.0,
dropout_rate: 0.2,
blocks,
input_channels,
num_classes,
}
}
pub fn mobilenet_v3_large(input_channels: usize, num_classes: usize) -> Self {
let mut blocks = Vec::new();
blocks.push(ConvBlockConfig {
input_channels,
output_channels: 16,
kernel_size: 3,
stride: 2,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
let configs = vec![
(1.0, 16, 3, 1, false, "relu", false),
(4.0, 24, 3, 2, false, "relu", false),
(3.0, 24, 3, 1, false, "relu", true),
(3.0, 40, 5, 2, true, "relu", false),
(3.0, 40, 5, 1, true, "relu", true),
(3.0, 40, 5, 1, true, "relu", true),
(6.0, 80, 3, 2, false, "hard_swish", false),
(2.5, 80, 3, 1, false, "hard_swish", true),
(2.3, 80, 3, 1, false, "hard_swish", true),
(2.3, 80, 3, 1, false, "hard_swish", true),
(6.0, 112, 3, 1, true, "hard_swish", false),
(6.0, 112, 3, 1, true, "hard_swish", true),
(6.0, 160, 5, 2, true, "hard_swish", false),
(6.0, 160, 5, 1, true, "hard_swish", true),
(6.0, 160, 5, 1, true, "hard_swish", true),
];
let mut input_c = 16;
for (exp_ratio, out_ch, kernel, stride, use_se, activation, use_res) in configs {
let expand_ratio = (exp_ratio * 100.0).round() as usize / 100;
blocks.push(ConvBlockConfig {
input_channels: input_c,
output_channels: out_ch,
kernel_size: kernel,
stride,
use_residual: use_res,
expand_ratio,
use_se,
activation: activation.to_string(),
});
input_c = out_ch;
}
blocks.push(ConvBlockConfig {
input_channels: 160,
output_channels: 960,
kernel_size: 1,
stride: 1,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
blocks.push(ConvBlockConfig {
input_channels: 960,
output_channels: 1280,
kernel_size: 1,
stride: 1,
use_residual: false,
expand_ratio: 1,
use_se: false,
activation: "hard_swish".to_string(),
});
Self {
version: MobileNetVersion::V3Large,
width_multiplier: 1.0,
resolution_multiplier: 1.0,
dropout_rate: 0.2,
blocks,
input_channels,
num_classes,
}
}
pub fn scale_channels(&self, channels: usize) -> usize {
let scaled = (channels as f64 * self.width_multiplier).round();
scaled.max(8.0) as usize }
}
struct SqueezeExcitation<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
input_channels: usize,
#[allow(dead_code)]
squeeze_channels: usize,
fc1: Conv2D<F>,
fc2: Conv2D<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> SqueezeExcitation<F> {
pub fn new(input_channels: usize, squeeze_channels: usize) -> Result<Self> {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let kernel_size = (1, 1);
let stride = (1, 1);
let fc1 = Conv2D::new(input_channels, squeeze_channels, kernel_size, stride, None)?
.with_padding(PaddingMode::Valid);
let fc2 = Conv2D::new(squeeze_channels, input_channels, kernel_size, stride, None)?
.with_padding(PaddingMode::Valid);
Ok(Self {
input_channels,
squeeze_channels,
fc1,
fc2,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for SqueezeExcitation<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input, got {:?}",
shape
)));
}
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
if channels != self.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected {} input channels, got {}",
self.input_channels, channels
)));
}
let mut x = Array::zeros(IxDyn(&[batch_size, channels, 1, 1]));
for b in 0..batch_size {
for c in 0..channels {
let mut sum = F::zero();
for h in 0..height {
for w in 0..width {
sum += input[[b, c, h, w]];
}
}
x[[b, c, 0, 0]] =
sum / F::from(height * width).expect("Failed to convert to float");
}
}
let x = self.fc1.forward(&x)?;
let x = x.mapv(|v: F| v.max(F::zero()));
let x = self.fc2.forward(&x)?;
let x = x.mapv(hard_sigmoid);
let mut result = input.clone();
for b in 0..batch_size {
for c in 0..channels {
let scale = x[[b, c, 0, 0]];
for h in 0..height {
for w in 0..width {
result[[b, c, h, w]] = input[[b, c, h, w]] * scale;
}
}
}
}
Ok(result)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.fc1.update(learning_rate)?;
self.fc2.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
#[allow(dead_code)]
fn get_activation<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>(
name: &str,
) -> Box<dyn Fn(F) -> F + Send + Sync> {
match name {
"relu" => Box::new(|x: F| x.max(F::zero())),
"relu6" => Box::new(relu6),
"hard_swish" => Box::new(hard_swish),
"hard_sigmoid" => Box::new(hard_sigmoid),
_ => Box::new(|x: F| x.max(F::zero())), }
}
struct InvertedResidualBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
config: ConvBlockConfig,
has_skip_connection: bool,
expand_conv: Option<Conv2D<F>>,
expand_bn: Option<BatchNorm<F>>,
depthwise_conv: Conv2D<F>,
depthwise_bn: BatchNorm<F>,
se: Option<SqueezeExcitation<F>>,
project_conv: Conv2D<F>,
project_bn: BatchNorm<F>,
activation: Box<dyn Fn(F) -> F + Send + Sync>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> InvertedResidualBlock<F> {
pub fn new(config: ConvBlockConfig, width_multiplier: f64) -> Result<Self> {
let input_channels = (config.input_channels as f64 * width_multiplier).round() as usize;
let output_channels = (config.output_channels as f64 * width_multiplier).round() as usize;
let expand_ratio = config.expand_ratio;
let kernel_size = config.kernel_size;
let stride = config.stride;
let use_se = config.use_se;
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let has_skip_connection =
input_channels == output_channels && stride == 1 && config.use_residual;
let (expand_conv, expand_bn) = if expand_ratio != 1 {
let expanded_channels = input_channels * expand_ratio;
let kernel_size = (1, 1);
let stride = (1, 1);
let conv = Conv2D::new(input_channels, expanded_channels, kernel_size, stride, None)?
.with_padding(PaddingMode::Valid);
let bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
(Some(conv), Some(bn))
} else {
(None, None)
};
let expanded_channels = if expand_ratio != 1 {
input_channels * expand_ratio
} else {
input_channels
};
let kernel_size_tuple = (kernel_size, kernel_size);
let stride_tuple = (stride, stride);
let depthwise_conv = Conv2D::new(
expanded_channels,
expanded_channels,
kernel_size_tuple,
stride_tuple,
None,
)?
.with_padding(PaddingMode::Same);
let depthwise_bn = BatchNorm::new(expanded_channels, 1e-3, 0.01, &mut rng)?;
let se = if use_se {
let squeeze_channels = (expanded_channels as f64 / 4.0).round() as usize;
Some(SqueezeExcitation::new(expanded_channels, squeeze_channels)?)
} else {
None
};
let project_conv = Conv2D::new(expanded_channels, output_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let project_bn = BatchNorm::new(output_channels, 1e-3, 0.01, &mut rng)?;
let activation = get_activation(&config.activation);
Ok(Self {
config,
has_skip_connection,
expand_conv,
expand_bn,
depthwise_conv,
depthwise_bn,
se,
project_conv,
project_bn,
activation,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F>
for InvertedResidualBlock<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = input.clone();
if let (Some(ref expand_conv), Some(ref expand_bn)) = (&self.expand_conv, &self.expand_bn) {
x = expand_conv.forward(&x)?;
x = expand_bn.forward(&x)?;
x = x.mapv(|v| (self.activation)(v)); }
x = self.depthwise_conv.forward(&x)?;
x = self.depthwise_bn.forward(&x)?;
x = x.mapv(|v| (self.activation)(v)); if let Some(ref se) = self.se {
x = se.forward(&x)?;
}
x = self.project_conv.forward(&x)?;
x = self.project_bn.forward(&x)?;
if self.has_skip_connection {
let mut result = input.clone();
for i in 0..result.len() {
result[i] += x[i];
}
x = result;
}
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
if let (Some(ref mut expand_conv), Some(ref mut expand_bn)) =
(&mut self.expand_conv, &mut self.expand_bn)
{
expand_conv.update(learning_rate)?;
expand_bn.update(learning_rate)?;
}
self.depthwise_conv.update(learning_rate)?;
self.depthwise_bn.update(learning_rate)?;
if let Some(ref mut se) = self.se {
se.update(learning_rate)?;
}
self.project_conv.update(learning_rate)?;
self.project_bn.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
struct DepthwiseSeparableConv<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
depthwise_conv: Conv2D<F>,
depthwise_bn: BatchNorm<F>,
pointwise_conv: Conv2D<F>,
pointwise_bn: BatchNorm<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> DepthwiseSeparableConv<F> {
pub fn new(config: ConvBlockConfig, width_multiplier: f64) -> Result<Self> {
let input_channels = (config.input_channels as f64 * width_multiplier).round() as usize;
let output_channels = (config.output_channels as f64 * width_multiplier).round() as usize;
let kernel_size = config.kernel_size;
let stride = config.stride;
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let depthwise_conv = Conv2D::new(
input_channels,
input_channels,
(kernel_size, kernel_size),
(stride, stride),
None,
)?
.with_padding(PaddingMode::Same);
let depthwise_bn = BatchNorm::new(input_channels, 1e-3, 0.01, &mut rng)?;
let pointwise_conv = Conv2D::new(input_channels, output_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let pointwise_bn = BatchNorm::new(output_channels, 1e-3, 0.01, &mut rng)?;
Ok(Self {
depthwise_conv,
depthwise_bn,
pointwise_conv,
pointwise_bn,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F>
for DepthwiseSeparableConv<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = self.depthwise_conv.forward(input)?;
x = self.depthwise_bn.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero())); x = self.pointwise_conv.forward(&x)?;
x = self.pointwise_bn.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero())); Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.depthwise_conv.update(learning_rate)?;
self.depthwise_bn.update(learning_rate)?;
self.pointwise_conv.update(learning_rate)?;
self.pointwise_bn.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub struct MobileNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
config: MobileNetConfig,
stem_conv: Conv2D<F>,
stem_bn: BatchNorm<F>,
blocks: Vec<Box<dyn Layer<F>>>,
classifier: Dense<F>,
dropout: Dropout<F>,
stem_activation: Box<dyn Fn(F) -> F + Send + Sync>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> MobileNet<F> {
pub fn new(config: MobileNetConfig) -> Result<Self> {
let input_channels = config.input_channels;
let num_classes = config.num_classes;
let width_multiplier = config.width_multiplier;
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let first_block = &config.blocks[0];
let stem_channels =
(first_block.output_channels as f64 * width_multiplier).round() as usize;
let kernel_size_tuple = (first_block.kernel_size, first_block.kernel_size);
let stride_tuple = (first_block.stride, first_block.stride);
let stem_conv = Conv2D::new(
input_channels,
stem_channels,
kernel_size_tuple,
stride_tuple,
None,
)?
.with_padding(PaddingMode::Same);
let stem_bn = BatchNorm::new(stem_channels, 1e-3, 0.01, &mut rng)?;
let stem_activation = get_activation(&first_block.activation);
let mut blocks: Vec<Box<dyn Layer<F>>> = Vec::new();
match config.version {
MobileNetVersion::V1 => {
for i in 1..config.blocks.len() {
let block =
DepthwiseSeparableConv::new(config.blocks[i].clone(), width_multiplier)?;
blocks.push(Box::new(block));
}
}
_ => {
for i in 1..config.blocks.len() {
let block =
InvertedResidualBlock::new(config.blocks[i].clone(), width_multiplier)?;
blocks.push(Box::new(block));
}
}
}
let last_channels = if config.version == MobileNetVersion::V1 {
let scaled = (config
.blocks
.last()
.expect("Operation failed")
.output_channels as f64
* width_multiplier)
.round();
scaled as usize
} else {
match config.version {
MobileNetVersion::V2 => 1001,
_ => 1280,
}
};
let classifier = Dense::new(last_channels, num_classes, None, &mut rng)?;
let dropout = Dropout::new(config.dropout_rate, &mut rng)?;
Ok(Self {
config,
stem_conv,
stem_bn,
blocks,
classifier,
dropout,
stem_activation,
})
}
pub fn mobilenet_v1(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = MobileNetConfig::mobilenet_v1(input_channels, num_classes);
Self::new(config)
}
pub fn mobilenet_v2(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = MobileNetConfig::mobilenet_v2(input_channels, num_classes);
Self::new(config)
}
pub fn mobilenet_v3_small(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = MobileNetConfig::mobilenet_v3_small(input_channels, num_classes);
Self::new(config)
}
pub fn mobilenet_v3_large(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = MobileNetConfig::mobilenet_v3_large(input_channels, num_classes);
Self::new(config)
}
pub fn config(&self) -> &MobileNetConfig {
&self.config
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for MobileNet<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 || shape[1] != self.config.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected input shape [batch_size, {}, height, width], got {:?}",
self.config.input_channels, shape
)));
}
let mut x = self.stem_conv.forward(input)?;
x = self.stem_bn.forward(&x)?;
x = x.mapv(|v| (self.stem_activation)(v)); for block in &self.blocks {
x = block.forward(&x)?;
}
let batch_size = x.shape()[0];
let channels = x.shape()[1];
let height = x.shape()[2];
let width = x.shape()[3];
let mut pooled = Array::zeros(IxDyn(&[batch_size, channels]));
for b in 0..batch_size {
for c in 0..channels {
let mut sum = F::zero();
for h in 0..height {
for w in 0..width {
sum += x[[b, c, h, w]];
}
}
pooled[[b, c]] = sum / F::from(height * width).expect("Failed to convert to float");
}
}
let pooled = self.dropout.forward(&pooled)?;
let logits = self.classifier.forward(&pooled)?;
Ok(logits)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.stem_conv.update(learning_rate)?;
self.stem_bn.update(learning_rate)?;
for block in &mut self.blocks {
block.update(learning_rate)?;
}
self.classifier.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}