models/
lib.rs

1use burn::module::Module;
2use burn::nn;
3use burn::tensor::activation::{relu, sigmoid};
4use burn::tensor::Tensor;
5
6#[derive(Debug, Clone)]
7pub struct TinyDetConfig {
8    pub hidden: usize,
9}
10
11impl Default for TinyDetConfig {
12    fn default() -> Self {
13        Self { hidden: 64 }
14    }
15}
16
17#[derive(Debug, Module)]
18pub struct TinyDet<B: burn::tensor::backend::Backend> {
19    linear1: nn::Linear<B>,
20    linear2: nn::Linear<B>,
21}
22
23impl<B: burn::tensor::backend::Backend> TinyDet<B> {
24    pub fn new(cfg: TinyDetConfig, device: &B::Device) -> Self {
25        let linear1 = nn::LinearConfig::new(4, cfg.hidden).init(device);
26        let linear2 = nn::LinearConfig::new(cfg.hidden, 1).init(device);
27        Self { linear1, linear2 }
28    }
29
30    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
31        let x = self.linear1.forward(input);
32        let x = relu(x);
33        self.linear2.forward(x)
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct BigDetConfig {
39    pub hidden: usize,
40    pub depth: usize,
41    pub max_boxes: usize,
42    pub input_dim: Option<usize>,
43}
44
45impl Default for BigDetConfig {
46    fn default() -> Self {
47        Self {
48            hidden: 128,
49            depth: 2,
50            max_boxes: 64,
51            input_dim: None,
52        }
53    }
54}
55
56#[derive(Debug, Module)]
57pub struct BigDet<B: burn::tensor::backend::Backend> {
58    stem: nn::Linear<B>,
59    blocks: Vec<nn::Linear<B>>,
60    box_head: nn::Linear<B>,
61    score_head: nn::Linear<B>,
62    max_boxes: usize,
63    input_dim: usize,
64}
65
66impl<B: burn::tensor::backend::Backend> BigDet<B> {
67    pub fn new(cfg: BigDetConfig, device: &B::Device) -> Self {
68        let input_dim = cfg.input_dim.unwrap_or(4);
69        let stem = nn::LinearConfig::new(input_dim, cfg.hidden).init(device);
70        let mut blocks = Vec::new();
71        for _ in 0..cfg.depth {
72            blocks.push(nn::LinearConfig::new(cfg.hidden, cfg.hidden).init(device));
73        }
74        let box_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1) * 4).init(device);
75        let score_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1)).init(device);
76        Self {
77            stem,
78            blocks,
79            box_head,
80            score_head,
81            max_boxes: cfg.max_boxes.max(1),
82            input_dim,
83        }
84    }
85
86    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
87        let mut x = relu(self.stem.forward(input));
88        for block in &self.blocks {
89            x = relu(block.forward(x));
90        }
91        self.score_head.forward(x)
92    }
93
94    /// Multibox forward: returns (boxes, scores) with shape [B, max_boxes, 4] and [B, max_boxes].
95    /// Boxes/scores are passed through sigmoid to keep them in a stable range.
96    pub fn forward_multibox(&self, input: Tensor<B, 2>) -> (Tensor<B, 3>, Tensor<B, 2>) {
97        let mut x = relu(self.stem.forward(input));
98        for block in &self.blocks {
99            x = relu(block.forward(x));
100        }
101        let boxes_flat = sigmoid(self.box_head.forward(x.clone()));
102        let scores = sigmoid(self.score_head.forward(x));
103        let batch = boxes_flat.dims()[0];
104        let boxes = boxes_flat.reshape([batch, self.max_boxes, 4]);
105
106        // Reorder/clamp to enforce x0 <= x1, y0 <= y1 within [0,1] using arithmetic.
107        let x0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 0..1]);
108        let y0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 1..2]);
109        let x1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 2..3]);
110        let y1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 3..4]);
111
112        let dx = x0.clone() - x1.clone();
113        let dy = y0.clone() - y1.clone();
114        let half = 0.5;
115
116        let x_min = (x0.clone() + x1.clone() - dx.clone().abs()) * half;
117        let x_max = (x0 + x1 + dx.abs()) * half;
118        let y_min = (y0.clone() + y1.clone() - dy.clone().abs()) * half;
119        let y_max = (y0 + y1 + dy.abs()) * half;
120
121        let x_min = x_min.clamp(0.0, 1.0);
122        let x_max = x_max.clamp(0.0, 1.0);
123        let y_min = y_min.clamp(0.0, 1.0);
124        let y_max = y_max.clamp(0.0, 1.0);
125
126        let boxes_ordered = burn::tensor::Tensor::cat(vec![x_min, y_min, x_max, y_max], 2);
127
128        (boxes_ordered, scores)
129    }
130}
131
132pub mod prelude {
133    pub use super::{BigDet, BigDetConfig, TinyDet, TinyDetConfig};
134}