#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter, ReLU};
pub struct SharedStem {
conv1: Conv2d,
bn1: BatchNorm2d,
conv2: Conv2d,
bn2: BatchNorm2d,
relu: ReLU,
}
impl SharedStem {
pub fn new() -> Self {
Self {
conv1: Conv2d::with_options(3, 32, (3, 3), (2, 2), (1, 1), true),
bn1: BatchNorm2d::new(32),
conv2: Conv2d::with_options(32, 64, (3, 3), (2, 2), (1, 1), true),
bn2: BatchNorm2d::new(64),
relu: ReLU,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let out = self.relu.forward(&self.bn1.forward(&self.conv1.forward(x)));
self.relu
.forward(&self.bn2.forward(&self.conv2.forward(&out)))
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.conv1.parameters());
p.extend(self.bn1.parameters());
p.extend(self.conv2.parameters());
p.extend(self.bn2.parameters());
p
}
pub fn eval(&mut self) {
self.bn1.eval();
self.bn2.eval();
}
pub fn train(&mut self) {
self.bn1.train();
self.bn2.train();
}
}
impl Default for SharedStem {
fn default() -> Self {
Self::new()
}
}
pub struct InvertedResidualBlock {
expand: Conv2d,
expand_bn: BatchNorm2d,
dw: Conv2d,
dw_bn: BatchNorm2d,
project: Conv2d,
project_bn: BatchNorm2d,
shortcut: Option<(Conv2d, BatchNorm2d)>,
relu: ReLU,
use_residual: bool,
}
impl InvertedResidualBlock {
pub fn new(in_ch: usize, out_ch: usize, stride: usize, expand_ratio: usize) -> Self {
let mid_ch = in_ch * expand_ratio;
let expand = Conv2d::with_options(in_ch, mid_ch, (1, 1), (1, 1), (0, 0), true);
let expand_bn = BatchNorm2d::new(mid_ch);
let dw = Conv2d::with_groups(
mid_ch,
mid_ch,
(3, 3),
(stride, stride),
(1, 1),
true,
mid_ch,
);
let dw_bn = BatchNorm2d::new(mid_ch);
let project = Conv2d::with_options(mid_ch, out_ch, (1, 1), (1, 1), (0, 0), true);
let project_bn = BatchNorm2d::new(out_ch);
let use_residual = stride == 1 && in_ch == out_ch;
let shortcut = if !use_residual && stride != 1 {
None } else if !use_residual {
Some((
Conv2d::with_options(in_ch, out_ch, (1, 1), (stride, stride), (0, 0), true),
BatchNorm2d::new(out_ch),
))
} else {
None
};
Self {
expand,
expand_bn,
dw,
dw_bn,
project,
project_bn,
shortcut,
relu: ReLU,
use_residual,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let out = self
.relu
.forward(&self.expand_bn.forward(&self.expand.forward(x)));
let out = self
.relu
.forward(&self.dw_bn.forward(&self.dw.forward(&out)));
let out = self.project_bn.forward(&self.project.forward(&out));
if self.use_residual {
out.add_var(x)
} else if let Some((ref conv, ref bn)) = self.shortcut {
out.add_var(&bn.forward(&conv.forward(x)))
} else {
out
}
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.expand.parameters());
p.extend(self.expand_bn.parameters());
p.extend(self.dw.parameters());
p.extend(self.dw_bn.parameters());
p.extend(self.project.parameters());
p.extend(self.project_bn.parameters());
if let Some((ref c, ref bn)) = self.shortcut {
p.extend(c.parameters());
p.extend(bn.parameters());
}
p
}
pub fn eval(&mut self) {
self.expand_bn.eval();
self.dw_bn.eval();
self.project_bn.eval();
if let Some((_, ref mut bn)) = self.shortcut {
bn.eval();
}
}
pub fn train(&mut self) {
self.expand_bn.train();
self.dw_bn.train();
self.project_bn.train();
if let Some((_, ref mut bn)) = self.shortcut {
bn.train();
}
}
}
pub struct VentralPathway {
stage1: Vec<InvertedResidualBlock>,
stage2: Vec<InvertedResidualBlock>,
stage3: Vec<InvertedResidualBlock>,
}
impl VentralPathway {
pub fn new() -> Self {
Self {
stage1: vec![
InvertedResidualBlock::new(64, 96, 2, 2),
InvertedResidualBlock::new(96, 96, 1, 2),
],
stage2: vec![
InvertedResidualBlock::new(96, 128, 2, 2),
InvertedResidualBlock::new(128, 128, 1, 2),
],
stage3: vec![
InvertedResidualBlock::new(128, 192, 2, 2),
InvertedResidualBlock::new(192, 192, 1, 2),
],
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable) {
let mut out = x.clone();
for block in &self.stage1 {
out = block.forward(&out);
}
let v1 = out.clone();
for block in &self.stage2 {
out = block.forward(&out);
}
let v2 = out.clone();
for block in &self.stage3 {
out = block.forward(&out);
}
let v3 = out;
(v1, v2, v3)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
for block in &self.stage1 {
p.extend(block.parameters());
}
for block in &self.stage2 {
p.extend(block.parameters());
}
for block in &self.stage3 {
p.extend(block.parameters());
}
p
}
pub fn eval(&mut self) {
for b in &mut self.stage1 {
b.eval();
}
for b in &mut self.stage2 {
b.eval();
}
for b in &mut self.stage3 {
b.eval();
}
}
pub fn train(&mut self) {
for b in &mut self.stage1 {
b.train();
}
for b in &mut self.stage2 {
b.train();
}
for b in &mut self.stage3 {
b.train();
}
}
}
impl Default for VentralPathway {
fn default() -> Self {
Self::new()
}
}
pub struct DorsalPathway {
conv1: Conv2d,
bn1: BatchNorm2d,
conv2: Conv2d,
bn2: BatchNorm2d,
conv3: Conv2d,
bn3: BatchNorm2d,
relu: ReLU,
}
impl DorsalPathway {
pub fn new() -> Self {
Self {
conv1: Conv2d::with_options(64, 48, (5, 5), (2, 2), (2, 2), true),
bn1: BatchNorm2d::new(48),
conv2: Conv2d::with_options(48, 64, (5, 5), (2, 2), (2, 2), true),
bn2: BatchNorm2d::new(64),
conv3: Conv2d::with_options(64, 96, (5, 5), (2, 2), (2, 2), true),
bn3: BatchNorm2d::new(96),
relu: ReLU,
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable) {
let d1 = self.relu.forward(&self.bn1.forward(&self.conv1.forward(x)));
let d2 = self
.relu
.forward(&self.bn2.forward(&self.conv2.forward(&d1)));
let d3 = self
.relu
.forward(&self.bn3.forward(&self.conv3.forward(&d2)));
(d1, d2, d3)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.conv1.parameters());
p.extend(self.bn1.parameters());
p.extend(self.conv2.parameters());
p.extend(self.bn2.parameters());
p.extend(self.conv3.parameters());
p.extend(self.bn3.parameters());
p
}
pub fn eval(&mut self) {
self.bn1.eval();
self.bn2.eval();
self.bn3.eval();
}
pub fn train(&mut self) {
self.bn1.train();
self.bn2.train();
self.bn3.train();
}
}
impl Default for DorsalPathway {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
#[test]
fn test_shared_stem() {
let stem = SharedStem::new();
let x = Variable::new(
Tensor::from_vec(vec![0.1; 3 * 320 * 320], &[1, 3, 320, 320]).unwrap(),
false,
);
let out = stem.forward(&x);
assert_eq!(out.shape(), vec![1, 64, 80, 80]);
}
#[test]
fn test_inverted_residual_same() {
let block = InvertedResidualBlock::new(96, 96, 1, 2);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 8 * 8], &[1, 96, 8, 8]).unwrap(),
false,
);
let out = block.forward(&x);
assert_eq!(out.shape(), vec![1, 96, 8, 8]);
}
#[test]
fn test_inverted_residual_downsample() {
let block = InvertedResidualBlock::new(64, 96, 2, 2);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 64 * 16 * 16], &[1, 64, 16, 16]).unwrap(),
false,
);
let out = block.forward(&x);
assert_eq!(out.shape(), vec![1, 96, 8, 8]);
}
#[test]
fn test_ventral_pathway() {
let ventral = VentralPathway::new();
let x = Variable::new(
Tensor::from_vec(vec![0.1; 64 * 80 * 80], &[1, 64, 80, 80]).unwrap(),
false,
);
let (v1, v2, v3) = ventral.forward(&x);
assert_eq!(v1.shape(), vec![1, 96, 40, 40]);
assert_eq!(v2.shape(), vec![1, 128, 20, 20]);
assert_eq!(v3.shape(), vec![1, 192, 10, 10]);
}
#[test]
fn test_dorsal_pathway() {
let dorsal = DorsalPathway::new();
let x = Variable::new(
Tensor::from_vec(vec![0.1; 64 * 80 * 80], &[1, 64, 80, 80]).unwrap(),
false,
);
let (d1, d2, d3) = dorsal.forward(&x);
assert_eq!(d1.shape(), vec![1, 48, 40, 40]);
assert_eq!(d2.shape(), vec![1, 64, 20, 20]);
assert_eq!(d3.shape(), vec![1, 96, 10, 10]);
}
#[test]
fn test_stem_param_count() {
let stem = SharedStem::new();
let total: usize = stem.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 1000);
assert!(total < 50_000);
}
#[test]
fn test_ventral_param_count() {
let ventral = VentralPathway::new();
let total: usize = ventral.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 50_000);
assert!(total < 500_000);
}
#[test]
fn test_dorsal_param_count() {
let dorsal = DorsalPathway::new();
let total: usize = dorsal.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 10_000);
assert!(total < 350_000);
}
}