use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{Conv2d, Module, Parameter};
use super::backbone::ConvBNSiLU;
pub struct DecoupledHead {
stem: ConvBNSiLU,
cls_conv: ConvBNSiLU,
cls_pred: Conv2d,
reg_conv: ConvBNSiLU,
reg_pred: Conv2d,
obj_pred: Conv2d,
domain_pred: Option<Conv2d>,
pub num_classes: usize,
pub num_domains: usize,
}
impl DecoupledHead {
pub fn new(in_ch: usize, num_classes: usize, num_domains: usize) -> Self {
let hidden = in_ch;
Self {
stem: ConvBNSiLU::new(in_ch, hidden, 1, 1, 0),
cls_conv: ConvBNSiLU::new(hidden, hidden, 3, 1, 1),
cls_pred: Conv2d::with_options(hidden, num_classes, (1, 1), (1, 1), (0, 0), true),
reg_conv: ConvBNSiLU::new(hidden, hidden, 3, 1, 1),
reg_pred: Conv2d::with_options(hidden, 4, (1, 1), (1, 1), (0, 0), true),
obj_pred: Conv2d::with_options(hidden, 1, (1, 1), (1, 1), (0, 0), true),
domain_pred: if num_domains > 0 {
Some(Conv2d::with_options(
hidden,
num_domains,
(1, 1),
(1, 1),
(0, 0),
true,
))
} else {
None
},
num_classes,
num_domains,
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable, Option<Variable>) {
let stem_out = self.stem.forward(x);
let cls_feat = self.cls_conv.forward(&stem_out);
let cls_out = self.cls_pred.forward(&cls_feat);
let reg_feat = self.reg_conv.forward(&stem_out);
let reg_out = self.reg_pred.forward(®_feat);
let obj_out = self.obj_pred.forward(®_feat);
let domain_out = self.domain_pred.as_ref().map(|dp| dp.forward(&cls_feat));
(cls_out, reg_out, obj_out, domain_out)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.stem.parameters();
p.extend(self.cls_conv.parameters());
p.extend(self.cls_pred.parameters());
p.extend(self.reg_conv.parameters());
p.extend(self.reg_pred.parameters());
p.extend(self.obj_pred.parameters());
if let Some(ref dp) = self.domain_pred {
p.extend(dp.parameters());
}
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = self.stem.named_parameters(&format!("{}.stem", prefix));
p.extend(
self.cls_conv
.named_parameters(&format!("{}.cls_conv", prefix)),
);
for (k, v) in self.cls_pred.named_parameters() {
p.insert(format!("{}.cls_pred.{}", prefix, k), v);
}
p.extend(
self.reg_conv
.named_parameters(&format!("{}.reg_conv", prefix)),
);
for (k, v) in self.reg_pred.named_parameters() {
p.insert(format!("{}.reg_pred.{}", prefix, k), v);
}
for (k, v) in self.obj_pred.named_parameters() {
p.insert(format!("{}.obj_pred.{}", prefix, k), v);
}
if let Some(ref dp) = self.domain_pred {
for (k, v) in dp.named_parameters() {
p.insert(format!("{}.domain_pred.{}", prefix, k), v);
}
}
p
}
pub fn set_training(&mut self, training: bool) {
self.stem.set_training(training);
self.cls_conv.set_training(training);
self.reg_conv.set_training(training);
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
#[test]
fn test_decoupled_head_shapes() {
let head = DecoupledHead::new(128, 10, 0);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 128 * 8 * 8], &[1, 128, 8, 8]).unwrap(),
false,
);
let (cls, reg, obj, domain) = head.forward(&x);
assert_eq!(cls.data().shape(), &[1, 10, 8, 8]); assert_eq!(reg.data().shape(), &[1, 4, 8, 8]); assert_eq!(obj.data().shape(), &[1, 1, 8, 8]); assert!(domain.is_none()); }
#[test]
fn test_decoupled_head_with_domain() {
let head = DecoupledHead::new(64, 5, 3);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 64 * 4 * 4], &[2, 64, 4, 4]).unwrap(),
false,
);
let (cls, reg, obj, domain) = head.forward(&x);
assert_eq!(cls.data().shape(), &[2, 5, 4, 4]);
assert_eq!(reg.data().shape(), &[2, 4, 4, 4]);
assert_eq!(obj.data().shape(), &[2, 1, 4, 4]);
assert!(domain.is_some());
assert_eq!(domain.unwrap().data().shape(), &[2, 3, 4, 4]); }
#[test]
fn test_decoupled_head_config_fields() {
let head = DecoupledHead::new(128, 20, 4);
assert_eq!(head.num_classes, 20);
assert_eq!(head.num_domains, 4);
}
#[test]
fn test_decoupled_head_params() {
let head = DecoupledHead::new(64, 5, 0);
assert!(!head.parameters().is_empty());
}
}