1use candle::{DType, IndexOp, Result, Tensor};
2use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
3
4#[derive(Debug)]
5struct PatchEmbed {
6 proj: candle_nn::Conv2d,
7 span: tracing::Span,
8}
9
10impl PatchEmbed {
11 fn new(
12 in_chans: usize,
13 embed_dim: usize,
14 k_size: usize,
15 stride: usize,
16 padding: usize,
17 vb: VarBuilder,
18 ) -> Result<Self> {
19 let cfg = candle_nn::Conv2dConfig {
20 stride,
21 padding,
22 ..Default::default()
23 };
24 let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
25 let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
26 Ok(Self { proj, span })
27 }
28}
29
30impl Module for PatchEmbed {
31 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32 let _enter = self.span.enter();
33 xs.apply(&self.proj)?.permute((0, 2, 3, 1))
34 }
35}
36
37struct Add3(usize, usize, usize, usize, usize);
45impl candle::CustomOp3 for Add3 {
46 fn name(&self) -> &'static str {
47 "add3"
48 }
49
50 fn cpu_fwd(
51 &self,
52 s1: &candle::CpuStorage,
53 l1: &candle::Layout,
54 s2: &candle::CpuStorage,
55 l2: &candle::Layout,
56 s3: &candle::CpuStorage,
57 l3: &candle::Layout,
58 ) -> Result<(candle::CpuStorage, candle::Shape)> {
59 use rayon::prelude::*;
60
61 let Add3(b, q_h, q_w, k_h, k_w) = *self;
62 let s1 = s1.as_slice::<f32>()?;
63 let s1 = match l1.contiguous_offsets() {
64 None => candle::bail!("input1 has to be contiguous"),
65 Some((o1, o2)) => &s1[o1..o2],
66 };
67 let s2 = s2.as_slice::<f32>()?;
68 let s2 = match l2.contiguous_offsets() {
69 None => candle::bail!("input2 has to be contiguous"),
70 Some((o1, o2)) => &s2[o1..o2],
71 };
72 let s3 = s3.as_slice::<f32>()?;
73 let s3 = match l3.contiguous_offsets() {
74 None => candle::bail!("input3 has to be contiguous"),
75 Some((o1, o2)) => &s3[o1..o2],
76 };
77 let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
78 dst.par_chunks_exact_mut(k_h * k_w)
79 .enumerate()
80 .for_each(|(b_idx, dst)| {
81 let s1_idx = b_idx * k_h * k_w;
82 let s2_idx = b_idx * k_h;
83 let s3_idx = b_idx * k_w;
84 for h_idx in 0..k_h {
85 let s1_idx = s1_idx + h_idx * k_w;
86 let s2_idx = s2_idx + h_idx;
87 let dst_idx = h_idx * k_w;
88 for w_idx in 0..k_w {
89 let s1_idx = s1_idx + w_idx;
90 let s3_idx = s3_idx + w_idx;
91 let dst_idx = dst_idx + w_idx;
92 dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
93 }
94 }
95 });
96 let dst = candle::WithDType::to_cpu_storage_owned(dst);
97 Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
98 }
99}
100
101#[derive(Debug)]
102struct Attention {
103 qkv: super::Linear,
104 proj: super::Linear,
105 num_heads: usize,
106 scale: f64,
107 rel_pos_hw: Option<(Tensor, Tensor)>,
108 span: tracing::Span,
109 span_matmul: tracing::Span,
110 span_rel_pos: tracing::Span,
111 span_softmax: tracing::Span,
112}
113
114impl Attention {
115 fn new(
116 dim: usize,
117 num_heads: usize,
118 qkv_bias: bool,
119 use_rel_pos: bool,
120 input_size: (usize, usize),
121 vb: VarBuilder,
122 ) -> Result<Self> {
123 let span = tracing::span!(tracing::Level::TRACE, "attention");
124 let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
125 let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
126 let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
127 let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
128 let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
129 let head_dim = dim / num_heads;
130 let scale = 1. / (head_dim as f64).sqrt();
131 let rel_pos_hw = if use_rel_pos {
132 let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
133 let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
134 Some((h, w))
135 } else {
136 None
137 };
138 Ok(Self {
139 qkv,
140 proj,
141 num_heads,
142 scale,
143 rel_pos_hw,
144 span,
145 span_matmul,
146 span_rel_pos,
147 span_softmax,
148 })
149 }
150
151 fn add_decomposed_rel_pos(
152 &self,
153 attn: Tensor,
154 q: &Tensor,
155 (q_h, q_w): (usize, usize),
156 (k_h, k_w): (usize, usize),
157 ) -> Result<Tensor> {
158 match &self.rel_pos_hw {
159 Some((rel_pos_h, rel_pos_w)) => {
160 let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
161 let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
162 let (b, _, dim) = q.dims3()?;
163 let r_q = q.reshape((b, q_h, q_w, dim))?;
164 let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
166 let rel_w = r_q
168 .transpose(1, 2)? .contiguous()?
170 .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? .transpose(1, 2)?
172 .contiguous()?;
173 if attn.device().is_cpu() {
174 let op = Add3(b, q_h, q_w, k_h, k_w);
175 attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
176 } else {
177 (attn.reshape((b, q_h, q_w, k_h, k_w))?
178 + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
179 .reshape((b, q_h * q_w, k_h * k_w))
180 }
181 }
182 None => Ok(attn),
183 }
184 }
185}
186
187fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
188 let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
189 let dev = rel_pos.device();
190 let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
191 todo!("interpolation")
192 } else {
193 rel_pos
194 };
195 let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
196 .reshape((q_size, 1))?
197 .to_dtype(DType::F32)?;
198 let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
199 .reshape((1, k_size))?
200 .to_dtype(DType::F32)?;
201 let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
202 let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
203 let relative_coords = (q_coords.broadcast_sub(&k_coords)?
204 + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
205 let (d1, d2) = relative_coords.dims2()?;
206 let relative_coords = relative_coords.to_dtype(DType::U32)?;
207 rel_pos_resized
208 .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
209 .reshape((d1, d2, ()))
210}
211
212impl Module for Attention {
213 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
214 let _enter = self.span.enter();
215 let (b, h, w, c) = xs.dims4()?;
216 let qkv = self
217 .qkv
218 .forward(&xs.flatten_to(1)?)?
219 .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
220 .permute((2, 0, 3, 1, 4))?
221 .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
222 let q = qkv.i(0)?;
223 let k = qkv.i(1)?;
224 let v = qkv.i(2)?;
225 let attn = {
226 let _enter = self.span_matmul.enter();
227 (&q * self.scale)?.matmul(&k.t()?)?
228 };
229 let attn = {
230 let _enter = self.span_rel_pos.enter();
231 self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
232 };
233 let attn = {
234 let _enter = self.span_softmax.enter();
235 candle_nn::ops::softmax_last_dim(&attn)?
236 };
237 let attn = {
238 let _enter = self.span_matmul.enter();
239 attn.matmul(&v)?
240 };
241 let attn = attn
242 .reshape((b, self.num_heads, h, w, c / self.num_heads))?
243 .permute((0, 2, 3, 1, 4))?
244 .reshape((b, h * w, c))?;
245 self.proj.forward(&attn)?.reshape((b, h, w, c))
246 }
247}
248
249#[derive(Debug)]
250struct Block {
251 norm1: LayerNorm,
252 attn: Attention,
253 norm2: LayerNorm,
254 mlp: super::MlpBlock,
255 window_size: usize,
256 span: tracing::Span,
257}
258
259impl Block {
260 fn new(
261 dim: usize,
262 num_heads: usize,
263 qkv_bias: bool,
264 use_rel_pos: bool,
265 window_size: usize,
266 input_size: (usize, usize),
267 vb: VarBuilder,
268 ) -> Result<Self> {
269 let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
270 let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
271 let input_size_attn = if window_size == 0 {
272 input_size
273 } else {
274 (window_size, window_size)
275 };
276 let attn = Attention::new(
277 dim,
278 num_heads,
279 qkv_bias,
280 use_rel_pos,
281 input_size_attn,
282 vb.pp("attn"),
283 )?;
284 let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
285 let span = tracing::span!(tracing::Level::TRACE, "ie-block");
286 Ok(Self {
287 norm1,
288 attn,
289 norm2,
290 mlp,
291 window_size,
292 span,
293 })
294 }
295}
296
297fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
298 let (b, h, w, c) = xs.dims4()?;
299 let pad_h = (window_size - h % window_size) % window_size;
300 let pad_w = (window_size - w % window_size) % window_size;
301 let xs = if pad_h > 0 {
302 xs.pad_with_zeros(1, 0, pad_h)?
303 } else {
304 xs
305 };
306 let xs = if pad_w > 0 {
307 xs.pad_with_zeros(2, 0, pad_w)?
308 } else {
309 xs
310 };
311 let (h_p, w_p) = (h + pad_h, w + pad_w);
312 let windows = xs
313 .reshape((
314 b,
315 h_p / window_size,
316 window_size,
317 w_p / window_size,
318 window_size,
319 c,
320 ))?
321 .transpose(2, 3)?
322 .contiguous()?
323 .flatten_to(2)?;
324 Ok((windows, (h_p, w_p)))
325}
326
327fn window_unpartition(
328 windows: Tensor,
329 window_size: usize,
330 (h_p, w_p): (usize, usize),
331 (h, w): (usize, usize),
332) -> Result<Tensor> {
333 let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
334 let xs = windows
335 .reshape((
336 b,
337 h_p / window_size,
338 w_p / window_size,
339 window_size,
340 window_size,
341 windows.elem_count() / b / h_p / w_p,
342 ))?
343 .transpose(2, 3)?
344 .contiguous()?
345 .reshape((b, h_p, w_p, ()))?;
346 let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
347 let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
348 Ok(xs)
349}
350
351impl Module for Block {
352 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
353 let _enter = self.span.enter();
354 let shortcut = xs;
355 let xs = self.norm1.forward(xs)?;
356 let hw = (xs.dim(1)?, xs.dim(2)?);
357 let (xs, pad_hw) = if self.window_size > 0 {
358 window_partition(xs, self.window_size)?
359 } else {
360 (xs, (0, 0))
361 };
362 let xs = self.attn.forward(&xs)?;
363 let xs = if self.window_size > 0 {
364 window_unpartition(xs, self.window_size, pad_hw, hw)?
365 } else {
366 xs
367 };
368 let xs = (xs + shortcut)?;
369 &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
370 }
371}
372
373#[derive(Debug)]
374pub struct ImageEncoderViT {
375 patch_embed: PatchEmbed,
376 blocks: Vec<Block>,
377 neck_conv1: candle_nn::Conv2d,
378 neck_ln1: super::LayerNorm2d,
379 neck_conv2: candle_nn::Conv2d,
380 neck_ln2: super::LayerNorm2d,
381 pos_embed: Option<Tensor>,
382 span: tracing::Span,
383}
384
385impl ImageEncoderViT {
386 #[allow(clippy::too_many_arguments)]
387 pub fn new(
388 img_size: usize,
389 patch_size: usize,
390 in_chans: usize,
391 embed_dim: usize,
392 depth: usize,
393 num_heads: usize,
394 out_chans: usize,
395 qkv_bias: bool,
396 use_rel_pos: bool,
397 use_abs_pos: bool,
398 window_size: usize,
399 global_attn_indexes: &[usize],
400 vb: VarBuilder,
401 ) -> Result<Self> {
402 let patch_embed = PatchEmbed::new(
403 in_chans,
404 embed_dim,
405 patch_size,
406 patch_size,
407 0,
408 vb.pp("patch_embed"),
409 )?;
410 let mut blocks = Vec::with_capacity(depth);
411 let vb_b = vb.pp("blocks");
412 for i in 0..depth {
413 let window_size = if global_attn_indexes.contains(&i) {
414 0
415 } else {
416 window_size
417 };
418 let block = Block::new(
419 embed_dim,
420 num_heads,
421 qkv_bias,
422 use_rel_pos,
423 window_size,
424 (img_size / patch_size, img_size / patch_size),
425 vb_b.pp(i),
426 )?;
427 blocks.push(block)
428 }
429 let neck_conv1 = candle_nn::conv2d_no_bias(
430 embed_dim,
431 out_chans,
432 1,
433 Default::default(),
434 vb.pp("neck.0"),
435 )?;
436 let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
437 let cfg = candle_nn::Conv2dConfig {
438 padding: 1,
439 ..Default::default()
440 };
441 let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
442 let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
443 let pos_embed = if use_abs_pos {
444 let p = vb.get(
445 (1, img_size / patch_size, img_size / patch_size, embed_dim),
446 "pos_embed",
447 )?;
448 Some(p)
449 } else {
450 None
451 };
452 let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
453 Ok(Self {
454 patch_embed,
455 blocks,
456 neck_conv1,
457 neck_ln1,
458 neck_conv2,
459 neck_ln2,
460 pos_embed,
461 span,
462 })
463 }
464}
465
466impl Module for ImageEncoderViT {
467 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
468 let _enter = self.span.enter();
469 let xs = self.patch_embed.forward(xs)?;
470 let mut xs = match &self.pos_embed {
471 Some(pos_embed) => (xs + pos_embed)?,
472 None => xs,
473 };
474 for block in self.blocks.iter() {
475 xs = block.forward(&xs)?
476 }
477 xs.permute((0, 3, 1, 2))?
478 .apply(&self.neck_conv1)?
479 .apply(&self.neck_ln1)?
480 .apply(&self.neck_conv2)?
481 .apply(&self.neck_ln2)
482 }
483}