1use burn::module::Module;
2use burn::nn;
3use burn::tensor::activation::{relu, sigmoid};
4use burn::tensor::Tensor;
5
6#[derive(Debug, Clone)]
7pub struct TinyDetConfig {
8 pub hidden: usize,
9}
10
11impl Default for TinyDetConfig {
12 fn default() -> Self {
13 Self { hidden: 64 }
14 }
15}
16
17#[derive(Debug, Module)]
18pub struct TinyDet<B: burn::tensor::backend::Backend> {
19 linear1: nn::Linear<B>,
20 linear2: nn::Linear<B>,
21}
22
23impl<B: burn::tensor::backend::Backend> TinyDet<B> {
24 pub fn new(cfg: TinyDetConfig, device: &B::Device) -> Self {
25 let linear1 = nn::LinearConfig::new(4, cfg.hidden).init(device);
26 let linear2 = nn::LinearConfig::new(cfg.hidden, 1).init(device);
27 Self { linear1, linear2 }
28 }
29
30 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
31 let x = self.linear1.forward(input);
32 let x = relu(x);
33 self.linear2.forward(x)
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct BigDetConfig {
39 pub hidden: usize,
40 pub depth: usize,
41 pub max_boxes: usize,
42 pub input_dim: Option<usize>,
43}
44
45impl Default for BigDetConfig {
46 fn default() -> Self {
47 Self {
48 hidden: 128,
49 depth: 2,
50 max_boxes: 64,
51 input_dim: None,
52 }
53 }
54}
55
56#[derive(Debug, Module)]
57pub struct BigDet<B: burn::tensor::backend::Backend> {
58 stem: nn::Linear<B>,
59 blocks: Vec<nn::Linear<B>>,
60 box_head: nn::Linear<B>,
61 score_head: nn::Linear<B>,
62 max_boxes: usize,
63 input_dim: usize,
64}
65
66impl<B: burn::tensor::backend::Backend> BigDet<B> {
67 pub fn new(cfg: BigDetConfig, device: &B::Device) -> Self {
68 let input_dim = cfg.input_dim.unwrap_or(4);
69 let stem = nn::LinearConfig::new(input_dim, cfg.hidden).init(device);
70 let mut blocks = Vec::new();
71 for _ in 0..cfg.depth {
72 blocks.push(nn::LinearConfig::new(cfg.hidden, cfg.hidden).init(device));
73 }
74 let box_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1) * 4).init(device);
75 let score_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1)).init(device);
76 Self {
77 stem,
78 blocks,
79 box_head,
80 score_head,
81 max_boxes: cfg.max_boxes.max(1),
82 input_dim,
83 }
84 }
85
86 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
87 let mut x = relu(self.stem.forward(input));
88 for block in &self.blocks {
89 x = relu(block.forward(x));
90 }
91 self.score_head.forward(x)
92 }
93
94 pub fn forward_multibox(&self, input: Tensor<B, 2>) -> (Tensor<B, 3>, Tensor<B, 2>) {
97 let mut x = relu(self.stem.forward(input));
98 for block in &self.blocks {
99 x = relu(block.forward(x));
100 }
101 let boxes_flat = sigmoid(self.box_head.forward(x.clone()));
102 let scores = sigmoid(self.score_head.forward(x));
103 let batch = boxes_flat.dims()[0];
104 let boxes = boxes_flat.reshape([batch, self.max_boxes, 4]);
105
106 let x0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 0..1]);
108 let y0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 1..2]);
109 let x1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 2..3]);
110 let y1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 3..4]);
111
112 let dx = x0.clone() - x1.clone();
113 let dy = y0.clone() - y1.clone();
114 let half = 0.5;
115
116 let x_min = (x0.clone() + x1.clone() - dx.clone().abs()) * half;
117 let x_max = (x0 + x1 + dx.abs()) * half;
118 let y_min = (y0.clone() + y1.clone() - dy.clone().abs()) * half;
119 let y_max = (y0 + y1 + dy.abs()) * half;
120
121 let x_min = x_min.clamp(0.0, 1.0);
122 let x_max = x_max.clamp(0.0, 1.0);
123 let y_min = y_min.clamp(0.0, 1.0);
124 let y_max = y_max.clamp(0.0, 1.0);
125
126 let boxes_ordered = burn::tensor::Tensor::cat(vec![x_min, y_min, x_max, y_max], 2);
127
128 (boxes_ordered, scores)
129 }
130}
131
132pub mod prelude {
133 pub use super::{BigDet, BigDetConfig, TinyDet, TinyDetConfig};
134}