candle_transformers/models/
efficientnet.rs

1//! Implementation of EfficientBert, an efficient variant of BERT for computer vision tasks.
2//!
3//! See:
4//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
5//!
6use candle::{Context, Result, Tensor, D};
7use candle_nn as nn;
8use nn::{Module, VarBuilder};
9
10// Based on the Python version from torchvision.
11// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
12#[derive(Debug, Clone, Copy)]
13pub struct MBConvConfig {
14    expand_ratio: f64,
15    kernel: usize,
16    stride: usize,
17    input_channels: usize,
18    out_channels: usize,
19    num_layers: usize,
20}
21
22fn make_divisible(v: f64, divisor: usize) -> usize {
23    let min_value = divisor;
24    let new_v = usize::max(
25        min_value,
26        (v + divisor as f64 * 0.5) as usize / divisor * divisor,
27    );
28    if (new_v as f64) < 0.9 * v {
29        new_v + divisor
30    } else {
31        new_v
32    }
33}
34
35fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
36    let bneck_conf = |e, k, s, i, o, n| {
37        let input_channels = make_divisible(i as f64 * width_mult, 8);
38        let out_channels = make_divisible(o as f64 * width_mult, 8);
39        let num_layers = (n as f64 * depth_mult).ceil() as usize;
40        MBConvConfig {
41            expand_ratio: e,
42            kernel: k,
43            stride: s,
44            input_channels,
45            out_channels,
46            num_layers,
47        }
48    };
49    vec![
50        bneck_conf(1., 3, 1, 32, 16, 1),
51        bneck_conf(6., 3, 2, 16, 24, 2),
52        bneck_conf(6., 5, 2, 24, 40, 2),
53        bneck_conf(6., 3, 2, 40, 80, 3),
54        bneck_conf(6., 5, 1, 80, 112, 3),
55        bneck_conf(6., 5, 2, 112, 192, 4),
56        bneck_conf(6., 3, 1, 192, 320, 1),
57    ]
58}
59
60impl MBConvConfig {
61    pub fn b0() -> Vec<Self> {
62        bneck_confs(1.0, 1.0)
63    }
64    pub fn b1() -> Vec<Self> {
65        bneck_confs(1.0, 1.1)
66    }
67    pub fn b2() -> Vec<Self> {
68        bneck_confs(1.1, 1.2)
69    }
70    pub fn b3() -> Vec<Self> {
71        bneck_confs(1.2, 1.4)
72    }
73    pub fn b4() -> Vec<Self> {
74        bneck_confs(1.4, 1.8)
75    }
76    pub fn b5() -> Vec<Self> {
77        bneck_confs(1.6, 2.2)
78    }
79    pub fn b6() -> Vec<Self> {
80        bneck_confs(1.8, 2.6)
81    }
82    pub fn b7() -> Vec<Self> {
83        bneck_confs(2.0, 3.1)
84    }
85}
86
87/// Conv2D with same padding.
88#[derive(Debug)]
89struct Conv2DSame {
90    conv2d: nn::Conv2d,
91    s: usize,
92    k: usize,
93}
94
95impl Conv2DSame {
96    fn new(
97        vb: VarBuilder,
98        i: usize,
99        o: usize,
100        k: usize,
101        stride: usize,
102        groups: usize,
103        bias: bool,
104    ) -> Result<Self> {
105        let conv_config = nn::Conv2dConfig {
106            stride,
107            groups,
108            ..Default::default()
109        };
110        let conv2d = if bias {
111            nn::conv2d(i, o, k, conv_config, vb)?
112        } else {
113            nn::conv2d_no_bias(i, o, k, conv_config, vb)?
114        };
115        Ok(Self {
116            conv2d,
117            s: stride,
118            k,
119        })
120    }
121}
122
123impl Module for Conv2DSame {
124    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
125        let s = self.s;
126        let k = self.k;
127        let (_, _, ih, iw) = xs.dims4()?;
128        let oh = ih.div_ceil(s);
129        let ow = iw.div_ceil(s);
130        let pad_h = usize::max((oh - 1) * s + k - ih, 0);
131        let pad_w = usize::max((ow - 1) * s + k - iw, 0);
132        if pad_h > 0 || pad_w > 0 {
133            let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
134            let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
135            self.conv2d.forward(&xs)
136        } else {
137            self.conv2d.forward(xs)
138        }
139    }
140}
141
142#[derive(Debug)]
143struct ConvNormActivation {
144    conv2d: Conv2DSame,
145    bn2d: nn::BatchNorm,
146    activation: bool,
147}
148
149impl ConvNormActivation {
150    fn new(
151        vb: VarBuilder,
152        i: usize,
153        o: usize,
154        k: usize,
155        stride: usize,
156        groups: usize,
157    ) -> Result<Self> {
158        let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
159        let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
160        Ok(Self {
161            conv2d,
162            bn2d,
163            activation: true,
164        })
165    }
166
167    fn no_activation(self) -> Self {
168        Self {
169            activation: false,
170            ..self
171        }
172    }
173}
174
175impl Module for ConvNormActivation {
176    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
177        let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
178        if self.activation {
179            swish(&xs)
180        } else {
181            Ok(xs)
182        }
183    }
184}
185
186#[derive(Debug)]
187struct SqueezeExcitation {
188    fc1: Conv2DSame,
189    fc2: Conv2DSame,
190}
191
192impl SqueezeExcitation {
193    fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
194        let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
195        let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
196        Ok(Self { fc1, fc2 })
197    }
198}
199
200impl Module for SqueezeExcitation {
201    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
202        let residual = xs;
203        // equivalent to adaptive_avg_pool2d([1, 1])
204        let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
205        let xs = self.fc1.forward(&xs)?;
206        let xs = swish(&xs)?;
207        let xs = self.fc2.forward(&xs)?;
208        let xs = nn::ops::sigmoid(&xs)?;
209        residual.broadcast_mul(&xs)
210    }
211}
212
213#[derive(Debug)]
214struct MBConv {
215    expand_cna: Option<ConvNormActivation>,
216    depthwise_cna: ConvNormActivation,
217    squeeze_excitation: SqueezeExcitation,
218    project_cna: ConvNormActivation,
219    config: MBConvConfig,
220}
221
222impl MBConv {
223    fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
224        let vb = vb.pp("block");
225        let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
226        let expand_cna = if exp != c.input_channels {
227            Some(ConvNormActivation::new(
228                vb.pp("0"),
229                c.input_channels,
230                exp,
231                1,
232                1,
233                1,
234            )?)
235        } else {
236            None
237        };
238        let start_index = if expand_cna.is_some() { 1 } else { 0 };
239        let depthwise_cna =
240            ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
241        let squeeze_channels = usize::max(1, c.input_channels / 4);
242        let squeeze_excitation =
243            SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
244        let project_cna =
245            ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
246                .no_activation();
247        Ok(Self {
248            expand_cna,
249            depthwise_cna,
250            squeeze_excitation,
251            project_cna,
252            config: c,
253        })
254    }
255}
256
257impl Module for MBConv {
258    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
259        let use_res_connect =
260            self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
261        let ys = match &self.expand_cna {
262            Some(expand_cna) => expand_cna.forward(xs)?,
263            None => xs.clone(),
264        };
265        let ys = self.depthwise_cna.forward(&ys)?;
266        let ys = self.squeeze_excitation.forward(&ys)?;
267        let ys = self.project_cna.forward(&ys)?;
268        if use_res_connect {
269            ys + xs
270        } else {
271            Ok(ys)
272        }
273    }
274}
275
276fn swish(s: &Tensor) -> Result<Tensor> {
277    s * nn::ops::sigmoid(s)?
278}
279
280#[derive(Debug)]
281pub struct EfficientNet {
282    init_cna: ConvNormActivation,
283    blocks: Vec<MBConv>,
284    final_cna: ConvNormActivation,
285    classifier: nn::Linear,
286}
287
288impl EfficientNet {
289    pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
290        let f_p = p.pp("features");
291        let first_in_c = configs[0].input_channels;
292        let last_out_c = configs.last().context("no last")?.out_channels;
293        let final_out_c = 4 * last_out_c;
294        let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
295        let nconfigs = configs.len();
296        let mut blocks = vec![];
297        for (index, cnf) in configs.into_iter().enumerate() {
298            let f_p = f_p.pp(index + 1);
299            for r_index in 0..cnf.num_layers {
300                let cnf = if r_index == 0 {
301                    cnf
302                } else {
303                    MBConvConfig {
304                        input_channels: cnf.out_channels,
305                        stride: 1,
306                        ..cnf
307                    }
308                };
309                blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
310            }
311        }
312        let final_cna =
313            ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
314        let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
315        Ok(Self {
316            init_cna,
317            blocks,
318            final_cna,
319            classifier,
320        })
321    }
322}
323
324impl Module for EfficientNet {
325    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
326        let mut xs = self.init_cna.forward(xs)?;
327        for block in self.blocks.iter() {
328            xs = block.forward(&xs)?
329        }
330        let xs = self.final_cna.forward(&xs)?;
331        // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
332        let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
333        self.classifier.forward(&xs)
334    }
335}