1use candle::{Result, Tensor, D};
39use candle_nn::{
40 batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
41 VarBuilder,
42};
43
44#[derive(Clone)]
45pub struct Config {
46 channels: [usize; 3],
47 blocks: [usize; 3],
48 heads: [usize; 3],
49 kernels: [usize; 4],
50}
51
52impl Config {
53 pub fn m0() -> Self {
54 Self {
55 channels: [64, 128, 192],
56 blocks: [1, 2, 3],
57 heads: [4, 4, 4],
58 kernels: [5, 5, 5, 5],
59 }
60 }
61 pub fn m1() -> Self {
62 Self {
63 channels: [128, 144, 192],
64 blocks: [1, 2, 3],
65 heads: [2, 3, 3],
66 kernels: [7, 5, 3, 3],
67 }
68 }
69 pub fn m2() -> Self {
70 Self {
71 channels: [128, 192, 224],
72 blocks: [1, 2, 3],
73 heads: [4, 3, 2],
74 kernels: [7, 5, 3, 3],
75 }
76 }
77 pub fn m3() -> Self {
78 Self {
79 channels: [128, 240, 320],
80 blocks: [1, 2, 3],
81 heads: [4, 3, 4],
82 kernels: [5, 5, 5, 5],
83 }
84 }
85 pub fn m4() -> Self {
86 Self {
87 channels: [128, 256, 384],
88 blocks: [1, 2, 3],
89 heads: [4, 4, 4],
90 kernels: [7, 5, 3, 3],
91 }
92 }
93
94 pub fn m5() -> Self {
95 Self {
96 channels: [192, 288, 384],
97 blocks: [1, 3, 4],
98 heads: [3, 3, 4],
99 kernels: [7, 5, 3, 3],
100 }
101 }
102}
103
104fn efficientvit_stemblock(
105 in_channels: usize,
106 out_channels: usize,
107 vb: VarBuilder,
108) -> Result<Func<'static>> {
109 let conv2d_cfg = Conv2dConfig {
110 stride: 2,
111 padding: 1,
112 ..Default::default()
113 };
114
115 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
116 let conv = conv2d_no_bias(in_channels, out_channels, 3, conv2d_cfg, vb.pp("conv"))?;
117
118 Ok(Func::new(move |xs| {
119 let xs = xs.apply(&conv)?.apply_t(&bn, false)?;
120 Ok(xs)
121 }))
122}
123
124fn efficientvit_stem(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
125 let conv1 = efficientvit_stemblock(3, dim / 8, vb.pp("conv1"))?;
126 let conv2 = efficientvit_stemblock(dim / 8, dim / 4, vb.pp("conv2"))?;
127 let conv3 = efficientvit_stemblock(dim / 4, dim / 2, vb.pp("conv3"))?;
128 let conv4 = efficientvit_stemblock(dim / 2, dim, vb.pp("conv4"))?;
129
130 Ok(Func::new(move |xs| {
131 let xs = xs
132 .apply(&conv1)?
133 .relu()?
134 .apply(&conv2)?
135 .relu()?
136 .apply(&conv3)?
137 .relu()?
138 .apply(&conv4)?;
139
140 Ok(xs)
141 }))
142}
143
144fn depthwise_conv(
145 channels: usize,
146 kernel: usize,
147 stride: usize,
148 padding: usize,
149 vb: VarBuilder,
150) -> Result<Func<'static>> {
151 let conv2d_cfg = Conv2dConfig {
152 stride,
153 padding,
154 groups: channels,
155 ..Default::default()
156 };
157
158 let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
159 let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
160
161 Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
162}
163
164fn pointwise_conv(
165 in_channels: usize,
166 out_channels: usize,
167 vb: VarBuilder,
168) -> Result<Func<'static>> {
169 let conv2d_cfg = Conv2dConfig {
170 ..Default::default()
171 };
172
173 let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
174 let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
175
176 Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
177}
178
179fn conv_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
180 let pw1 = pointwise_conv(in_channels, out_channels, vb.pp("pw1"))?;
181 let pw2 = pointwise_conv(out_channels, in_channels, vb.pp("pw2"))?;
182
183 Ok(Func::new(move |xs| {
184 let xs = xs.apply(&pw1)?.relu()?.apply(&pw2)?;
185 Ok(xs)
186 }))
187}
188
189const RESOLUTIONS: [usize; 3] = [14, 7, 4];
191
192fn efficientvit_attn(
194 cfg: &Config,
195 stage: usize,
196 in_channels: usize,
197 vb: VarBuilder,
198) -> Result<Func<'static>> {
199 let cga = cascaded_group_attn(cfg, stage, in_channels, vb)?;
200
201 Ok(Func::new(move |xs| {
202 let mut xs = xs.clone();
203
204 let (b, c, h, w) = xs.dims4()?;
205 let win_res = 7; let pad_b = (win_res - h % win_res) % win_res;
207 let pad_r = (win_res - w % win_res) % win_res;
208 let ph = h + pad_b;
209 let pw = w + pad_r;
210 let nh = ph / win_res;
211 let nw = pw / win_res;
212
213 if RESOLUTIONS[stage] > win_res {
214 xs = xs.permute((0, 2, 3, 1))?;
215 xs = xs.pad_with_zeros(D::Minus1, 0, pad_r)?;
216 xs = xs.pad_with_zeros(D::Minus2, 0, pad_b)?;
217 xs = xs
218 .reshape((b, nh, win_res, nw, win_res, c))?
219 .transpose(2, 3)?;
220 xs = xs
221 .reshape((b * nh * nw, win_res, win_res, c))?
222 .permute((0, 3, 1, 2))?;
223 }
224
225 xs = xs.apply(&cga)?;
226
227 if RESOLUTIONS[stage] > win_res {
228 xs = xs
229 .permute((0, 2, 3, 1))?
230 .reshape((b, nh, nw, win_res, win_res, c))?;
231 xs = xs.transpose(2, 3)?.reshape((b, ph, pw, c))?;
232 xs = xs.permute((0, 3, 1, 2))?;
233 }
234
235 Ok(xs)
236 }))
237}
238
239fn cascaded_group_attn(
241 cfg: &Config,
242 stage: usize,
243 in_channels: usize,
244 vb: VarBuilder,
245) -> Result<Func<'static>> {
246 let heads = cfg.heads[stage];
247 let key_dim = 16;
248
249 let val_dim = in_channels / heads;
250
251 let scale = (key_dim as f64).powf(-0.5);
252
253 let mut dws = Vec::with_capacity(heads);
254 let mut qkvs = Vec::with_capacity(heads);
255 for i in 0..heads {
256 dws.push(depthwise_conv(
257 key_dim,
258 cfg.kernels[i],
259 1,
260 cfg.kernels[i] / 2,
261 vb.pp(format!("dws.{i}")),
262 )?);
263
264 qkvs.push(pointwise_conv(
265 in_channels / heads,
266 in_channels / heads + 2 * key_dim,
267 vb.pp(format!("qkvs.{i}")),
268 )?);
269 }
270 let proj = pointwise_conv(in_channels, in_channels, vb.pp("proj.1"))?;
271
272 Ok(Func::new(move |xs| {
273 let (b, _, h, w) = xs.dims4()?;
274 let feats_in = xs.chunk(heads, 1)?;
275 let mut feats_out = Vec::with_capacity(heads);
276 let mut feat = feats_in[0].clone();
277
278 for i in 0..heads {
279 if i > 0 {
280 feat = (&feat + &feats_in[i])?;
281 }
282 feat = feat.apply(&qkvs[i])?;
283 let res = feat.reshape((b, (), h, w))?;
284 let q = res.narrow(1, 0, key_dim)?;
285 let k = res.narrow(1, key_dim, key_dim)?;
286 let v = res.narrow(1, 2 * key_dim, val_dim)?;
287
288 let q = q.apply(&dws[i])?;
289
290 let q = q.flatten_from(2)?;
291 let k = k.flatten_from(2)?;
292 let v = v.flatten_from(2)?;
293 let q = (q * scale)?;
294
295 let att = q.transpose(D::Minus2, D::Minus1)?.matmul(&k)?;
296 let att = softmax(&att, D::Minus1)?;
297 feat = v.matmul(&att.transpose(D::Minus2, D::Minus1)?)?;
298 feat = feat.reshape((b, val_dim, h, w))?;
299 feats_out.push(feat.clone());
300 }
301
302 let xs = Tensor::cat(&feats_out, 1)?;
303 let xs = xs.relu()?.apply(&proj)?;
304
305 Ok(xs)
306 }))
307}
308
309fn squeeze_and_excitation(
311 in_channels: usize,
312 squeeze_channels: usize,
313 vb: VarBuilder,
314) -> Result<Func<'static>> {
315 let conv2d_cfg = Conv2dConfig {
316 ..Default::default()
317 };
318 let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
319 let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
320
321 Ok(Func::new(move |xs| {
322 let residual = xs;
323 let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
324 let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
325
326 residual.broadcast_mul(&xs)
327 }))
328}
329
330fn patchmerge(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
332 let dim = in_channels;
333 let hid_dim = in_channels * 4;
334 let conv1 = pointwise_conv(dim, hid_dim, vb.pp("conv1"))?;
335 let conv2 = depthwise_conv(hid_dim, 3, 2, 1, vb.pp("conv2"))?;
336 let conv3 = pointwise_conv(hid_dim, out_channels, vb.pp("conv3"))?;
337 let se = squeeze_and_excitation(hid_dim, hid_dim / 4, vb.pp("se"))?;
338 Ok(Func::new(move |xs| {
339 let xs = xs
340 .apply(&conv1)?
341 .relu()?
342 .apply(&conv2)?
343 .relu()?
344 .apply(&se)?
345 .apply(&conv3)?;
346 Ok(xs)
347 }))
348}
349
350fn res(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
352 let dw = depthwise_conv(dim, 3, 1, 1, vb.pp("0.m"))?;
353 let mlp = conv_mlp(dim, dim * 2, vb.pp("1.m"))?;
354 Ok(Func::new(move |xs| {
355 let mut xs = xs.clone();
356 xs = (&xs + &xs.apply(&dw)?)?;
357 xs = (&xs + &xs.apply(&mlp)?)?;
358 Ok(xs)
359 }))
360}
361
362fn efficientvit_downsample(
364 in_channels: usize,
365 out_channels: usize,
366 vb: VarBuilder,
367) -> Result<Func<'static>> {
368 let res1 = res(in_channels, vb.pp("res1"))?;
369 let res2 = res(out_channels, vb.pp("res2"))?;
370 let patchmerge = patchmerge(in_channels, out_channels, vb.pp("patchmerge"))?;
371 Ok(Func::new(move |xs| {
372 let xs = xs.apply(&res1)?.apply(&patchmerge)?.apply(&res2)?;
373 Ok(xs)
374 }))
375}
376
377fn efficientvit_block(
378 cfg: &Config,
379 stage: usize,
380 dim: usize,
381 vb: VarBuilder,
382) -> Result<Func<'static>> {
383 let dw0 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw0.m"))?;
384 let dw1 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw1.m"))?;
385 let ffn0 = conv_mlp(dim, dim * 2, vb.pp("ffn0.m"))?;
386 let ffn1 = conv_mlp(dim, dim * 2, vb.pp("ffn1.m"))?;
387 let attn = efficientvit_attn(cfg, stage, dim, vb.pp("mixer.m.attn"))?;
388 Ok(Func::new(move |xs| {
389 let mut xs = xs.clone();
390 xs = (&xs + &xs.apply(&dw0)?)?;
391 xs = (&xs + &xs.apply(&ffn0)?)?;
392 xs = (&xs + &xs.apply(&attn)?)?;
393 xs = (&xs + &xs.apply(&dw1)?)?;
394 xs = (&xs + &xs.apply(&ffn1)?)?;
395 Ok(xs)
396 }))
397}
398
399fn efficientvit_stage(cfg: &Config, stage: usize, vb: VarBuilder) -> Result<Func<'static>> {
401 let nblocks = cfg.blocks[stage];
402 let mut blocks = Vec::with_capacity(nblocks + 1);
403
404 let in_channels = if stage > 0 {
405 cfg.channels[stage - 1]
406 } else {
407 cfg.channels[0]
408 };
409 let out_channels = cfg.channels[stage];
410
411 if stage > 0 {
412 blocks.push(efficientvit_downsample(
413 in_channels,
414 out_channels,
415 vb.pp("downsample"),
416 )?);
417 }
418
419 for i in 0..nblocks {
420 blocks.push(efficientvit_block(
421 cfg,
422 stage,
423 out_channels,
424 vb.pp(format!("blocks.{i}")),
425 )?);
426 }
427
428 Ok(Func::new(move |xs| {
429 let mut xs = xs.clone();
430 for block in blocks.iter() {
431 xs = xs.apply(block)?
432 }
433 Ok(xs)
434 }))
435}
436
437fn efficientvit_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
439 let norm = batch_norm(outputs, 1e-6, vb.pp("bn"))?;
440 let linear = linear(outputs, nclasses, vb.pp("linear"))?;
441 Ok(Func::new(move |xs| {
442 xs.apply_t(&norm, false)?.apply(&linear)
443 }))
444}
445
446fn efficientvit_model(
448 config: &Config,
449 nclasses: Option<usize>,
450 vb: VarBuilder,
451) -> Result<Func<'static>> {
452 let cls = match nclasses {
453 None => None,
454 Some(nclasses) => {
455 let outputs = config.channels[2];
456 let head = efficientvit_head(outputs, nclasses, vb.pp("head"))?;
457 Some(head)
458 }
459 };
460
461 let stem_dim = config.channels[0];
462 let stem = efficientvit_stem(stem_dim, vb.pp("patch_embed"))?;
463
464 let vb = vb.pp("stages");
465 let stage1 = efficientvit_stage(config, 0, vb.pp(0))?;
466 let stage2 = efficientvit_stage(config, 1, vb.pp(1))?;
467 let stage3 = efficientvit_stage(config, 2, vb.pp(2))?;
468
469 Ok(Func::new(move |xs| {
470 let xs = xs
471 .apply(&stem)?
472 .apply(&stage1)?
473 .apply(&stage2)?
474 .apply(&stage3)?
475 .mean(D::Minus2)?
476 .mean(D::Minus1)?;
477 match &cls {
478 None => Ok(xs),
479 Some(cls) => xs.apply(cls),
480 }
481 }))
482}
483
484pub fn efficientvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
485 efficientvit_model(cfg, Some(nclasses), vb)
486}
487
488pub fn efficientvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
489 efficientvit_model(cfg, None, vb)
490}