use axonml_autograd::Variable;
use axonml_nn::Parameter;
use axonml_tensor::Tensor;
use super::NexusConfig;
use super::backbone::{DorsalPathway, SharedStem, VentralPathway};
use super::fusion::MultiScaleFusion;
use super::heads::{ClassHead, ProposalHead, UncertaintyBBoxHead};
use super::memory::ObjectMemoryBank;
use super::predictive::MultiScalePredictiveCoding;
use crate::ops::{NexusDetection, nms};
pub struct Nexus {
stem: SharedStem,
ventral: VentralPathway,
dorsal: DorsalPathway,
fusion: MultiScaleFusion,
predictive: MultiScalePredictiveCoding,
proposal_heads: Vec<ProposalHead>,
memory: ObjectMemoryBank,
class_head: ClassHead,
bbox_head: UncertaintyBBoxHead,
config: NexusConfig,
total_frames: u64,
}
impl Nexus {
pub fn new() -> Self {
Self::with_config(NexusConfig::default())
}
pub fn with_config(config: NexusConfig) -> Self {
let fused_ch = 96; let hidden_size = config.memory_hidden_size;
let roi_size = 3;
let roi_dim = fused_ch * roi_size * roi_size;
Self {
stem: SharedStem::new(),
ventral: VentralPathway::new(),
dorsal: DorsalPathway::new(),
fusion: MultiScaleFusion::new(),
predictive: MultiScalePredictiveCoding::new(fused_ch),
proposal_heads: vec![
ProposalHead::new(fused_ch),
ProposalHead::new(fused_ch),
ProposalHead::new(fused_ch),
],
memory: ObjectMemoryBank::new(fused_ch, hidden_size, roi_size),
class_head: ClassHead::new(hidden_size, roi_dim, config.num_classes),
bbox_head: UncertaintyBBoxHead::new(hidden_size),
config,
total_frames: 0,
}
}
pub fn detect(&mut self, frame: &Variable) -> Vec<NexusDetection> {
self.total_frames += 1;
let shape = frame.shape();
let (_, _, img_h, img_w) = (shape[0], shape[1], shape[2], shape[3]);
let stem_out = self.stem.forward(frame);
let (v1, v2, v3) = self.ventral.forward(&stem_out);
let (d1, d2, d3) = self.dorsal.forward(&stem_out);
let (f1, f2, f3) = self.fusion.forward((&v1, &v2, &v3), (&d1, &d2, &d3));
let ((g1, _s1), (g2, _s2), (g3, _s3)) = self.predictive.forward(&f1, &f2, &f3);
let scales = [&g1, &g2, &g3];
let strides = [8.0f32, 16.0, 32.0];
let mut all_proposals: Vec<[f32; 4]> = Vec::new();
let mut all_scores: Vec<f32> = Vec::new();
for (scale_idx, (feat, stride)) in scales.iter().zip(strides.iter()).enumerate() {
let (cls, bbox, center) = self.proposal_heads[scale_idx].forward(feat);
let cls_data = cls.data().to_vec();
let bbox_data = bbox.data().to_vec();
let center_data = center.data().to_vec();
let fh = cls.shape()[2];
let fw = cls.shape()[3];
for fy in 0..fh {
for fx in 0..fw {
let cls_score = 1.0 / (1.0 + (-cls_data[fy * fw + fx]).exp());
let centerness = 1.0 / (1.0 + (-center_data[fy * fw + fx]).exp());
let score = (cls_score * centerness).sqrt();
if score > self.config.proposal_threshold {
let cx = (fx as f32 + 0.5) * stride;
let cy = (fy as f32 + 0.5) * stride;
let dx = bbox_data[0 * fh * fw + fy * fw + fx];
let dy = bbox_data[fh * fw + fy * fw + fx];
let dw = bbox_data[2 * fh * fw + fy * fw + fx];
let dh = bbox_data[3 * fh * fw + fy * fw + fx];
let bw = dw.exp() * stride;
let bh = dh.exp() * stride;
let x1 = (cx + dx - bw / 2.0).max(0.0).min(img_w as f32);
let y1 = (cy + dy - bh / 2.0).max(0.0).min(img_h as f32);
let x2 = (cx + dx + bw / 2.0).max(0.0).min(img_w as f32);
let y2 = (cy + dy + bh / 2.0).max(0.0).min(img_h as f32);
if x2 > x1 && y2 > y1 {
all_proposals.push([x1, y1, x2, y2]);
all_scores.push(score);
}
}
}
}
}
if all_proposals.is_empty() {
self.memory.update(&g2, &[], &[], 1.0 / 16.0);
return Vec::new();
}
let boxes_flat: Vec<f32> = all_proposals
.iter()
.flat_map(|b| b.iter().copied())
.collect();
let n = all_proposals.len();
let boxes_tensor = Tensor::from_vec(boxes_flat, &[n, 4]).unwrap();
let scores_tensor = Tensor::from_vec(all_scores.clone(), &[n]).unwrap();
let kept = nms(&boxes_tensor, &scores_tensor, self.config.nms_threshold);
let nms_proposals: Vec<[f32; 4]> = kept.iter().map(|&i| all_proposals[i]).collect();
let nms_scores: Vec<f32> = kept.iter().map(|&i| all_scores[i]).collect();
let spatial_scale = 1.0 / 16.0; let hidden_states = self
.memory
.update(&g2, &nms_proposals, &nms_scores, spatial_scale);
let mut detections = Vec::new();
for (si, slot) in self.memory.slots().iter().enumerate() {
if si >= hidden_states.len() {
break;
}
let (bbox_mean, bbox_log_var) = self.bbox_head.forward(&hidden_states[si]);
let mean_data = bbox_mean.data().to_vec();
let logvar_data = bbox_log_var.data().to_vec();
let refined_bbox = [
slot.bbox[0] + mean_data[0] * 0.1,
slot.bbox[1] + mean_data[1] * 0.1,
slot.bbox[2] + mean_data[2] * 0.1,
slot.bbox[3] + mean_data[3] * 0.1,
];
detections.push(NexusDetection {
bbox_mean: refined_bbox,
bbox_log_var: [
logvar_data[0],
logvar_data[1],
logvar_data[2],
logvar_data[3],
],
confidence: slot.confidence,
class_id: 0, tracking_id: slot.id,
frames_tracked: slot.frames_tracked,
});
}
detections
}
pub fn forward_train(&mut self, frame: &Variable) -> super::NexusTrainOutput {
let stem_out = self.stem.forward(frame);
let (v1, v2, v3) = self.ventral.forward(&stem_out);
let (d1, d2, d3) = self.dorsal.forward(&stem_out);
let (f1, f2, f3) = self.fusion.forward((&v1, &v2, &v3), (&d1, &d2, &d3));
let ((g1, _), (g2, _), (g3, _)) = self.predictive.forward(&f1, &f2, &f3);
let scales = [&g1, &g2, &g3];
let mut scale_outputs = Vec::with_capacity(3);
for (i, feat) in scales.iter().enumerate() {
let (cls, bbox, center) = self.proposal_heads[i].forward(feat);
scale_outputs.push(super::NexusScaleOutput {
cls_logits: cls,
bbox_pred: bbox,
centerness: center,
});
}
super::NexusTrainOutput {
scales: scale_outputs,
}
}
pub fn detect_video_frame(&mut self, frame: &Variable) -> Vec<NexusDetection> {
self.detect(frame)
}
pub fn reset(&mut self) {
self.predictive.reset();
self.memory.reset();
self.total_frames = 0;
}
pub fn total_frames(&self) -> u64 {
self.total_frames
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.stem.parameters());
p.extend(self.ventral.parameters());
p.extend(self.dorsal.parameters());
p.extend(self.fusion.parameters());
p.extend(self.predictive.parameters());
for head in &self.proposal_heads {
p.extend(head.parameters());
}
p.extend(self.memory.parameters());
p.extend(self.class_head.parameters());
p.extend(self.bbox_head.parameters());
p
}
pub fn eval(&mut self) {
self.stem.eval();
self.ventral.eval();
self.dorsal.eval();
self.fusion.eval();
self.predictive.eval();
for head in &mut self.proposal_heads {
head.eval();
}
}
pub fn train(&mut self) {
self.stem.train();
self.ventral.train();
self.dorsal.train();
self.fusion.train();
self.predictive.train();
for head in &mut self.proposal_heads {
head.train();
}
}
}
impl Default for Nexus {
fn default() -> Self {
Self::new()
}
}
impl crate::camera::pipeline::DetectionModel for Nexus {
type Output = Vec<NexusDetection>;
fn detect(&mut self, input: &Variable) -> Vec<NexusDetection> {
Nexus::detect(self, input)
}
fn input_size(&self) -> (u32, u32) {
(self.config.input_width, self.config.input_height)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_frame(h: usize, w: usize, val: f32) -> Variable {
Variable::new(
Tensor::from_vec(vec![val; 3 * h * w], &[1, 3, h, w]).unwrap(),
false,
)
}
#[test]
fn test_nexus_creation() {
let nexus = Nexus::new();
assert_eq!(nexus.total_frames(), 0);
}
#[test]
fn test_nexus_param_count() {
let nexus = Nexus::new();
let total: usize = nexus.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 200_000, "Nexus too small: {total} params");
assert!(total < 2_000_000, "Nexus too large: {total} params");
}
#[test]
fn test_nexus_single_frame() {
let mut nexus = Nexus::new();
let frame = make_frame(320, 320, 0.5);
let detections = nexus.detect(&frame);
assert_eq!(nexus.total_frames(), 1);
for det in &detections {
assert!(det.confidence >= 0.0 && det.confidence <= 1.0);
assert!(det.bbox_log_var.iter().all(|v| v.is_finite()));
}
}
#[test]
fn test_nexus_two_frame_predictive() {
let mut nexus = Nexus::with_config(NexusConfig {
input_width: 128,
input_height: 128,
..NexusConfig::default()
});
let frame1 = make_frame(128, 128, 0.3);
let frame2 = make_frame(128, 128, 0.3);
nexus.detect(&frame1);
assert!(nexus.predictive.scale1.has_prediction());
nexus.detect(&frame2);
assert_eq!(nexus.total_frames(), 2);
}
#[test]
fn test_nexus_uncertainty_finite() {
let mut nexus = Nexus::with_config(NexusConfig {
input_width: 64,
input_height: 64,
proposal_threshold: 0.0, ..NexusConfig::default()
});
let frame = make_frame(64, 64, 0.5);
let detections = nexus.detect(&frame);
for det in &detections {
assert!(det.bbox_mean.iter().all(|v| v.is_finite()));
assert!(det.bbox_log_var.iter().all(|v| v.is_finite()));
}
}
#[test]
fn test_nexus_reset() {
let mut nexus = Nexus::with_config(NexusConfig {
input_width: 64,
input_height: 64,
..NexusConfig::default()
});
let frame = make_frame(64, 64, 0.5);
nexus.detect(&frame);
nexus.reset();
assert_eq!(nexus.total_frames(), 0);
assert!(!nexus.predictive.scale1.has_prediction());
assert_eq!(nexus.memory.num_tracked(), 0);
}
#[test]
fn test_nexus_video_api() {
let mut nexus = Nexus::with_config(NexusConfig {
input_width: 64,
input_height: 64,
..NexusConfig::default()
});
let frame = make_frame(64, 64, 0.4);
let _det = nexus.detect_video_frame(&frame);
assert_eq!(nexus.total_frames(), 1);
}
}