Skip to main content

candle_transformers/models/
efficientvit.rs

1//! EfficientViT (MSRA) inference implementation based on timm.
2//!
3//! This crate provides an implementation of the EfficientViT model from Microsoft Research Asia
4//! for efficient image classification. The model uses cascaded group attention modules
5//! to achieve strong performance while maintaining low memory usage.
6//!
7//! The model was originally described in the paper:
8//! ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027)
9//!
10//! This implementation is based on the reference implementation from
11//! [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py).
12//!
13//! # Example Usage
14//!
15//! This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
16//! The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
17//!
18//!
19//! ```bash
20//! cargo run
21//!   --example efficientvit \
22//!   --release -- \
23//!   --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
24//!
25//! > loaded image Tensor[dims 3, 224, 224; f32]
26//! > model built
27//! > mountain bike, all-terrain bike, off-roader: 69.80%
28//! > unicycle, monocycle     : 13.03%
29//! > bicycle-built-for-two, tandem bicycle, tandem: 9.28%
30//! > crash helmet            : 2.25%
31//! > alp                     : 0.46%
32//! ```
33//!
34//! <div align=center>
35//!   <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
36//! </div>
37//!
38use candle::{Result, Tensor, D};
39use candle_nn::{
40    batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
41    VarBuilder,
42};
43
44#[derive(Clone)]
45pub struct Config {
46    channels: [usize; 3],
47    blocks: [usize; 3],
48    heads: [usize; 3],
49    kernels: [usize; 4],
50}
51
52impl Config {
53    pub fn m0() -> Self {
54        Self {
55            channels: [64, 128, 192],
56            blocks: [1, 2, 3],
57            heads: [4, 4, 4],
58            kernels: [5, 5, 5, 5],
59        }
60    }
61    pub fn m1() -> Self {
62        Self {
63            channels: [128, 144, 192],
64            blocks: [1, 2, 3],
65            heads: [2, 3, 3],
66            kernels: [7, 5, 3, 3],
67        }
68    }
69    pub fn m2() -> Self {
70        Self {
71            channels: [128, 192, 224],
72            blocks: [1, 2, 3],
73            heads: [4, 3, 2],
74            kernels: [7, 5, 3, 3],
75        }
76    }
77    pub fn m3() -> Self {
78        Self {
79            channels: [128, 240, 320],
80            blocks: [1, 2, 3],
81            heads: [4, 3, 4],
82            kernels: [5, 5, 5, 5],
83        }
84    }
85    pub fn m4() -> Self {
86        Self {
87            channels: [128, 256, 384],
88            blocks: [1, 2, 3],
89            heads: [4, 4, 4],
90            kernels: [7, 5, 3, 3],
91        }
92    }
93
94    pub fn m5() -> Self {
95        Self {
96            channels: [192, 288, 384],
97            blocks: [1, 3, 4],
98            heads: [3, 3, 4],
99            kernels: [7, 5, 3, 3],
100        }
101    }
102}
103
104fn efficientvit_stemblock(
105    in_channels: usize,
106    out_channels: usize,
107    vb: VarBuilder,
108) -> Result<Func<'static>> {
109    let conv2d_cfg = Conv2dConfig {
110        stride: 2,
111        padding: 1,
112        ..Default::default()
113    };
114
115    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
116    let conv = conv2d_no_bias(in_channels, out_channels, 3, conv2d_cfg, vb.pp("conv"))?;
117
118    Ok(Func::new(move |xs| {
119        let xs = xs.apply(&conv)?.apply_t(&bn, false)?;
120        Ok(xs)
121    }))
122}
123
124fn efficientvit_stem(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
125    let conv1 = efficientvit_stemblock(3, dim / 8, vb.pp("conv1"))?;
126    let conv2 = efficientvit_stemblock(dim / 8, dim / 4, vb.pp("conv2"))?;
127    let conv3 = efficientvit_stemblock(dim / 4, dim / 2, vb.pp("conv3"))?;
128    let conv4 = efficientvit_stemblock(dim / 2, dim, vb.pp("conv4"))?;
129
130    Ok(Func::new(move |xs| {
131        let xs = xs
132            .apply(&conv1)?
133            .relu()?
134            .apply(&conv2)?
135            .relu()?
136            .apply(&conv3)?
137            .relu()?
138            .apply(&conv4)?;
139
140        Ok(xs)
141    }))
142}
143
144fn depthwise_conv(
145    channels: usize,
146    kernel: usize,
147    stride: usize,
148    padding: usize,
149    vb: VarBuilder,
150) -> Result<Func<'static>> {
151    let conv2d_cfg = Conv2dConfig {
152        stride,
153        padding,
154        groups: channels,
155        ..Default::default()
156    };
157
158    let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
159    let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
160
161    Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
162}
163
164fn pointwise_conv(
165    in_channels: usize,
166    out_channels: usize,
167    vb: VarBuilder,
168) -> Result<Func<'static>> {
169    let conv2d_cfg = Conv2dConfig {
170        ..Default::default()
171    };
172
173    let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
174    let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
175
176    Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
177}
178
179fn conv_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
180    let pw1 = pointwise_conv(in_channels, out_channels, vb.pp("pw1"))?;
181    let pw2 = pointwise_conv(out_channels, in_channels, vb.pp("pw2"))?;
182
183    Ok(Func::new(move |xs| {
184        let xs = xs.apply(&pw1)?.relu()?.apply(&pw2)?;
185        Ok(xs)
186    }))
187}
188
189// Fixed per-stage resolutions
190const RESOLUTIONS: [usize; 3] = [14, 7, 4];
191
192// Attention block
193fn efficientvit_attn(
194    cfg: &Config,
195    stage: usize,
196    in_channels: usize,
197    vb: VarBuilder,
198) -> Result<Func<'static>> {
199    let cga = cascaded_group_attn(cfg, stage, in_channels, vb)?;
200
201    Ok(Func::new(move |xs| {
202        let mut xs = xs.clone();
203
204        let (b, c, h, w) = xs.dims4()?;
205        let win_res = 7; // Fixed window resolution
206        let pad_b = (win_res - h % win_res) % win_res;
207        let pad_r = (win_res - w % win_res) % win_res;
208        let ph = h + pad_b;
209        let pw = w + pad_r;
210        let nh = ph / win_res;
211        let nw = pw / win_res;
212
213        if RESOLUTIONS[stage] > win_res {
214            xs = xs.permute((0, 2, 3, 1))?;
215            xs = xs.pad_with_zeros(D::Minus1, 0, pad_r)?;
216            xs = xs.pad_with_zeros(D::Minus2, 0, pad_b)?;
217            xs = xs
218                .reshape((b, nh, win_res, nw, win_res, c))?
219                .transpose(2, 3)?;
220            xs = xs
221                .reshape((b * nh * nw, win_res, win_res, c))?
222                .permute((0, 3, 1, 2))?;
223        }
224
225        xs = xs.apply(&cga)?;
226
227        if RESOLUTIONS[stage] > win_res {
228            xs = xs
229                .permute((0, 2, 3, 1))?
230                .reshape((b, nh, nw, win_res, win_res, c))?;
231            xs = xs.transpose(2, 3)?.reshape((b, ph, pw, c))?;
232            xs = xs.permute((0, 3, 1, 2))?;
233        }
234
235        Ok(xs)
236    }))
237}
238
239// Cascaded group attention
240fn cascaded_group_attn(
241    cfg: &Config,
242    stage: usize,
243    in_channels: usize,
244    vb: VarBuilder,
245) -> Result<Func<'static>> {
246    let heads = cfg.heads[stage];
247    let key_dim = 16;
248
249    let val_dim = in_channels / heads;
250
251    let scale = (key_dim as f64).powf(-0.5);
252
253    let mut dws = Vec::with_capacity(heads);
254    let mut qkvs = Vec::with_capacity(heads);
255    for i in 0..heads {
256        dws.push(depthwise_conv(
257            key_dim,
258            cfg.kernels[i],
259            1,
260            cfg.kernels[i] / 2,
261            vb.pp(format!("dws.{i}")),
262        )?);
263
264        qkvs.push(pointwise_conv(
265            in_channels / heads,
266            in_channels / heads + 2 * key_dim,
267            vb.pp(format!("qkvs.{i}")),
268        )?);
269    }
270    let proj = pointwise_conv(in_channels, in_channels, vb.pp("proj.1"))?;
271
272    Ok(Func::new(move |xs| {
273        let (b, _, h, w) = xs.dims4()?;
274        let feats_in = xs.chunk(heads, 1)?;
275        let mut feats_out = Vec::with_capacity(heads);
276        let mut feat = feats_in[0].clone();
277
278        for i in 0..heads {
279            if i > 0 {
280                feat = (&feat + &feats_in[i])?;
281            }
282            feat = feat.apply(&qkvs[i])?;
283            let res = feat.reshape((b, (), h, w))?;
284            let q = res.narrow(1, 0, key_dim)?;
285            let k = res.narrow(1, key_dim, key_dim)?;
286            let v = res.narrow(1, 2 * key_dim, val_dim)?;
287
288            let q = q.apply(&dws[i])?;
289
290            let q = q.flatten_from(2)?;
291            let k = k.flatten_from(2)?;
292            let v = v.flatten_from(2)?;
293            let q = (q * scale)?;
294
295            let att = q.transpose(D::Minus2, D::Minus1)?.matmul(&k)?;
296            let att = softmax(&att, D::Minus1)?;
297            feat = v.matmul(&att.transpose(D::Minus2, D::Minus1)?)?;
298            feat = feat.reshape((b, val_dim, h, w))?;
299            feats_out.push(feat.clone());
300        }
301
302        let xs = Tensor::cat(&feats_out, 1)?;
303        let xs = xs.relu()?.apply(&proj)?;
304
305        Ok(xs)
306    }))
307}
308
309// Used by the downsampling layer
310fn squeeze_and_excitation(
311    in_channels: usize,
312    squeeze_channels: usize,
313    vb: VarBuilder,
314) -> Result<Func<'static>> {
315    let conv2d_cfg = Conv2dConfig {
316        ..Default::default()
317    };
318    let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
319    let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
320
321    Ok(Func::new(move |xs| {
322        let residual = xs;
323        let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
324        let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
325
326        residual.broadcast_mul(&xs)
327    }))
328}
329
330// Used by the downsampling layer
331fn patchmerge(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
332    let dim = in_channels;
333    let hid_dim = in_channels * 4;
334    let conv1 = pointwise_conv(dim, hid_dim, vb.pp("conv1"))?;
335    let conv2 = depthwise_conv(hid_dim, 3, 2, 1, vb.pp("conv2"))?;
336    let conv3 = pointwise_conv(hid_dim, out_channels, vb.pp("conv3"))?;
337    let se = squeeze_and_excitation(hid_dim, hid_dim / 4, vb.pp("se"))?;
338    Ok(Func::new(move |xs| {
339        let xs = xs
340            .apply(&conv1)?
341            .relu()?
342            .apply(&conv2)?
343            .relu()?
344            .apply(&se)?
345            .apply(&conv3)?;
346        Ok(xs)
347    }))
348}
349
350// Used by the downsampling layer
351fn res(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
352    let dw = depthwise_conv(dim, 3, 1, 1, vb.pp("0.m"))?;
353    let mlp = conv_mlp(dim, dim * 2, vb.pp("1.m"))?;
354    Ok(Func::new(move |xs| {
355        let mut xs = xs.clone();
356        xs = (&xs + &xs.apply(&dw)?)?;
357        xs = (&xs + &xs.apply(&mlp)?)?;
358        Ok(xs)
359    }))
360}
361
362// Downsampling
363fn efficientvit_downsample(
364    in_channels: usize,
365    out_channels: usize,
366    vb: VarBuilder,
367) -> Result<Func<'static>> {
368    let res1 = res(in_channels, vb.pp("res1"))?;
369    let res2 = res(out_channels, vb.pp("res2"))?;
370    let patchmerge = patchmerge(in_channels, out_channels, vb.pp("patchmerge"))?;
371    Ok(Func::new(move |xs| {
372        let xs = xs.apply(&res1)?.apply(&patchmerge)?.apply(&res2)?;
373        Ok(xs)
374    }))
375}
376
377fn efficientvit_block(
378    cfg: &Config,
379    stage: usize,
380    dim: usize,
381    vb: VarBuilder,
382) -> Result<Func<'static>> {
383    let dw0 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw0.m"))?;
384    let dw1 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw1.m"))?;
385    let ffn0 = conv_mlp(dim, dim * 2, vb.pp("ffn0.m"))?;
386    let ffn1 = conv_mlp(dim, dim * 2, vb.pp("ffn1.m"))?;
387    let attn = efficientvit_attn(cfg, stage, dim, vb.pp("mixer.m.attn"))?;
388    Ok(Func::new(move |xs| {
389        let mut xs = xs.clone();
390        xs = (&xs + &xs.apply(&dw0)?)?;
391        xs = (&xs + &xs.apply(&ffn0)?)?;
392        xs = (&xs + &xs.apply(&attn)?)?;
393        xs = (&xs + &xs.apply(&dw1)?)?;
394        xs = (&xs + &xs.apply(&ffn1)?)?;
395        Ok(xs)
396    }))
397}
398
399// Each stage is made of blocks. There is a downsampling layer between stages.
400fn efficientvit_stage(cfg: &Config, stage: usize, vb: VarBuilder) -> Result<Func<'static>> {
401    let nblocks = cfg.blocks[stage];
402    let mut blocks = Vec::with_capacity(nblocks + 1);
403
404    let in_channels = if stage > 0 {
405        cfg.channels[stage - 1]
406    } else {
407        cfg.channels[0]
408    };
409    let out_channels = cfg.channels[stage];
410
411    if stage > 0 {
412        blocks.push(efficientvit_downsample(
413            in_channels,
414            out_channels,
415            vb.pp("downsample"),
416        )?);
417    }
418
419    for i in 0..nblocks {
420        blocks.push(efficientvit_block(
421            cfg,
422            stage,
423            out_channels,
424            vb.pp(format!("blocks.{i}")),
425        )?);
426    }
427
428    Ok(Func::new(move |xs| {
429        let mut xs = xs.clone();
430        for block in blocks.iter() {
431            xs = xs.apply(block)?
432        }
433        Ok(xs)
434    }))
435}
436
437// Classification head.
438fn efficientvit_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
439    let norm = batch_norm(outputs, 1e-6, vb.pp("bn"))?;
440    let linear = linear(outputs, nclasses, vb.pp("linear"))?;
441    Ok(Func::new(move |xs| {
442        xs.apply_t(&norm, false)?.apply(&linear)
443    }))
444}
445
446// Build a efficientvit model for a given configuration.
447fn efficientvit_model(
448    config: &Config,
449    nclasses: Option<usize>,
450    vb: VarBuilder,
451) -> Result<Func<'static>> {
452    let cls = match nclasses {
453        None => None,
454        Some(nclasses) => {
455            let outputs = config.channels[2];
456            let head = efficientvit_head(outputs, nclasses, vb.pp("head"))?;
457            Some(head)
458        }
459    };
460
461    let stem_dim = config.channels[0];
462    let stem = efficientvit_stem(stem_dim, vb.pp("patch_embed"))?;
463
464    let vb = vb.pp("stages");
465    let stage1 = efficientvit_stage(config, 0, vb.pp(0))?;
466    let stage2 = efficientvit_stage(config, 1, vb.pp(1))?;
467    let stage3 = efficientvit_stage(config, 2, vb.pp(2))?;
468
469    Ok(Func::new(move |xs| {
470        let xs = xs
471            .apply(&stem)?
472            .apply(&stage1)?
473            .apply(&stage2)?
474            .apply(&stage3)?
475            .mean(D::Minus2)?
476            .mean(D::Minus1)?;
477        match &cls {
478            None => Ok(xs),
479            Some(cls) => xs.apply(cls),
480        }
481    }))
482}
483
484pub fn efficientvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
485    efficientvit_model(cfg, Some(nclasses), vb)
486}
487
488pub fn efficientvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
489    efficientvit_model(cfg, None, vb)
490}