#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Linear, Module, Parameter};
pub struct ProposalHead {
conv: Conv2d,
bn: BatchNorm2d,
cls_conv: Conv2d,
bbox_conv: Conv2d,
center_conv: Conv2d,
}
impl ProposalHead {
pub fn new(in_channels: usize) -> Self {
Self {
conv: Conv2d::with_options(in_channels, in_channels, (3, 3), (1, 1), (1, 1), true),
bn: BatchNorm2d::new(in_channels),
cls_conv: Conv2d::with_options(in_channels, 1, (1, 1), (1, 1), (0, 0), true),
bbox_conv: Conv2d::with_options(in_channels, 4, (1, 1), (1, 1), (0, 0), true),
center_conv: Conv2d::with_options(in_channels, 1, (1, 1), (1, 1), (0, 0), true),
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable) {
let feat = self.bn.forward(&self.conv.forward(x)).relu();
let cls = self.cls_conv.forward(&feat);
let bbox = self.bbox_conv.forward(&feat);
let center = self.center_conv.forward(&feat);
(cls, bbox, center)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.conv.parameters());
p.extend(self.bn.parameters());
p.extend(self.cls_conv.parameters());
p.extend(self.bbox_conv.parameters());
p.extend(self.center_conv.parameters());
p
}
pub fn eval(&mut self) {
self.bn.eval();
}
pub fn train(&mut self) {
self.bn.train();
}
}
pub struct ClassHead {
fc1: Linear,
fc2: Linear,
}
impl ClassHead {
pub fn new(hidden_size: usize, roi_dim: usize, num_classes: usize) -> Self {
Self {
fc1: Linear::new(hidden_size + roi_dim, 128),
fc2: Linear::new(128, num_classes),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let h = self.fc1.forward(x).relu();
self.fc2.forward(&h)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.fc1.parameters());
p.extend(self.fc2.parameters());
p
}
}
pub struct UncertaintyBBoxHead {
shared: Linear,
mean_head: Linear,
logvar_head: Linear,
}
impl UncertaintyBBoxHead {
pub fn new(hidden_size: usize) -> Self {
Self {
shared: Linear::new(hidden_size, 64),
mean_head: Linear::new(64, 4),
logvar_head: Linear::new(64, 4),
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable) {
let h = self.shared.forward(x).relu();
let mean = self.mean_head.forward(&h);
let log_var = self.logvar_head.forward(&h);
(mean, log_var)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.shared.parameters());
p.extend(self.mean_head.parameters());
p.extend(self.logvar_head.parameters());
p
}
}
pub struct TemporalPredictor {
conv1: Conv2d,
bn1: BatchNorm2d,
conv2: Conv2d,
}
impl TemporalPredictor {
pub fn new(channels: usize) -> Self {
Self {
conv1: Conv2d::with_options(channels, channels, (3, 3), (1, 1), (1, 1), true),
bn1: BatchNorm2d::new(channels),
conv2: Conv2d::with_options(channels, channels, (3, 3), (1, 1), (1, 1), true),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let h = self.bn1.forward(&self.conv1.forward(x)).relu();
self.conv2.forward(&h)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.conv1.parameters());
p.extend(self.bn1.parameters());
p.extend(self.conv2.parameters());
p
}
pub fn eval(&mut self) {
self.bn1.eval();
}
pub fn train(&mut self) {
self.bn1.train();
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
#[test]
fn test_proposal_head() {
let head = ProposalHead::new(96);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 10 * 10], &[1, 96, 10, 10]).unwrap(),
false,
);
let (cls, bbox, center) = head.forward(&x);
assert_eq!(cls.shape(), vec![1, 1, 10, 10]);
assert_eq!(bbox.shape(), vec![1, 4, 10, 10]);
assert_eq!(center.shape(), vec![1, 1, 10, 10]);
}
#[test]
fn test_class_head() {
let head = ClassHead::new(64, 288, 20);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 1 * (64 + 288)], &[1, 352]).unwrap(),
false,
);
let out = head.forward(&x);
assert_eq!(out.shape(), vec![1, 20]);
}
#[test]
fn test_uncertainty_bbox_head() {
let head = UncertaintyBBoxHead::new(64);
let x = Variable::new(Tensor::from_vec(vec![0.1; 64], &[1, 64]).unwrap(), false);
let (mean, log_var) = head.forward(&x);
assert_eq!(mean.shape(), vec![1, 4]);
assert_eq!(log_var.shape(), vec![1, 4]);
assert!(log_var.data().to_vec().iter().all(|v| v.is_finite()));
}
#[test]
fn test_temporal_predictor() {
let pred = TemporalPredictor::new(96);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 10 * 10], &[1, 96, 10, 10]).unwrap(),
false,
);
let out = pred.forward(&x);
assert_eq!(out.shape(), vec![1, 96, 10, 10]);
}
#[test]
fn test_proposal_head_params() {
let head = ProposalHead::new(96);
let total: usize = head.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 5_000);
assert!(total < 200_000);
}
}