1use candle::{Context, Result, Tensor, D};
7use candle_nn as nn;
8use nn::{Module, VarBuilder};
9
10#[derive(Debug, Clone, Copy)]
13pub struct MBConvConfig {
14 expand_ratio: f64,
15 kernel: usize,
16 stride: usize,
17 input_channels: usize,
18 out_channels: usize,
19 num_layers: usize,
20}
21
22fn make_divisible(v: f64, divisor: usize) -> usize {
23 let min_value = divisor;
24 let new_v = usize::max(
25 min_value,
26 (v + divisor as f64 * 0.5) as usize / divisor * divisor,
27 );
28 if (new_v as f64) < 0.9 * v {
29 new_v + divisor
30 } else {
31 new_v
32 }
33}
34
35fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
36 let bneck_conf = |e, k, s, i, o, n| {
37 let input_channels = make_divisible(i as f64 * width_mult, 8);
38 let out_channels = make_divisible(o as f64 * width_mult, 8);
39 let num_layers = (n as f64 * depth_mult).ceil() as usize;
40 MBConvConfig {
41 expand_ratio: e,
42 kernel: k,
43 stride: s,
44 input_channels,
45 out_channels,
46 num_layers,
47 }
48 };
49 vec![
50 bneck_conf(1., 3, 1, 32, 16, 1),
51 bneck_conf(6., 3, 2, 16, 24, 2),
52 bneck_conf(6., 5, 2, 24, 40, 2),
53 bneck_conf(6., 3, 2, 40, 80, 3),
54 bneck_conf(6., 5, 1, 80, 112, 3),
55 bneck_conf(6., 5, 2, 112, 192, 4),
56 bneck_conf(6., 3, 1, 192, 320, 1),
57 ]
58}
59
60impl MBConvConfig {
61 pub fn b0() -> Vec<Self> {
62 bneck_confs(1.0, 1.0)
63 }
64 pub fn b1() -> Vec<Self> {
65 bneck_confs(1.0, 1.1)
66 }
67 pub fn b2() -> Vec<Self> {
68 bneck_confs(1.1, 1.2)
69 }
70 pub fn b3() -> Vec<Self> {
71 bneck_confs(1.2, 1.4)
72 }
73 pub fn b4() -> Vec<Self> {
74 bneck_confs(1.4, 1.8)
75 }
76 pub fn b5() -> Vec<Self> {
77 bneck_confs(1.6, 2.2)
78 }
79 pub fn b6() -> Vec<Self> {
80 bneck_confs(1.8, 2.6)
81 }
82 pub fn b7() -> Vec<Self> {
83 bneck_confs(2.0, 3.1)
84 }
85}
86
87#[derive(Debug)]
89struct Conv2DSame {
90 conv2d: nn::Conv2d,
91 s: usize,
92 k: usize,
93}
94
95impl Conv2DSame {
96 fn new(
97 vb: VarBuilder,
98 i: usize,
99 o: usize,
100 k: usize,
101 stride: usize,
102 groups: usize,
103 bias: bool,
104 ) -> Result<Self> {
105 let conv_config = nn::Conv2dConfig {
106 stride,
107 groups,
108 ..Default::default()
109 };
110 let conv2d = if bias {
111 nn::conv2d(i, o, k, conv_config, vb)?
112 } else {
113 nn::conv2d_no_bias(i, o, k, conv_config, vb)?
114 };
115 Ok(Self {
116 conv2d,
117 s: stride,
118 k,
119 })
120 }
121}
122
123impl Module for Conv2DSame {
124 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
125 let s = self.s;
126 let k = self.k;
127 let (_, _, ih, iw) = xs.dims4()?;
128 let oh = ih.div_ceil(s);
129 let ow = iw.div_ceil(s);
130 let pad_h = usize::max((oh - 1) * s + k - ih, 0);
131 let pad_w = usize::max((ow - 1) * s + k - iw, 0);
132 if pad_h > 0 || pad_w > 0 {
133 let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
134 let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
135 self.conv2d.forward(&xs)
136 } else {
137 self.conv2d.forward(xs)
138 }
139 }
140}
141
142#[derive(Debug)]
143struct ConvNormActivation {
144 conv2d: Conv2DSame,
145 bn2d: nn::BatchNorm,
146 activation: bool,
147}
148
149impl ConvNormActivation {
150 fn new(
151 vb: VarBuilder,
152 i: usize,
153 o: usize,
154 k: usize,
155 stride: usize,
156 groups: usize,
157 ) -> Result<Self> {
158 let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
159 let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
160 Ok(Self {
161 conv2d,
162 bn2d,
163 activation: true,
164 })
165 }
166
167 fn no_activation(self) -> Self {
168 Self {
169 activation: false,
170 ..self
171 }
172 }
173}
174
175impl Module for ConvNormActivation {
176 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177 let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
178 if self.activation {
179 swish(&xs)
180 } else {
181 Ok(xs)
182 }
183 }
184}
185
186#[derive(Debug)]
187struct SqueezeExcitation {
188 fc1: Conv2DSame,
189 fc2: Conv2DSame,
190}
191
192impl SqueezeExcitation {
193 fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
194 let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
195 let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
196 Ok(Self { fc1, fc2 })
197 }
198}
199
200impl Module for SqueezeExcitation {
201 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
202 let residual = xs;
203 let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
205 let xs = self.fc1.forward(&xs)?;
206 let xs = swish(&xs)?;
207 let xs = self.fc2.forward(&xs)?;
208 let xs = nn::ops::sigmoid(&xs)?;
209 residual.broadcast_mul(&xs)
210 }
211}
212
213#[derive(Debug)]
214struct MBConv {
215 expand_cna: Option<ConvNormActivation>,
216 depthwise_cna: ConvNormActivation,
217 squeeze_excitation: SqueezeExcitation,
218 project_cna: ConvNormActivation,
219 config: MBConvConfig,
220}
221
222impl MBConv {
223 fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
224 let vb = vb.pp("block");
225 let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
226 let expand_cna = if exp != c.input_channels {
227 Some(ConvNormActivation::new(
228 vb.pp("0"),
229 c.input_channels,
230 exp,
231 1,
232 1,
233 1,
234 )?)
235 } else {
236 None
237 };
238 let start_index = if expand_cna.is_some() { 1 } else { 0 };
239 let depthwise_cna =
240 ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
241 let squeeze_channels = usize::max(1, c.input_channels / 4);
242 let squeeze_excitation =
243 SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
244 let project_cna =
245 ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
246 .no_activation();
247 Ok(Self {
248 expand_cna,
249 depthwise_cna,
250 squeeze_excitation,
251 project_cna,
252 config: c,
253 })
254 }
255}
256
257impl Module for MBConv {
258 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
259 let use_res_connect =
260 self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
261 let ys = match &self.expand_cna {
262 Some(expand_cna) => expand_cna.forward(xs)?,
263 None => xs.clone(),
264 };
265 let ys = self.depthwise_cna.forward(&ys)?;
266 let ys = self.squeeze_excitation.forward(&ys)?;
267 let ys = self.project_cna.forward(&ys)?;
268 if use_res_connect {
269 ys + xs
270 } else {
271 Ok(ys)
272 }
273 }
274}
275
276fn swish(s: &Tensor) -> Result<Tensor> {
277 s * nn::ops::sigmoid(s)?
278}
279
280#[derive(Debug)]
281pub struct EfficientNet {
282 init_cna: ConvNormActivation,
283 blocks: Vec<MBConv>,
284 final_cna: ConvNormActivation,
285 classifier: nn::Linear,
286}
287
288impl EfficientNet {
289 pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
290 let f_p = p.pp("features");
291 let first_in_c = configs[0].input_channels;
292 let last_out_c = configs.last().context("no last")?.out_channels;
293 let final_out_c = 4 * last_out_c;
294 let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
295 let nconfigs = configs.len();
296 let mut blocks = vec![];
297 for (index, cnf) in configs.into_iter().enumerate() {
298 let f_p = f_p.pp(index + 1);
299 for r_index in 0..cnf.num_layers {
300 let cnf = if r_index == 0 {
301 cnf
302 } else {
303 MBConvConfig {
304 input_channels: cnf.out_channels,
305 stride: 1,
306 ..cnf
307 }
308 };
309 blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
310 }
311 }
312 let final_cna =
313 ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
314 let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
315 Ok(Self {
316 init_cna,
317 blocks,
318 final_cna,
319 classifier,
320 })
321 }
322}
323
324impl Module for EfficientNet {
325 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
326 let mut xs = self.init_cna.forward(xs)?;
327 for block in self.blocks.iter() {
328 xs = block.forward(&xs)?
329 }
330 let xs = self.final_cna.forward(&xs)?;
331 let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
333 self.classifier.forward(&xs)
334 }
335}