#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter};
use axonml_tensor::Tensor;
pub struct PathwayFusion {
gate_conv: Conv2d,
gate_bn: BatchNorm2d,
out_conv: Conv2d,
out_bn: BatchNorm2d,
_out_channels: usize,
}
impl PathwayFusion {
pub fn new(ventral_ch: usize, dorsal_ch: usize, out_ch: usize) -> Self {
let total_in = ventral_ch + dorsal_ch;
Self {
gate_conv: Conv2d::with_options(total_in, 1, (1, 1), (1, 1), (0, 0), true),
gate_bn: BatchNorm2d::new(1),
out_conv: Conv2d::with_options(total_in, out_ch, (1, 1), (1, 1), (0, 0), true),
out_bn: BatchNorm2d::new(out_ch),
_out_channels: out_ch,
}
}
pub fn forward(&self, ventral: &Variable, dorsal: &Variable) -> Variable {
let v_shape = ventral.shape();
let d_shape = dorsal.shape();
let (b, v_ch) = (v_shape[0], v_shape[1]);
let d_ch = d_shape[1];
let (h, w) = (v_shape[2], v_shape[3]);
let concatenated = Variable::cat(&[ventral, dorsal], 1);
let gate_raw = self.gate_bn.forward(&self.gate_conv.forward(&concatenated));
let gate = gate_raw.sigmoid();
let gate_v = gate.expand(&[b, v_ch, h, w]);
let gate_d = gate.expand(&[b, d_ch, h, w]);
let one = Variable::new(Tensor::ones(&[1]), false);
let inv_gate_d = one.sub_var(&gate_d);
let gated_v = ventral.mul_var(&gate_v);
let gated_d = dorsal.mul_var(&inv_gate_d);
let blended_var = Variable::cat(&[&gated_v, &gated_d], 1);
let out = self.out_bn.forward(&self.out_conv.forward(&blended_var));
out.relu()
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.gate_conv.parameters());
p.extend(self.gate_bn.parameters());
p.extend(self.out_conv.parameters());
p.extend(self.out_bn.parameters());
p
}
pub fn eval(&mut self) {
self.gate_bn.eval();
self.out_bn.eval();
}
pub fn train(&mut self) {
self.gate_bn.train();
self.out_bn.train();
}
}
pub struct MultiScaleFusion {
pub scale1: PathwayFusion,
pub scale2: PathwayFusion,
pub scale3: PathwayFusion,
}
impl MultiScaleFusion {
pub fn new() -> Self {
Self {
scale1: PathwayFusion::new(96, 48, 96),
scale2: PathwayFusion::new(128, 64, 96),
scale3: PathwayFusion::new(192, 96, 96),
}
}
pub fn forward(
&self,
ventral: (&Variable, &Variable, &Variable),
dorsal: (&Variable, &Variable, &Variable),
) -> (Variable, Variable, Variable) {
let f1 = self.scale1.forward(ventral.0, dorsal.0);
let f2 = self.scale2.forward(ventral.1, dorsal.1);
let f3 = self.scale3.forward(ventral.2, dorsal.2);
(f1, f2, f3)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.scale1.parameters());
p.extend(self.scale2.parameters());
p.extend(self.scale3.parameters());
p
}
pub fn eval(&mut self) {
self.scale1.eval();
self.scale2.eval();
self.scale3.eval();
}
pub fn train(&mut self) {
self.scale1.train();
self.scale2.train();
self.scale3.train();
}
}
impl Default for MultiScaleFusion {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pathway_fusion_shapes() {
let fusion = PathwayFusion::new(96, 48, 96);
let v = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 10 * 10], &[1, 96, 10, 10]).unwrap(),
false,
);
let d = Variable::new(
Tensor::from_vec(vec![0.1; 48 * 10 * 10], &[1, 48, 10, 10]).unwrap(),
false,
);
let out = fusion.forward(&v, &d);
assert_eq!(out.shape(), vec![1, 96, 10, 10]);
}
#[test]
fn test_multi_scale_fusion() {
let fusion = MultiScaleFusion::new();
let v1 = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 40 * 40], &[1, 96, 40, 40]).unwrap(),
false,
);
let v2 = Variable::new(
Tensor::from_vec(vec![0.1; 128 * 20 * 20], &[1, 128, 20, 20]).unwrap(),
false,
);
let v3 = Variable::new(
Tensor::from_vec(vec![0.1; 192 * 10 * 10], &[1, 192, 10, 10]).unwrap(),
false,
);
let d1 = Variable::new(
Tensor::from_vec(vec![0.1; 48 * 40 * 40], &[1, 48, 40, 40]).unwrap(),
false,
);
let d2 = Variable::new(
Tensor::from_vec(vec![0.1; 64 * 20 * 20], &[1, 64, 20, 20]).unwrap(),
false,
);
let d3 = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 10 * 10], &[1, 96, 10, 10]).unwrap(),
false,
);
let (f1, f2, f3) = fusion.forward((&v1, &v2, &v3), (&d1, &d2, &d3));
assert_eq!(f1.shape(), vec![1, 96, 40, 40]);
assert_eq!(f2.shape(), vec![1, 96, 20, 20]);
assert_eq!(f3.shape(), vec![1, 96, 10, 10]);
}
#[test]
fn test_fusion_gate_bounded() {
let fusion = PathwayFusion::new(16, 8, 16);
let v = Variable::new(
Tensor::from_vec(vec![1.0; 16 * 4 * 4], &[1, 16, 4, 4]).unwrap(),
false,
);
let d = Variable::new(
Tensor::from_vec(vec![0.5; 8 * 4 * 4], &[1, 8, 4, 4]).unwrap(),
false,
);
let out = fusion.forward(&v, &d);
let data = out.data().to_vec();
assert!(data.iter().all(|v| v.is_finite()));
}
}