use crate::error::{NeuralError, Result};
use crate::layers::conv::PaddingMode;
use crate::layers::{BatchNorm, Conv2D, Dense, Dropout, Layer};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use scirs2_core::random::SeedableRng;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VGGVariant {
VGG11,
VGG13,
VGG16,
VGG19,
}
impl VGGVariant {
fn layer_config(&self) -> Vec<Vec<usize>> {
match self {
VGGVariant::VGG11 => vec![
vec![64],
vec![128],
vec![256, 256],
vec![512, 512],
vec![512, 512],
],
VGGVariant::VGG13 => vec![
vec![64, 64],
vec![128, 128],
vec![256, 256],
vec![512, 512],
vec![512, 512],
],
VGGVariant::VGG16 => vec![
vec![64, 64],
vec![128, 128],
vec![256, 256, 256],
vec![512, 512, 512],
vec![512, 512, 512],
],
VGGVariant::VGG19 => vec![
vec![64, 64],
vec![128, 128],
vec![256, 256, 256, 256],
vec![512, 512, 512, 512],
vec![512, 512, 512, 512],
],
}
}
pub fn name(&self) -> &str {
match self {
VGGVariant::VGG11 => "VGG-11",
VGGVariant::VGG13 => "VGG-13",
VGGVariant::VGG16 => "VGG-16",
VGGVariant::VGG19 => "VGG-19",
}
}
pub fn num_conv_layers(&self) -> usize {
self.layer_config().iter().map(|block| block.len()).sum()
}
}
#[derive(Debug, Clone)]
pub struct VGGConfig {
pub variant: VGGVariant,
pub batch_norm: bool,
pub input_channels: usize,
pub num_classes: usize,
pub dropout_rate: f64,
pub fc_hidden_units: usize,
pub channel_divisor: usize,
}
impl VGGConfig {
pub fn vgg11(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG11,
batch_norm: false,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg11_bn(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG11,
batch_norm: true,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg13(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG13,
batch_norm: false,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg13_bn(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG13,
batch_norm: true,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg16(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG16,
batch_norm: false,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg16_bn(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG16,
batch_norm: true,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg19(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG19,
batch_norm: false,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn vgg19_bn(input_channels: usize, num_classes: usize) -> Self {
Self {
variant: VGGVariant::VGG19,
batch_norm: true,
input_channels,
num_classes,
dropout_rate: 0.5,
fc_hidden_units: 4096,
channel_divisor: 1,
}
}
pub fn with_dropout(mut self, rate: f64) -> Self {
self.dropout_rate = rate;
self
}
pub fn with_batch_norm(mut self, batch_norm: bool) -> Self {
self.batch_norm = batch_norm;
self
}
pub fn with_fc_hidden_units(mut self, units: usize) -> Self {
self.fc_hidden_units = units;
self
}
pub fn with_channel_divisor(mut self, divisor: usize) -> Self {
self.channel_divisor = divisor.max(1);
self
}
fn effective_layer_config(&self) -> Vec<Vec<usize>> {
let base_config = self.variant.layer_config();
if self.channel_divisor <= 1 {
return base_config;
}
base_config
.into_iter()
.map(|block| {
block
.into_iter()
.map(|ch| (ch / self.channel_divisor).max(1))
.collect()
})
.collect()
}
}
struct VGGConvBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
conv: Conv2D<F>,
bn: Option<BatchNorm<F>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> VGGConvBlock<F> {
fn new(in_channels: usize, out_channels: usize, use_bn: bool) -> Result<Self> {
let conv = Conv2D::new(in_channels, out_channels, (3, 3), (1, 1), None)?
.with_padding(PaddingMode::Same);
let bn = if use_bn {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
Some(BatchNorm::new(out_channels, 1e-5, 0.1, &mut rng)?)
} else {
None
};
Ok(Self { conv, bn })
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = self.conv.forward(input)?;
if let Some(ref bn) = self.bn {
x = bn.forward(&x)?;
}
x = x.mapv(|v: F| v.max(F::zero()));
Ok(x)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.conv.update(learning_rate)?;
if let Some(ref mut bn) = self.bn {
bn.update(learning_rate)?;
}
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut result = self.conv.params();
if let Some(ref bn) = self.bn {
result.extend(bn.params());
}
result
}
fn parameter_count(&self) -> usize {
let mut count = self.conv.parameter_count();
if let Some(ref bn) = self.bn {
count += bn.parameter_count();
}
count
}
}
struct VGGFeatureStage<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
blocks: Vec<VGGConvBlock<F>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> VGGFeatureStage<F> {
fn new(channels: &[usize], in_channels: usize, use_bn: bool) -> Result<Self> {
let mut blocks = Vec::with_capacity(channels.len());
let mut current_in = in_channels;
for &out_ch in channels {
blocks.push(VGGConvBlock::new(current_in, out_ch, use_bn)?);
current_in = out_ch;
}
Ok(Self { blocks })
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let mut x = input.clone();
for block in &self.blocks {
x = block.forward(&x)?;
}
x = max_pool_2x2(&x)?;
Ok(x)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
for block in &mut self.blocks {
block.update(learning_rate)?;
}
Ok(())
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut result = Vec::new();
for block in &self.blocks {
result.extend(block.params());
}
result
}
fn parameter_count(&self) -> usize {
self.blocks.iter().map(|b| b.parameter_count()).sum()
}
}
fn max_pool_2x2<F: Float + Debug>(input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input for max pooling, got shape {:?}",
shape
)));
}
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let out_h = height / 2;
let out_w = width / 2;
let mut output = Array::from_elem(
IxDyn(&[batch_size, channels, out_h, out_w]),
F::neg_infinity(),
);
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..out_h {
for ow in 0..out_w {
let h_start = oh * 2;
let w_start = ow * 2;
let mut max_val = F::neg_infinity();
for dh in 0..2 {
for dw in 0..2 {
let h = h_start + dh;
let w = w_start + dw;
if h < height && w < width {
let val = input[[b, c, h, w]];
if val > max_val {
max_val = val;
}
}
}
}
output[[b, c, oh, ow]] = max_val;
}
}
}
}
Ok(output)
}
pub struct VGG<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
config: VGGConfig,
features: Vec<VGGFeatureStage<F>>,
fc1: Dense<F>,
dropout1: Dropout<F>,
fc2: Dense<F>,
dropout2: Dropout<F>,
fc3: Dense<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> VGG<F> {
pub fn new(config: VGGConfig) -> Result<Self> {
let layer_configs = config.effective_layer_config();
let mut features = Vec::with_capacity(layer_configs.len());
let mut in_channels = config.input_channels;
for block_channels in &layer_configs {
let stage = VGGFeatureStage::new(block_channels, in_channels, config.batch_norm)?;
in_channels = *block_channels.last().ok_or_else(|| {
NeuralError::InvalidArchitecture("Empty block channel configuration".to_string())
})?;
features.push(stage);
}
let fc_input_size = in_channels * 7 * 7;
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let fc1 = Dense::new(fc_input_size, config.fc_hidden_units, None, &mut rng)?;
let dropout1 = Dropout::new(config.dropout_rate, &mut rng)?;
let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
let fc2 = Dense::new(
config.fc_hidden_units,
config.fc_hidden_units,
None,
&mut rng2,
)?;
let dropout2 = Dropout::new(config.dropout_rate, &mut rng2)?;
let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
let fc3 = Dense::new(config.fc_hidden_units, config.num_classes, None, &mut rng3)?;
Ok(Self {
config,
features,
fc1,
dropout1,
fc2,
dropout2,
fc3,
})
}
pub fn vgg11(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg11(input_channels, num_classes))
}
pub fn vgg11_bn(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg11_bn(input_channels, num_classes))
}
pub fn vgg13(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg13(input_channels, num_classes))
}
pub fn vgg13_bn(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg13_bn(input_channels, num_classes))
}
pub fn vgg16(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg16(input_channels, num_classes))
}
pub fn vgg16_bn(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg16_bn(input_channels, num_classes))
}
pub fn vgg19(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg19(input_channels, num_classes))
}
pub fn vgg19_bn(input_channels: usize, num_classes: usize) -> Result<Self> {
Self::new(VGGConfig::vgg19_bn(input_channels, num_classes))
}
pub fn config(&self) -> &VGGConfig {
&self.config
}
pub fn total_parameter_count(&self) -> usize {
let feature_params: usize = self.features.iter().map(|s| s.parameter_count()).sum();
let classifier_params =
self.fc1.parameter_count() + self.fc2.parameter_count() + self.fc3.parameter_count();
feature_params + classifier_params
}
pub fn num_stages(&self) -> usize {
self.features.len()
}
pub fn extract_features(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input [batch, channels, height, width], got shape {:?}",
shape
)));
}
if shape[1] != self.config.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected {} input channels, got {}",
self.config.input_channels, shape[1]
)));
}
let mut x = input.clone();
for stage in &self.features {
x = stage.forward(&x)?;
}
Ok(x)
}
fn adaptive_avg_pool(
input: &Array<F, IxDyn>,
target_h: usize,
target_w: usize,
) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input for adaptive avg pooling, got shape {:?}",
shape
)));
}
let batch_size = shape[0];
let channels = shape[1];
let in_h = shape[2];
let in_w = shape[3];
let mut output = Array::zeros(IxDyn(&[batch_size, channels, target_h, target_w]));
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..target_h {
for ow in 0..target_w {
let h_start = (oh * in_h) / target_h;
let h_end = ((oh + 1) * in_h) / target_h;
let w_start = (ow * in_w) / target_w;
let w_end = ((ow + 1) * in_w) / target_w;
let mut sum = F::zero();
let mut count = 0usize;
for h in h_start..h_end {
for w in w_start..w_end {
sum += input[[b, c, h, w]];
count += 1;
}
}
let count_f = F::from(count).ok_or_else(|| {
NeuralError::InferenceError(
"Failed to convert pool count to float".to_string(),
)
})?;
if count > 0 {
output[[b, c, oh, ow]] = sum / count_f;
}
}
}
}
}
Ok(output)
}
}
impl<
F: Float
+ Debug
+ ScalarOperand
+ Send
+ Sync
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ 'static,
> VGG<F>
{
pub fn extract_named_params(&self) -> Vec<(String, Array<F, IxDyn>)> {
let mut result = Vec::new();
for (stage_idx, stage) in self.features.iter().enumerate() {
for (block_idx, block) in stage.blocks.iter().enumerate() {
for (i, p) in block.conv.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((
format!("features.{stage_idx}.{block_idx}.conv.{suffix}"),
p.clone(),
));
}
if let Some(ref bn) = block.bn {
for (i, p) in bn.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((
format!("features.{stage_idx}.{block_idx}.bn.{suffix}"),
p.clone(),
));
}
}
}
}
for (i, p) in self.fc1.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("classifier.0.{suffix}"), p.clone()));
}
for (i, p) in self.fc2.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("classifier.3.{suffix}"), p.clone()));
}
for (i, p) in self.fc3.params().iter().enumerate() {
let suffix = if i == 0 { "weight" } else { "bias" };
result.push((format!("classifier.6.{suffix}"), p.clone()));
}
result
}
pub fn load_named_params(
&mut self,
params_map: &HashMap<String, Array<F, IxDyn>>,
) -> Result<()> {
for (stage_idx, stage) in self.features.iter_mut().enumerate() {
for (block_idx, block) in stage.blocks.iter_mut().enumerate() {
let conv_weight_key = format!("features.{stage_idx}.{block_idx}.conv.weight");
if let Some(w) = params_map.get(&conv_weight_key) {
let mut ps = vec![w.clone()];
let conv_bias_key = format!("features.{stage_idx}.{block_idx}.conv.bias");
if let Some(b) = params_map.get(&conv_bias_key) {
ps.push(b.clone());
}
block.conv.set_params(&ps)?;
}
if let Some(ref mut bn) = block.bn {
let bn_weight_key = format!("features.{stage_idx}.{block_idx}.bn.weight");
if let Some(w) = params_map.get(&bn_weight_key) {
let mut ps = vec![w.clone()];
let bn_bias_key = format!("features.{stage_idx}.{block_idx}.bn.bias");
if let Some(b) = params_map.get(&bn_bias_key) {
ps.push(b.clone());
}
bn.set_params(&ps)?;
}
}
}
}
if let Some(w) = params_map.get("classifier.0.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("classifier.0.bias") {
ps.push(b.clone());
}
self.fc1.set_params(&ps)?;
}
if let Some(w) = params_map.get("classifier.3.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("classifier.3.bias") {
ps.push(b.clone());
}
self.fc2.set_params(&ps)?;
}
if let Some(w) = params_map.get("classifier.6.weight") {
let mut ps = vec![w.clone()];
if let Some(b) = params_map.get("classifier.6.bias") {
ps.push(b.clone());
}
self.fc3.set_params(&ps)?;
}
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F> for VGG<F> {
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
let shape = input.shape();
if shape.len() != 4 {
return Err(NeuralError::InferenceError(format!(
"Expected 4D input [batch, channels, height, width], got shape {:?}",
shape
)));
}
if shape[1] != self.config.input_channels {
return Err(NeuralError::InferenceError(format!(
"Expected {} input channels, got {}",
self.config.input_channels, shape[1]
)));
}
let batch_size = shape[0];
let mut x = input.clone();
for stage in &self.features {
x = stage.forward(&x)?;
}
x = Self::adaptive_avg_pool(&x, 7, 7)?;
let channels = x.shape()[1];
let height = x.shape()[2];
let width = x.shape()[3];
let flat_size = channels * height * width;
let x = x
.into_shape_with_order(IxDyn(&[batch_size, flat_size]))
.map_err(|e| {
NeuralError::InferenceError(format!("Failed to flatten feature map: {}", e))
})?;
let mut x = self.fc1.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero())); x = self.dropout1.forward(&x)?;
x = self.fc2.forward(&x)?;
x = x.mapv(|v: F| v.max(F::zero())); x = self.dropout2.forward(&x)?;
x = self.fc3.forward(&x)?;
Ok(x)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
Ok(grad_output.clone())
}
fn update(&mut self, learning_rate: F) -> Result<()> {
for stage in &mut self.features {
stage.update(learning_rate)?;
}
self.fc1.update(learning_rate)?;
self.fc2.update(learning_rate)?;
self.fc3.update(learning_rate)?;
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
let mut result = Vec::new();
for stage in &self.features {
result.extend(stage.params());
}
result.extend(self.fc1.params());
result.extend(self.fc2.params());
result.extend(self.fc3.params());
result
}
fn parameter_count(&self) -> usize {
self.total_parameter_count()
}
fn layer_type(&self) -> &str {
"VGG"
}
fn layer_description(&self) -> String {
format!(
"VGG(variant={}, batch_norm={}, classes={}, params={})",
self.config.variant.name(),
self.config.batch_norm,
self.config.num_classes,
self.total_parameter_count()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vgg_variant_layer_counts() {
assert_eq!(VGGVariant::VGG11.num_conv_layers(), 8);
assert_eq!(VGGVariant::VGG13.num_conv_layers(), 10);
assert_eq!(VGGVariant::VGG16.num_conv_layers(), 13);
assert_eq!(VGGVariant::VGG19.num_conv_layers(), 16);
}
#[test]
fn test_vgg_variant_names() {
assert_eq!(VGGVariant::VGG11.name(), "VGG-11");
assert_eq!(VGGVariant::VGG13.name(), "VGG-13");
assert_eq!(VGGVariant::VGG16.name(), "VGG-16");
assert_eq!(VGGVariant::VGG19.name(), "VGG-19");
}
#[test]
fn test_vgg_config_vgg11() {
let config = VGGConfig::vgg11(3, 1000);
assert_eq!(config.input_channels, 3);
assert_eq!(config.num_classes, 1000);
assert!(!config.batch_norm);
assert_eq!(config.variant, VGGVariant::VGG11);
assert!((config.dropout_rate - 0.5).abs() < 1e-10);
}
#[test]
fn test_vgg_config_vgg16_bn() {
let config = VGGConfig::vgg16_bn(3, 100);
assert!(config.batch_norm);
assert_eq!(config.variant, VGGVariant::VGG16);
assert_eq!(config.num_classes, 100);
}
#[test]
fn test_vgg_config_builder_methods() {
let config = VGGConfig::vgg19(3, 1000)
.with_dropout(0.3)
.with_batch_norm(true)
.with_fc_hidden_units(2048);
assert!((config.dropout_rate - 0.3).abs() < 1e-10);
assert!(config.batch_norm);
assert_eq!(config.fc_hidden_units, 2048);
}
#[test]
fn test_vgg11_creation() {
let model: VGG<f64> = VGG::vgg11(3, 10).expect("Failed to create VGG-11");
assert_eq!(model.num_stages(), 5);
assert_eq!(model.config().variant, VGGVariant::VGG11);
assert!(model.total_parameter_count() > 0);
}
#[test]
fn test_vgg16_creation() {
let model: VGG<f64> = VGG::vgg16(3, 1000).expect("Failed to create VGG-16");
assert_eq!(model.num_stages(), 5);
let param_count = model.total_parameter_count();
assert!(
param_count > 100_000_000,
"VGG-16 should have >100M params, got {}",
param_count
);
}
#[test]
fn test_vgg19_bn_creation() {
let config_bn = VGGConfig::vgg19_bn(3, 100);
assert_eq!(config_bn.variant, VGGVariant::VGG19);
assert!(config_bn.batch_norm);
let model_bn: VGG<f32> = VGG::new(
VGGConfig::vgg19_bn(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(32),
)
.expect("Failed to create VGG-19-BN (scaled)");
assert_eq!(model_bn.num_stages(), 5);
assert!(model_bn.config().batch_norm);
let model_no_bn: VGG<f32> = VGG::new(
VGGConfig::vgg19(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(32),
)
.expect("Failed to create VGG-19 (scaled)");
assert!(
model_bn.total_parameter_count() > model_no_bn.total_parameter_count(),
"BN model params {} should exceed non-BN model params {}",
model_bn.total_parameter_count(),
model_no_bn.total_parameter_count()
);
}
#[test]
fn test_vgg_forward_pass() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(16),
)
.expect("Failed to create VGG");
let input = Array::zeros(IxDyn(&[1, 1, 32, 32]));
let output = model.forward(&input).expect("Forward pass failed");
assert_eq!(output.shape(), &[1, 10]);
}
#[test]
fn test_vgg_forward_larger_input() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 5)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(16),
)
.expect("Failed to create VGG");
let input = Array::zeros(IxDyn(&[2, 1, 64, 64]));
let output = model.forward(&input).expect("Forward pass failed");
assert_eq!(output.shape(), &[2, 5]);
}
#[test]
fn test_vgg_feature_extraction() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(16),
)
.expect("Failed to create VGG");
let input = Array::zeros(IxDyn(&[1, 1, 32, 32]));
let features = model
.extract_features(&input)
.expect("Feature extraction failed");
assert_eq!(features.shape()[0], 1);
assert_eq!(features.shape()[1], 32); }
#[test]
fn test_vgg_invalid_input_shape() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(3, 10)
.with_dropout(0.0)
.with_fc_hidden_units(32),
)
.expect("Failed to create VGG");
let input_3d = Array::zeros(IxDyn(&[1, 3, 32]));
assert!(model.forward(&input_3d).is_err());
let input_wrong_channels = Array::zeros(IxDyn(&[1, 1, 32, 32]));
assert!(model.forward(&input_wrong_channels).is_err());
}
#[test]
fn test_vgg_named_params() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(32),
)
.expect("Failed to create VGG");
let named_params = model.extract_named_params();
assert!(!named_params.is_empty());
let has_feature_param = named_params
.iter()
.any(|(name, _)| name.starts_with("features."));
let has_classifier_param = named_params
.iter()
.any(|(name, _)| name.starts_with("classifier."));
assert!(has_feature_param, "Should have feature parameters");
assert!(has_classifier_param, "Should have classifier parameters");
}
#[test]
fn test_vgg_layer_trait() {
let model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(32),
)
.expect("Failed to create VGG");
assert_eq!(model.layer_type(), "VGG");
assert!(model.parameter_count() > 0);
let desc = model.layer_description();
assert!(desc.contains("VGG-11"));
}
#[test]
fn test_vgg_update() {
let mut model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 10)
.with_dropout(0.0)
.with_fc_hidden_units(32),
)
.expect("Failed to create VGG");
model.update(0.001).expect("Update failed");
}
#[test]
fn test_vgg_all_variants_create() {
for variant in &[
VGGVariant::VGG11,
VGGVariant::VGG13,
VGGVariant::VGG16,
VGGVariant::VGG19,
] {
let config = VGGConfig {
variant: *variant,
batch_norm: false,
input_channels: 1,
num_classes: 5,
dropout_rate: 0.0,
fc_hidden_units: 32,
channel_divisor: 1,
};
let model: VGG<f64> = VGG::new(config).expect("Failed to create model");
assert_eq!(model.config().variant, *variant);
}
}
#[test]
fn test_vgg_bn_variants_create() {
for variant in &[
VGGVariant::VGG11,
VGGVariant::VGG13,
VGGVariant::VGG16,
VGGVariant::VGG19,
] {
let config = VGGConfig {
variant: *variant,
batch_norm: true,
input_channels: 1,
num_classes: 5,
dropout_rate: 0.0,
fc_hidden_units: 32,
channel_divisor: 1,
};
let model: VGG<f64> = VGG::new(config).expect("Failed to create BN model");
assert!(model.config().batch_norm);
}
}
#[test]
fn test_max_pool_2x2() {
let mut input = Array::zeros(IxDyn(&[1, 1, 4, 4]));
input[[0, 0, 0, 0]] = 1.0_f64;
input[[0, 0, 0, 1]] = 2.0;
input[[0, 0, 1, 0]] = 3.0;
input[[0, 0, 1, 1]] = 4.0;
input[[0, 0, 2, 2]] = 5.0;
input[[0, 0, 2, 3]] = 6.0;
input[[0, 0, 3, 2]] = 7.0;
input[[0, 0, 3, 3]] = 8.0;
let output = max_pool_2x2(&input).expect("Max pool failed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert!((output[[0, 0, 0, 0]] - 4.0).abs() < 1e-10);
assert!((output[[0, 0, 1, 1]] - 8.0).abs() < 1e-10);
}
#[test]
fn test_vgg_load_named_params() {
let mut model: VGG<f64> = VGG::new(
VGGConfig::vgg11(1, 5)
.with_dropout(0.0)
.with_fc_hidden_units(32),
)
.expect("Failed to create VGG");
let named_params = model.extract_named_params();
let params_map: HashMap<String, Array<f64, IxDyn>> = named_params.into_iter().collect();
model
.load_named_params(¶ms_map)
.expect("Load named params failed");
}
#[test]
fn test_vgg_f32_support() {
let model: VGG<f32> = VGG::new(
VGGConfig::vgg11(1, 5)
.with_dropout(0.0)
.with_fc_hidden_units(16)
.with_channel_divisor(16),
)
.expect("Failed to create VGG f32");
let input = Array::zeros(IxDyn(&[1, 1, 32, 32]));
let output = model.forward(&input).expect("Forward pass failed for f32");
assert_eq!(output.shape(), &[1, 5]);
}
}