use crate::error::{NeuralError, Result};
use crate::layers::conv::PaddingMode;
use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use scirs2_core::random::SeedableRng;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub enum ResNetBlock {
Basic,
Bottleneck,
}
#[derive(Debug, Clone)]
pub struct ResNetLayer {
pub blocks: usize,
pub channels: usize,
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct ResNetConfig {
pub block: ResNetBlock,
pub layers: Vec<ResNetLayer>,
pub input_channels: usize,
pub num_classes: usize,
pub dropout_rate: f64,
}
impl ResNetConfig {
pub fn resnet18(input_channels: usize, num_classes: usize) -> Self {
Self {
block: ResNetBlock::Basic,
layers: vec![
ResNetLayer {
blocks: 2,
channels: 64,
stride: 1,
},
ResNetLayer {
blocks: 2,
channels: 128,
stride: 2,
},
ResNetLayer {
blocks: 2,
channels: 256,
stride: 2,
},
ResNetLayer {
blocks: 2,
channels: 512,
stride: 2,
},
],
input_channels,
num_classes,
dropout_rate: 0.0,
}
}
pub fn resnet34(input_channels: usize, num_classes: usize) -> Self {
Self {
block: ResNetBlock::Basic,
layers: vec![
ResNetLayer {
blocks: 3,
channels: 64,
stride: 1,
},
ResNetLayer {
blocks: 4,
channels: 128,
stride: 2,
},
ResNetLayer {
blocks: 6,
channels: 256,
stride: 2,
},
ResNetLayer {
blocks: 3,
channels: 512,
stride: 2,
},
],
input_channels,
num_classes,
dropout_rate: 0.0,
}
}
pub fn resnet50(input_channels: usize, num_classes: usize) -> Self {
Self {
block: ResNetBlock::Bottleneck,
layers: vec![
ResNetLayer {
blocks: 3,
channels: 64,
stride: 1,
},
ResNetLayer {
blocks: 4,
channels: 128,
stride: 2,
},
ResNetLayer {
blocks: 6,
channels: 256,
stride: 2,
},
ResNetLayer {
blocks: 3,
channels: 512,
stride: 2,
},
],
input_channels,
num_classes,
dropout_rate: 0.0,
}
}
pub fn resnet101(input_channels: usize, num_classes: usize) -> Self {
Self {
block: ResNetBlock::Bottleneck,
layers: vec![
ResNetLayer {
blocks: 3,
channels: 64,
stride: 1,
},
ResNetLayer {
blocks: 4,
channels: 128,
stride: 2,
},
ResNetLayer {
blocks: 23,
channels: 256,
stride: 2,
},
ResNetLayer {
blocks: 3,
channels: 512,
stride: 2,
},
],
input_channels,
num_classes,
dropout_rate: 0.0,
}
}
pub fn resnet152(input_channels: usize, num_classes: usize) -> Self {
Self {
block: ResNetBlock::Bottleneck,
layers: vec![
ResNetLayer {
blocks: 3,
channels: 64,
stride: 1,
},
ResNetLayer {
blocks: 8,
channels: 128,
stride: 2,
},
ResNetLayer {
blocks: 36,
channels: 256,
stride: 2,
},
ResNetLayer {
blocks: 3,
channels: 512,
stride: 2,
},
],
input_channels,
num_classes,
dropout_rate: 0.0,
}
}
pub fn with_dropout(mut self, rate: f64) -> Self {
self.dropout_rate = rate;
self
}
}
struct BasicBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
conv1: Conv2D<F>,
bn1: BatchNorm<F>,
conv2: Conv2D<F>,
bn2: BatchNorm<F>,
downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone for BasicBlock<F> {
fn clone(&self) -> Self {
Self {
conv1: self.conv1.clone(),
bn1: self.bn1.clone(),
conv2: self.conv2.clone(),
bn2: self.bn2.clone(),
downsample: self.downsample.clone(),
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BasicBlock<F> {
pub fn new(
in_channels: usize,
out_channels: usize,
stride: usize,
downsample: bool,
) -> Result<Self> {
let stride_tuple = (stride, stride);
let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let conv1 = Conv2D::new(in_channels, out_channels, (3, 3), stride_tuple, None)?
.with_padding(PaddingMode::Same);
let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
let bn1 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng2)?;
let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
let conv2 = Conv2D::new(out_channels, out_channels, (3, 3), (1, 1), None)?
.with_padding(PaddingMode::Same);
let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
let bn2 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng4)?;
let downsample = if downsample {
let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
.with_padding(PaddingMode::Valid);
let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
Some((ds_conv, ds_bn))
} else {
None
};
Ok(Self {
conv1,
bn1,
conv2,
bn2,
downsample,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
for BasicBlock<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = self.conv1.forward(input)?;
x = self.bn1.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero()));
x = self.conv2.forward(&x)?;
x = self.bn2.forward(&x)?;
let identity = if let Some((ref conv, ref bn)) = self.downsample {
let ds = conv.forward(input)?;
bn.forward(&ds)?
} else {
input.clone()
};
let x = &x + &identity;
let x = x.mapv(|v: F| v.max(F::zero()));
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.conv1.update(learning_rate)?;
self.bn1.update(learning_rate)?;
self.conv2.update(learning_rate)?;
self.bn2.update(learning_rate)?;
if let Some((ref mut conv, ref mut bn)) = self.downsample {
conv.update(learning_rate)?;
bn.update(learning_rate)?;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl<
F: Float
+ Debug
+ ScalarOperand
+ Send
+ Sync
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ 'static,
> BasicBlock<F>
{
pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
let mut result = Vec::new();
for (i, p) in self.conv1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
}
for (i, p) in self.bn1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
}
for (i, p) in self.conv2.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
}
for (i, p) in self.bn2.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
}
if let Some((ref conv, ref bn)) = self.downsample {
for (i, p) in conv.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
}
for (i, p) in bn.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
}
}
result
}
pub(crate) fn load_named_params(
&mut self,
prefix: &str,
params_map: &HashMap<String, Array<F, IxDyn>>,
) -> Result<()> {
if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
ps.push(b.clone());
}
self.conv1.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
ps.push(b.clone());
}
self.bn1.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
ps.push(b.clone());
}
self.conv2.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
ps.push(b.clone());
}
self.bn2.set_params(&ps)?;
}
if let Some((ref mut conv, ref mut bn)) = self.downsample {
if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
ps.push(b.clone());
}
conv.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
ps.push(b.clone());
}
bn.set_params(&ps)?;
}
}
Ok(())
}
}
struct BottleneckBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
conv1: Conv2D<F>,
bn1: BatchNorm<F>,
conv2: Conv2D<F>,
bn2: BatchNorm<F>,
conv3: Conv2D<F>,
bn3: BatchNorm<F>,
downsample: Option<(Conv2D<F>, BatchNorm<F>)>,
#[allow(dead_code)]
expansion: usize,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Clone
for BottleneckBlock<F>
{
fn clone(&self) -> Self {
Self {
conv1: self.conv1.clone(),
bn1: self.bn1.clone(),
conv2: self.conv2.clone(),
bn2: self.bn2.clone(),
conv3: self.conv3.clone(),
bn3: self.bn3.clone(),
downsample: self.downsample.clone(),
expansion: self.expansion,
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> BottleneckBlock<F> {
const EXPANSION: usize = 4;
pub fn new(
in_channels: usize,
out_channels: usize,
stride: usize,
downsample: bool,
) -> Result<Self> {
let bottleneck_channels = out_channels / Self::EXPANSION;
let stride_tuple = (stride, stride);
let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
let conv1 = Conv2D::new(in_channels, bottleneck_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
let bn1 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng2)?;
let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
let conv2 = Conv2D::new(
bottleneck_channels,
bottleneck_channels,
(3, 3),
stride_tuple,
None,
)?
.with_padding(PaddingMode::Same);
let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
let bn2 = BatchNorm::new(bottleneck_channels, 1e-5, 0.1, &mut rng4)?;
let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
let conv3 = Conv2D::new(bottleneck_channels, out_channels, (1, 1), (1, 1), None)?
.with_padding(PaddingMode::Valid);
let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
let bn3 = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng6)?;
let downsample = if downsample {
let mut rng7 = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
let ds_conv = Conv2D::new(in_channels, out_channels, (1, 1), stride_tuple, None)?
.with_padding(PaddingMode::Valid);
let mut rng8 = scirs2_core::random::rngs::SmallRng::from_seed([55; 32]);
let ds_bn = BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng8)?;
Some((ds_conv, ds_bn))
} else {
None
};
Ok(Self {
conv1,
bn1,
conv2,
bn2,
conv3,
bn3,
downsample,
expansion: Self::EXPANSION,
})
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
for BottleneckBlock<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = self.conv1.forward(input)?;
x = self.bn1.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero()));
x = self.conv2.forward(&x)?;
x = self.bn2.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero()));
x = self.conv3.forward(&x)?;
x = self.bn3.forward(&x)?;
let identity = if let Some((ref conv, ref bn)) = self.downsample {
let ds = conv.forward(input)?;
bn.forward(&ds)?
} else {
input.clone()
};
let x = &x + &identity;
let x = x.mapv(|v: F| v.max(F::zero()));
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.conv1.update(learning_rate)?;
self.bn1.update(learning_rate)?;
self.conv2.update(learning_rate)?;
self.bn2.update(learning_rate)?;
self.conv3.update(learning_rate)?;
self.bn3.update(learning_rate)?;
if let Some((ref mut conv, ref mut bn)) = self.downsample {
conv.update(learning_rate)?;
bn.update(learning_rate)?;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl<
F: Float
+ Debug
+ ScalarOperand
+ Send
+ Sync
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ 'static,
> BottleneckBlock<F>
{
pub(crate) fn extract_named_params(&self, prefix: &str) -> Vec<(String, Array<F, IxDyn>)> {
let mut result = Vec::new();
for (i, p) in self.conv1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.conv1.{suffix}"), p.clone()));
}
for (i, p) in self.bn1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.bn1.{suffix}"), p.clone()));
}
for (i, p) in self.conv2.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.conv2.{suffix}"), p.clone()));
}
for (i, p) in self.bn2.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.bn2.{suffix}"), p.clone()));
}
for (i, p) in self.conv3.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.conv3.{suffix}"), p.clone()));
}
for (i, p) in self.bn3.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.bn3.{suffix}"), p.clone()));
}
if let Some((ref conv, ref bn)) = self.downsample {
for (i, p) in conv.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.downsample.0.{suffix}"), p.clone()));
}
for (i, p) in bn.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("{prefix}.downsample.1.{suffix}"), p.clone()));
}
}
result
}
pub(crate) fn load_named_params(
&mut self,
prefix: &str,
params_map: &HashMap<String, Array<F, IxDyn>>,
) -> Result<()> {
if let Some(w) = params_map.get(&format!("{prefix}.conv1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.conv1.bias")) {
ps.push(b.clone());
}
self.conv1.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.bn1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.bn1.bias")) {
ps.push(b.clone());
}
self.bn1.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.conv2.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.conv2.bias")) {
ps.push(b.clone());
}
self.conv2.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.bn2.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.bn2.bias")) {
ps.push(b.clone());
}
self.bn2.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.conv3.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.conv3.bias")) {
ps.push(b.clone());
}
self.conv3.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.bn3.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.bn3.bias")) {
ps.push(b.clone());
}
self.bn3.set_params(&ps)?;
}
if let Some((ref mut conv, ref mut bn)) = self.downsample {
if let Some(w) = params_map.get(&format!("{prefix}.downsample.0.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.downsample.0.bias")) {
ps.push(b.clone());
}
conv.set_params(&ps)?;
}
if let Some(w) = params_map.get(&format!("{prefix}.downsample.1.weight")) {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get(&format!("{prefix}.downsample.1.bias")) {
ps.push(b.clone());
}
bn.set_params(&ps)?;
}
}
Ok(())
}
}
pub struct ResNet<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
conv1: Conv2D<F>,
bn1: BatchNorm<F>,
layer1: Vec<BasicBlock<F>>,
layer1_bottleneck: Vec<BottleneckBlock<F>>,
fc: Dense<F>,
dropout: Option<Dropout<F>>,
config: ResNetConfig,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> ResNet<F> {
pub fn new(config: ResNetConfig) -> Result<Self> {
let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([56; 32]);
let conv1 = Conv2D::new(config.input_channels, 64, (7, 7), (2, 2), None)?
.with_padding(PaddingMode::Same);
let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([57; 32]);
let bn1 = BatchNorm::new(64, 1e-5, 0.1, &mut rng2)?;
let layer1 = Vec::new();
let layer1_bottleneck = Vec::new();
let fc_in_features = match config.block {
ResNetBlock::Basic => config.layers.last().map(|l| l.channels).unwrap_or(512),
ResNetBlock::Bottleneck => config.layers.last().map(|l| l.channels * 4).unwrap_or(2048),
};
let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([58; 32]);
let fc = Dense::new(fc_in_features, config.num_classes, None, &mut rng3)?;
let dropout = if config.dropout_rate > 0.0 {
let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([59; 32]);
Some(Dropout::new(config.dropout_rate, &mut rng4)?)
} else {
None
};
Ok(Self {
conv1,
bn1,
layer1,
layer1_bottleneck,
fc,
dropout,
config,
})
}
pub fn resnet18(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = ResNetConfig::resnet18(input_channels, num_classes);
Self::new(config)
}
pub fn resnet34(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = ResNetConfig::resnet34(input_channels, num_classes);
Self::new(config)
}
pub fn resnet50(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = ResNetConfig::resnet50(input_channels, num_classes);
Self::new(config)
}
pub fn resnet101(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = ResNetConfig::resnet101(input_channels, num_classes);
Self::new(config)
}
pub fn resnet152(input_channels: usize, num_classes: usize) -> Result<Self> {
let config = ResNetConfig::resnet152(input_channels, num_classes);
Self::new(config)
}
pub fn config(&self) -> &ResNetConfig {
&self.config
}
fn global_avg_pool(x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = x.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input for average pooling, got shape {:?}",
shape
)));
}
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let mut output = Array::zeros(IxDyn(&[batch_size, channels]));
let count = F::from(height * width).expect("Failed to convert to float");
for b in 0..batch_size {
for c in 0..channels {
let mut sum = F::zero();
for h in 0..height {
for w in 0..width {
sum += x[[b, c, h, w]];
}
}
output[[b, c]] = sum / count;
}
}
Ok(output)
}
}
impl<
F: Float
+ Debug
+ ScalarOperand
+ Send
+ Sync
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ 'static,
> ResNet<F>
{
pub fn extract_named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
let mut result = Vec::new();
for (i, p) in self.conv1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("conv1.{suffix}"), p.clone()));
}
for (i, p) in self.bn1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("bn1.{suffix}"), p.clone()));
}
for (idx, block) in self.layer1.iter().enumerate() {
let block_params = block.extract_named_params(&format!("layer1.{idx}"));
result.extend(block_params);
}
for (idx, block) in self.layer1_bottleneck.iter().enumerate() {
let block_params = block.extract_named_params(&format!("layer1.{idx}"));
result.extend(block_params);
}
for (i, p) in self.fc.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("fc.{suffix}"), p.clone()));
}
Ok(result)
}
pub fn load_named_params(
&mut self,
params_map: &HashMap<String, Array<F, IxDyn>>,
) -> Result<()> {
if let Some(w) = params_map.get("conv1.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("conv1.bias") {
ps.push(b.clone());
}
self.conv1.set_params(&ps)?;
}
if let Some(w) = params_map.get("bn1.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("bn1.bias") {
ps.push(b.clone());
}
self.bn1.set_params(&ps)?;
}
for (idx, block) in self.layer1.iter_mut().enumerate() {
block.load_named_params(&format!("layer1.{idx}"), params_map)?;
}
for (idx, block) in self.layer1_bottleneck.iter_mut().enumerate() {
block.load_named_params(&format!("layer1.{idx}"), params_map)?;
}
if let Some(w) = params_map.get("fc.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("fc.bias") {
ps.push(b.clone());
}
self.fc.set_params(&ps)?;
}
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for ResNet<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = self.conv1.forward(input)?;
x = self.bn1.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero()));
for block in &self.layer1 {
x = block.forward(&x)?;
}
for block in &self.layer1_bottleneck {
x = block.forward(&x)?;
}
x = Self::global_avg_pool(&x)?;
if let Some(ref dropout) = self.dropout {
x = dropout.forward(&x)?;
}
x = self.fc.forward(&x)?;
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.conv1.update(learning_rate)?;
self.bn1.update(learning_rate)?;
for block in &mut self.layer1 {
block.update(learning_rate)?;
}
for block in &mut self.layer1_bottleneck {
block.update(learning_rate)?;
}
self.fc.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resnet_config_18() {
let config = ResNetConfig::resnet18(3, 1000);
assert_eq!(config.input_channels, 3);
assert_eq!(config.num_classes, 1000);
assert_eq!(config.layers.len(), 4);
assert!(matches!(config.block, ResNetBlock::Basic));
}
#[test]
fn test_resnet_config_50() {
let config = ResNetConfig::resnet50(3, 1000);
assert!(matches!(config.block, ResNetBlock::Bottleneck));
assert_eq!(config.layers.len(), 4);
}
#[test]
fn test_resnet_config_with_dropout() {
let config = ResNetConfig::resnet18(3, 100).with_dropout(0.5);
assert_eq!(config.dropout_rate, 0.5);
}
#[test]
fn test_resnet_config_variants() {
let config34 = ResNetConfig::resnet34(3, 1000);
assert_eq!(config34.layers[0].blocks, 3);
assert_eq!(config34.layers[1].blocks, 4);
let config101 = ResNetConfig::resnet101(3, 1000);
assert_eq!(config101.layers[2].blocks, 23);
let config152 = ResNetConfig::resnet152(3, 1000);
assert_eq!(config152.layers[2].blocks, 36);
}
}