#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_nn::{GRUCell, Linear, Module, Parameter};
use axonml_tensor::Tensor;
use crate::ops::{box_iou, roi_align};
#[derive(Clone)]
pub struct ObjectSlot {
pub id: u64,
pub bbox: [f32; 4],
pub confidence: f32,
pub hidden: Variable,
pub class_logits: Vec<f32>,
pub frames_tracked: u32,
pub frames_missing: u32,
}
pub struct ObjectMemoryBank {
roi_project: Linear,
gru: GRUCell,
hidden_size: usize,
roi_size: usize,
feat_channels: usize,
slots: Vec<ObjectSlot>,
next_id: u64,
pub match_threshold: f32,
pub max_missing: u32,
pub decay_rate: f32,
}
impl ObjectMemoryBank {
pub fn new(feat_channels: usize, hidden_size: usize, roi_size: usize) -> Self {
let roi_feat_dim = feat_channels * roi_size * roi_size;
Self {
roi_project: Linear::new(roi_feat_dim, hidden_size),
gru: GRUCell::new(hidden_size, hidden_size),
hidden_size,
roi_size,
feat_channels,
slots: Vec::new(),
next_id: 1,
match_threshold: 0.3,
max_missing: 10,
decay_rate: 0.9,
}
}
pub fn default_config() -> Self {
Self::new(96, 64, 3)
}
pub fn update(
&mut self,
features: &Variable,
proposals: &[[f32; 4]],
scores: &[f32],
spatial_scale: f32,
) -> Vec<Variable> {
let n_proposals = proposals.len();
let mut matched_prop = vec![false; n_proposals];
let mut matched_slot = vec![false; self.slots.len()];
if !self.slots.is_empty() && n_proposals > 0 {
let slot_boxes: Vec<f32> = self.slots.iter().flat_map(|s| s.bbox).collect();
let prop_boxes: Vec<f32> = proposals.iter().flat_map(|b| b.iter().copied()).collect();
let slot_tensor = Tensor::from_vec(slot_boxes, &[self.slots.len(), 4]).unwrap();
let prop_tensor = Tensor::from_vec(prop_boxes, &[n_proposals, 4]).unwrap();
let iou_matrix = box_iou(&slot_tensor, &prop_tensor);
let iou_data = iou_matrix.to_vec();
let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
for si in 0..self.slots.len() {
for pi in 0..n_proposals {
let iou = iou_data[si * n_proposals + pi];
if iou > self.match_threshold {
pairs.push((si, pi, iou));
}
}
}
pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
for (si, pi, _iou) in pairs {
if matched_slot[si] || matched_prop[pi] {
continue;
}
matched_slot[si] = true;
matched_prop[pi] = true;
let roi_feat = self.extract_roi(features, &proposals[pi], spatial_scale);
let projected = self.roi_project.forward(&roi_feat).relu();
let new_hidden = self.gru.forward_step(&projected, &self.slots[si].hidden);
self.slots[si].hidden = new_hidden;
self.slots[si].bbox = proposals[pi];
self.slots[si].confidence =
(self.slots[si].confidence * 0.7 + scores[pi] * 0.3).min(1.0);
self.slots[si].frames_tracked += 1;
self.slots[si].frames_missing = 0;
}
}
let max_missing = self.max_missing;
let decay = self.decay_rate;
let mut to_remove = Vec::new();
for (si, slot) in self.slots.iter_mut().enumerate() {
if !matched_slot.get(si).copied().unwrap_or(false) {
slot.frames_missing += 1;
slot.confidence *= decay;
if slot.frames_missing > max_missing {
to_remove.push(si);
}
}
}
for &si in to_remove.iter().rev() {
self.slots.remove(si);
}
for (pi, &score) in scores.iter().enumerate() {
if !matched_prop[pi] && score > 0.3 {
self.spawn_slot(&proposals[pi], score, features, spatial_scale);
}
}
self.slots.iter().map(|s| s.hidden.clone()).collect()
}
fn extract_roi(&self, features: &Variable, bbox: &[f32; 4], spatial_scale: f32) -> Variable {
let roi_data = vec![0.0, bbox[0], bbox[1], bbox[2], bbox[3]];
let roi_tensor = Tensor::from_vec(roi_data, &[1, 5]).unwrap();
let roi_out = roi_align(
&features.data(),
&roi_tensor,
(self.roi_size, self.roi_size),
spatial_scale,
);
let flat = roi_out.to_vec();
Variable::new(
Tensor::from_vec(
flat,
&[1, self.feat_channels * self.roi_size * self.roi_size],
)
.unwrap(),
false,
)
}
fn spawn_slot(
&mut self,
bbox: &[f32; 4],
confidence: f32,
features: &Variable,
spatial_scale: f32,
) {
let roi_feat = self.extract_roi(features, bbox, spatial_scale);
let projected = self.roi_project.forward(&roi_feat).relu();
let hidden = Variable::new(
Tensor::from_vec(vec![0.0; self.hidden_size], &[1, self.hidden_size]).unwrap(),
false,
);
let initial_hidden = self.gru.forward_step(&projected, &hidden);
let id = self.next_id;
self.next_id += 1;
self.slots.push(ObjectSlot {
id,
bbox: *bbox,
confidence,
hidden: initial_hidden,
class_logits: Vec::new(),
frames_tracked: 1,
frames_missing: 0,
});
}
pub fn slots(&self) -> &[ObjectSlot] {
&self.slots
}
pub fn num_tracked(&self) -> usize {
self.slots.len()
}
pub fn reset(&mut self) {
self.slots.clear();
self.next_id = 1;
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.roi_project.parameters());
p.extend(self.gru.parameters());
p
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_features(c: usize, h: usize, w: usize) -> Variable {
Variable::new(
Tensor::from_vec(vec![0.1; c * h * w], &[1, c, h, w]).unwrap(),
false,
)
}
#[test]
fn test_memory_bank_creation() {
let bank = ObjectMemoryBank::default_config();
assert_eq!(bank.num_tracked(), 0);
}
#[test]
fn test_memory_bank_spawn() {
let mut bank = ObjectMemoryBank::default_config();
let features = make_features(96, 20, 20);
let proposals = vec![[10.0, 10.0, 50.0, 50.0]];
let scores = vec![0.9];
let hiddens = bank.update(&features, &proposals, &scores, 0.25);
assert_eq!(bank.num_tracked(), 1);
assert_eq!(hiddens.len(), 1);
assert_eq!(hiddens[0].shape(), vec![1, 64]);
}
#[test]
fn test_memory_bank_match_and_update() {
let mut bank = ObjectMemoryBank::default_config();
let features = make_features(96, 20, 20);
let proposals1 = vec![[10.0, 10.0, 50.0, 50.0]];
let scores1 = vec![0.9];
bank.update(&features, &proposals1, &scores1, 0.25);
assert_eq!(bank.num_tracked(), 1);
assert_eq!(bank.slots()[0].frames_tracked, 1);
let proposals2 = vec![[12.0, 12.0, 52.0, 52.0]]; let scores2 = vec![0.85];
bank.update(&features, &proposals2, &scores2, 0.25);
assert_eq!(bank.num_tracked(), 1);
assert_eq!(bank.slots()[0].frames_tracked, 2);
assert_eq!(bank.slots()[0].id, 1); }
#[test]
fn test_memory_bank_decay() {
let mut bank = ObjectMemoryBank::default_config();
bank.max_missing = 3;
let features = make_features(96, 10, 10);
bank.update(&features, &[[5.0, 5.0, 15.0, 15.0]], &[0.8], 0.5);
assert_eq!(bank.num_tracked(), 1);
for _ in 0..4 {
bank.update(&features, &[], &[], 0.5);
}
assert_eq!(bank.num_tracked(), 0);
}
#[test]
fn test_memory_bank_reset() {
let mut bank = ObjectMemoryBank::default_config();
let features = make_features(96, 10, 10);
bank.update(&features, &[[5.0, 5.0, 15.0, 15.0]], &[0.8], 0.5);
bank.update(&features, &[[30.0, 30.0, 60.0, 60.0]], &[0.7], 0.5);
assert_eq!(bank.num_tracked(), 2);
bank.reset();
assert_eq!(bank.num_tracked(), 0);
}
#[test]
fn test_memory_bank_param_count() {
let bank = ObjectMemoryBank::default_config();
let total: usize = bank.parameters().iter().map(|p| p.numel()).sum();
assert!(total > 10_000);
assert!(total < 100_000);
}
}