1use burn::module::Module;
21use burn::nn;
22use burn::tensor::activation::{relu, sigmoid};
23use burn::tensor::Tensor;
24
25#[derive(Debug, Clone)]
26pub struct LinearClassifierConfig {
27 pub hidden: usize,
28}
29
30impl Default for LinearClassifierConfig {
31 fn default() -> Self {
32 Self { hidden: 64 }
33 }
34}
35
36#[derive(Debug, Module)]
37pub struct LinearClassifier<B: burn::tensor::backend::Backend> {
38 linear1: nn::Linear<B>,
39 linear2: nn::Linear<B>,
40}
41
42impl<B: burn::tensor::backend::Backend> LinearClassifier<B> {
43 pub fn new(cfg: LinearClassifierConfig, device: &B::Device) -> Self {
44 let linear1 = nn::LinearConfig::new(4, cfg.hidden).init(device);
45 let linear2 = nn::LinearConfig::new(cfg.hidden, 1).init(device);
46 Self { linear1, linear2 }
47 }
48
49 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
50 let x = self.linear1.forward(input);
51 let x = relu(x);
52 self.linear2.forward(x)
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct MultiboxModelConfig {
58 pub hidden: usize,
59 pub depth: usize,
60 pub max_boxes: usize,
61 pub input_dim: Option<usize>,
62}
63
64impl Default for MultiboxModelConfig {
65 fn default() -> Self {
66 Self {
67 hidden: 128,
68 depth: 2,
69 max_boxes: 64,
70 input_dim: None,
71 }
72 }
73}
74
75#[derive(Debug, Module)]
76pub struct MultiboxModel<B: burn::tensor::backend::Backend> {
77 stem: nn::Linear<B>,
78 blocks: Vec<nn::Linear<B>>,
79 box_head: nn::Linear<B>,
80 score_head: nn::Linear<B>,
81 max_boxes: usize,
82 input_dim: usize,
83}
84
85impl<B: burn::tensor::backend::Backend> MultiboxModel<B> {
86 pub fn new(cfg: MultiboxModelConfig, device: &B::Device) -> Self {
87 let input_dim = cfg.input_dim.unwrap_or(4);
88 let stem = nn::LinearConfig::new(input_dim, cfg.hidden).init(device);
89 let mut blocks = Vec::new();
90 for _ in 0..cfg.depth {
91 blocks.push(nn::LinearConfig::new(cfg.hidden, cfg.hidden).init(device));
92 }
93 let box_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1) * 4).init(device);
94 let score_head = nn::LinearConfig::new(cfg.hidden, cfg.max_boxes.max(1)).init(device);
95 Self {
96 stem,
97 blocks,
98 box_head,
99 score_head,
100 max_boxes: cfg.max_boxes.max(1),
101 input_dim,
102 }
103 }
104
105 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
106 let mut x = relu(self.stem.forward(input));
107 for block in &self.blocks {
108 x = relu(block.forward(x));
109 }
110 self.score_head.forward(x)
111 }
112
113 pub fn forward_multibox(&self, input: Tensor<B, 2>) -> (Tensor<B, 3>, Tensor<B, 2>) {
116 let mut x = relu(self.stem.forward(input));
117 for block in &self.blocks {
118 x = relu(block.forward(x));
119 }
120 let boxes_flat = sigmoid(self.box_head.forward(x.clone()));
121 let scores = sigmoid(self.score_head.forward(x));
122 let batch = boxes_flat.dims()[0];
123 let boxes = boxes_flat.reshape([batch, self.max_boxes, 4]);
124
125 let x0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 0..1]);
127 let y0 = boxes.clone().slice([0..batch, 0..self.max_boxes, 1..2]);
128 let x1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 2..3]);
129 let y1 = boxes.clone().slice([0..batch, 0..self.max_boxes, 3..4]);
130
131 let dx = x0.clone() - x1.clone();
132 let dy = y0.clone() - y1.clone();
133 let half = 0.5;
134
135 let x_min = (x0.clone() + x1.clone() - dx.clone().abs()) * half;
136 let x_max = (x0 + x1 + dx.abs()) * half;
137 let y_min = (y0.clone() + y1.clone() - dy.clone().abs()) * half;
138 let y_max = (y0 + y1 + dy.abs()) * half;
139
140 let x_min = x_min.clamp(0.0, 1.0);
141 let x_max = x_max.clamp(0.0, 1.0);
142 let y_min = y_min.clamp(0.0, 1.0);
143 let y_max = y_max.clamp(0.0, 1.0);
144
145 let boxes_ordered = burn::tensor::Tensor::cat(vec![x_min, y_min, x_max, y_max], 2);
146
147 (boxes_ordered, scores)
148 }
149}
150
151pub mod prelude {
152 pub use super::{LinearClassifier, LinearClassifierConfig, MultiboxModel, MultiboxModelConfig};
153}