models/
lib.rs

1//! Burn ML models for object detection in the CortenForge stack.
2//!
3//! This crate defines the neural network architectures used for detection:
4//! - `LinearClassifier`: Simple feedforward network for binary classification.
5//! - `MultiboxModel`: Multi-box detection model with spatial output heads.
6//!
7//! These are pure Burn Modules with no awareness of the Detector trait. The `inference`
8//! crate wraps them into Detector implementations for runtime use.
9//!
10//! ## Design Note
11//! Model types use descriptive names (Classifier, Model) rather than "Detector" suffix,
12//! as they are architectural components, not full detector implementations.
13//!
14//! ## Stability
15//!
16//! Model architectures (`LinearClassifier`, `MultiboxModel`) and their config types are **stable**.
17//! The forward pass signatures and checkpoint format will not change in a backwards-incompatible
18//! way without a major version bump.
19
20use burn::module::Module;
21use burn::nn;
22use burn::tensor::activation::{relu, sigmoid};
23use burn::tensor::Tensor;
24
25#[derive(Debug, Clone)]
26pub struct LinearClassifierConfig {
27    pub hidden: usize,
28}
29
30impl Default for LinearClassifierConfig {
31    fn default() -> Self {
32        Self { hidden: 64 }
33    }
34}
35
36#[derive(Debug, Module)]
37pub struct LinearClassifier<B: burn::tensor::backend::Backend> {
38    linear1: nn::Linear<B>,
39    linear2: nn::Linear<B>,
40}
41
42impl<B: burn::tensor::backend::Backend> LinearClassifier<B> {
43    pub fn new(cfg: LinearClassifierConfig, device: &B::Device) -> Self {
44        let linear1 = nn::LinearConfig::new(4, cfg.hidden).init(device);
45        let linear2 = nn::LinearConfig::new(cfg.hidden, 1).init(device);
46        Self { linear1, linear2 }
47    }
48
49    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
50        let x = self.linear1.forward(input);
51        let x = relu(x);
52        self.linear2.forward(x)
53    }
54}
55
56#[derive(Debug, Clone)]
57pub struct MultiboxModelConfig {
58    pub hidden: usize,
59    pub depth: usize,
60    pub max_boxes: usize,
61    pub input_dim: Option<usize>,
62}
63
64impl Default for MultiboxModelConfig {
65    fn default() -> Self {
66        Self {
67            hidden: 128,
68            depth: 2,
69            max_boxes: 64,
70            input_dim: None,
71        }
72    }
73}
74
75#[derive(Debug, Module)]
76pub struct MultiboxModel<B: burn::tensor::backend::Backend> {
77    stem: nn::Linear<B>,
78    blocks: Vec<nn::Linear<B>>,
79    box_head: nn::Linear<B>,
80    score_head: nn::Linear<B>,
81    max_boxes: usize,
82    input_dim: usize,
83}
84
85impl<B: burn::tensor::backend::Backend> MultiboxModel<B> {
86    pub fn new(cfg: MultiboxModelConfig, device: &B::Device) -> Self {
87        let input_dim = cfg.input_dim.unwrap_or(4);
88        let stem = nn::LinearConfig::new(input_dim, cfg.hidden).init(device);
89        let mut blocks = Vec::new();
90        for _ in 0..cfg.depth {
91            blocks.push(nn::LinearConfig::new(cfg.hidden, cfg.hidden).init(device));
92        }
93        let box_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1) * 4).init(device);
94        let score_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1)).init(device);
95        Self {
96            stem,
97            blocks,
98            box_head,
99            score_head,
100            max_boxes: cfg.max_boxes.max(1),
101            input_dim,
102        }
103    }
104
105    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
106        let mut x = relu(self.stem.forward(input));
107        for block in &self.blocks {
108            x = relu(block.forward(x));
109        }
110        self.score_head.forward(x)
111    }
112
113    /// Multibox forward: returns (boxes, scores) with shape [B, max_boxes, 4] and [B, max_boxes].
114    /// Boxes/scores are passed through sigmoid to keep them in a stable range.
115    pub fn forward_multibox(&self, input: Tensor<B, 2>) -> (Tensor<B, 3>, Tensor<B, 2>) {
116        let mut x = relu(self.stem.forward(input));
117        for block in &self.blocks {
118            x = relu(block.forward(x));
119        }
120        let boxes_flat = sigmoid(self.box_head.forward(x.clone()));
121        let scores = sigmoid(self.score_head.forward(x));
122        let batch = boxes_flat.dims()[0];
123        let boxes = boxes_flat.reshape([batch, self.max_boxes, 4]);
124
125        // Reorder/clamp to enforce x0 <= x1, y0 <= y1 within [0,1] using arithmetic.
126        let x0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 0..1]);
127        let y0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 1..2]);
128        let x1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 2..3]);
129        let y1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 3..4]);
130
131        let dx = x0.clone() - x1.clone();
132        let dy = y0.clone() - y1.clone();
133        let half = 0.5;
134
135        let x_min = (x0.clone() + x1.clone() - dx.clone().abs()) * half;
136        let x_max = (x0 + x1 + dx.abs()) * half;
137        let y_min = (y0.clone() + y1.clone() - dy.clone().abs()) * half;
138        let y_max = (y0 + y1 + dy.abs()) * half;
139
140        let x_min = x_min.clamp(0.0, 1.0);
141        let x_max = x_max.clamp(0.0, 1.0);
142        let y_min = y_min.clamp(0.0, 1.0);
143        let y_max = y_max.clamp(0.0, 1.0);
144
145        let boxes_ordered = burn::tensor::Tensor::cat(vec![x_min, y_min, x_max, y_max], 2);
146
147        (boxes_ordered, scores)
148    }
149}
150
151pub mod prelude {
152    pub use super::{LinearClassifier, LinearClassifierConfig, MultiboxModel, MultiboxModelConfig};
153}