use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{Module, Parameter};
#[cfg(test)]
use axonml_tensor::Tensor;
use super::backbone::ThermalBackbone;
use super::head::DecoupledHead;
use super::neck::ThermalFPN;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ThermalDomain {
Wildlife,
Human,
Interstellar,
Vehicle,
General,
}
impl ThermalDomain {
pub fn index(&self) -> usize {
match self {
Self::Wildlife => 0,
Self::Human => 1,
Self::Interstellar => 2,
Self::Vehicle => 3,
Self::General => 4,
}
}
pub fn from_index(i: usize) -> Self {
match i {
0 => Self::Wildlife,
1 => Self::Human,
2 => Self::Interstellar,
3 => Self::Vehicle,
_ => Self::General,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Wildlife => "wildlife",
Self::Human => "human",
Self::Interstellar => "interstellar",
Self::Vehicle => "vehicle",
Self::General => "general",
}
}
pub fn count() -> usize {
5
}
}
#[derive(Debug, Clone)]
pub struct NightVisionConfig {
pub in_channels: usize,
pub num_classes: usize,
pub num_domains: usize,
pub fpn_channels: usize,
pub img_size: usize,
}
impl NightVisionConfig {
pub fn wildlife(num_species: usize) -> Self {
Self {
in_channels: 1,
num_classes: num_species,
num_domains: 0,
fpn_channels: 128,
img_size: 320,
}
}
pub fn human() -> Self {
Self {
in_channels: 1,
num_classes: 1, num_domains: 0,
fpn_channels: 128,
img_size: 320,
}
}
pub fn interstellar(num_classes: usize, bands: usize) -> Self {
Self {
in_channels: bands,
num_classes,
num_domains: 0,
fpn_channels: 128,
img_size: 512,
}
}
pub fn multi_domain(num_classes: usize) -> Self {
Self {
in_channels: 1,
num_classes,
num_domains: ThermalDomain::count(),
fpn_channels: 128,
img_size: 320,
}
}
pub fn edge(num_classes: usize) -> Self {
Self {
in_channels: 1,
num_classes,
num_domains: 0,
fpn_channels: 64,
img_size: 256,
}
}
}
impl Default for NightVisionConfig {
fn default() -> Self {
Self::multi_domain(10)
}
}
pub struct NightVision {
backbone: ThermalBackbone,
neck: ThermalFPN,
head_p3: DecoupledHead,
head_p4: DecoupledHead,
head_p5: DecoupledHead,
config: NightVisionConfig,
}
impl NightVision {
pub fn new(config: NightVisionConfig) -> Self {
let fpn_ch = config.fpn_channels;
Self {
backbone: ThermalBackbone::new(config.in_channels),
neck: ThermalFPN::new(64, 128, 256, fpn_ch),
head_p3: DecoupledHead::new(fpn_ch, config.num_classes, config.num_domains),
head_p4: DecoupledHead::new(fpn_ch, config.num_classes, config.num_domains),
head_p5: DecoupledHead::new(fpn_ch, config.num_classes, config.num_domains),
config,
}
}
pub fn config(&self) -> &NightVisionConfig {
&self.config
}
pub fn forward_detection(
&self,
x: &Variable,
) -> Vec<(Variable, Variable, Variable, Option<Variable>)> {
let (p3, p4, p5) = self.backbone.forward(x);
let (fpn3, fpn4, fpn5) = self.neck.forward(&p3, &p4, &p5);
let out3 = self.head_p3.forward(&fpn3);
let out4 = self.head_p4.forward(&fpn4);
let out5 = self.head_p5.forward(&fpn5);
vec![out3, out4, out5]
}
pub fn forward_flat(&self, x: &Variable) -> (Variable, Variable, Variable) {
let outputs = self.forward_detection(x);
let batch = x.shape()[0];
let mut all_cls = Vec::new();
let mut all_bbox = Vec::new();
let mut all_obj = Vec::new();
for (cls, bbox, obj, _) in &outputs {
let cls_shape = cls.shape();
let h = cls_shape[2];
let w = cls_shape[3];
let n_anchors = h * w;
let cls_flat = cls
.reshape(&[batch, self.config.num_classes, n_anchors])
.transpose(1, 2); let bbox_flat = bbox.reshape(&[batch, 4, n_anchors]).transpose(1, 2); let obj_flat = obj.reshape(&[batch, 1, n_anchors]).transpose(1, 2);
all_cls.push(cls_flat);
all_bbox.push(bbox_flat);
all_obj.push(obj_flat);
}
let cls_refs: Vec<&Variable> = all_cls.iter().collect();
let bbox_refs: Vec<&Variable> = all_bbox.iter().collect();
let obj_refs: Vec<&Variable> = all_obj.iter().collect();
(
Variable::cat(&cls_refs, 1), Variable::cat(&bbox_refs, 1), Variable::cat(&obj_refs, 1), )
}
}
impl Module for NightVision {
fn forward(&self, input: &Variable) -> Variable {
let (cls, _bbox, _obj) = self.forward_flat(input);
cls
}
fn parameters(&self) -> Vec<Parameter> {
let mut p = self.backbone.parameters();
p.extend(self.neck.parameters());
p.extend(self.head_p3.parameters());
p.extend(self.head_p4.parameters());
p.extend(self.head_p5.parameters());
p
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut p = self.backbone.named_parameters();
p.extend(self.neck.named_parameters());
p.extend(self.head_p3.named_parameters("head_p3"));
p.extend(self.head_p4.named_parameters("head_p4"));
p.extend(self.head_p5.named_parameters("head_p5"));
p
}
fn name(&self) -> &'static str {
"NightVision"
}
fn set_training(&mut self, training: bool) {
self.backbone.set_training(training);
self.neck.set_training(training);
self.head_p3.set_training(training);
self.head_p4.set_training(training);
self.head_p5.set_training(training);
}
}
pub struct NightVisionLoss {
pub cls_weight: f32,
pub bbox_weight: f32,
pub obj_weight: f32,
pub domain_weight: f32,
}
impl Default for NightVisionLoss {
fn default() -> Self {
Self {
cls_weight: 1.0,
bbox_weight: 5.0,
obj_weight: 1.0,
domain_weight: 0.5,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nightvision_wildlife_forward() {
let config = NightVisionConfig::wildlife(10);
let model = NightVision::new(config);
let input = Variable::new(
Tensor::from_vec(vec![0.5f32; 1 * 1 * 128 * 128], &[1, 1, 128, 128]).unwrap(),
false,
);
let outputs = model.forward_detection(&input);
assert_eq!(outputs.len(), 3);
for (cls, bbox, obj, domain) in &outputs {
assert_eq!(cls.shape()[0], 1);
assert_eq!(cls.shape()[1], 10); assert_eq!(bbox.shape()[1], 4);
assert_eq!(obj.shape()[1], 1);
assert!(domain.is_none());
}
}
#[test]
fn test_nightvision_human_forward() {
let config = NightVisionConfig::human();
let model = NightVision::new(config);
let input = Variable::new(
Tensor::from_vec(vec![0.5f32; 1 * 1 * 64 * 64], &[1, 1, 64, 64]).unwrap(),
false,
);
let (cls, bbox, obj) = model.forward_flat(&input);
assert_eq!(cls.shape()[0], 1);
assert_eq!(cls.shape()[2], 1); assert_eq!(bbox.shape()[2], 4);
}
#[test]
fn test_nightvision_multi_domain() {
let config = NightVisionConfig::multi_domain(5);
let model = NightVision::new(config);
let input = Variable::new(
Tensor::from_vec(vec![0.5f32; 1 * 1 * 64 * 64], &[1, 1, 64, 64]).unwrap(),
false,
);
let outputs = model.forward_detection(&input);
for (_, _, _, domain) in &outputs {
assert!(domain.is_some());
let d = domain.as_ref().unwrap();
assert_eq!(d.shape()[1], 5); }
}
#[test]
fn test_nightvision_param_count() {
let model = NightVision::new(NightVisionConfig::wildlife(10));
let count: usize = model.parameters().iter().map(|p| p.numel()).sum();
println!("NightVision(wildlife, 10 classes): {} params", count);
assert!(count > 100_000);
assert!(count < 5_000_000);
}
#[test]
fn test_nightvision_interstellar() {
let config = NightVisionConfig::interstellar(3, 3); let model = NightVision::new(config);
let input = Variable::new(
Tensor::from_vec(vec![0.5f32; 1 * 3 * 64 * 64], &[1, 3, 64, 64]).unwrap(),
false,
);
let (cls, bbox, obj) = model.forward_flat(&input);
assert_eq!(cls.shape()[2], 3);
}
}