candle_transformers/models/segment_anything/
image_encoder.rs

1use candle::{DType, IndexOp, Result, Tensor};
2use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
3
4#[derive(Debug)]
5struct PatchEmbed {
6    proj: candle_nn::Conv2d,
7    span: tracing::Span,
8}
9
10impl PatchEmbed {
11    fn new(
12        in_chans: usize,
13        embed_dim: usize,
14        k_size: usize,
15        stride: usize,
16        padding: usize,
17        vb: VarBuilder,
18    ) -> Result<Self> {
19        let cfg = candle_nn::Conv2dConfig {
20            stride,
21            padding,
22            ..Default::default()
23        };
24        let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
25        let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
26        Ok(Self { proj, span })
27    }
28}
29
30impl Module for PatchEmbed {
31    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32        let _enter = self.span.enter();
33        xs.apply(&self.proj)?.permute((0, 2, 3, 1))
34    }
35}
36
37// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
38// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
39//   (attn.reshape((b, q_h, q_w, k_h, k_w))?
40//       + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
41//   .reshape((b, q_h * q_w, k_h * k_w))
42// Ideally we would perform this operation in place but this is not supported in candle at the
43// moment. We should also investigate using f16 rather than f32.
44struct Add3(usize, usize, usize, usize, usize);
45impl candle::CustomOp3 for Add3 {
46    fn name(&self) -> &'static str {
47        "add3"
48    }
49
50    fn cpu_fwd(
51        &self,
52        s1: &candle::CpuStorage,
53        l1: &candle::Layout,
54        s2: &candle::CpuStorage,
55        l2: &candle::Layout,
56        s3: &candle::CpuStorage,
57        l3: &candle::Layout,
58    ) -> Result<(candle::CpuStorage, candle::Shape)> {
59        use rayon::prelude::*;
60
61        let Add3(b, q_h, q_w, k_h, k_w) = *self;
62        let s1 = s1.as_slice::<f32>()?;
63        let s1 = match l1.contiguous_offsets() {
64            None => candle::bail!("input1 has to be contiguous"),
65            Some((o1, o2)) => &s1[o1..o2],
66        };
67        let s2 = s2.as_slice::<f32>()?;
68        let s2 = match l2.contiguous_offsets() {
69            None => candle::bail!("input2 has to be contiguous"),
70            Some((o1, o2)) => &s2[o1..o2],
71        };
72        let s3 = s3.as_slice::<f32>()?;
73        let s3 = match l3.contiguous_offsets() {
74            None => candle::bail!("input3 has to be contiguous"),
75            Some((o1, o2)) => &s3[o1..o2],
76        };
77        let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
78        dst.par_chunks_exact_mut(k_h * k_w)
79            .enumerate()
80            .for_each(|(b_idx, dst)| {
81                let s1_idx = b_idx * k_h * k_w;
82                let s2_idx = b_idx * k_h;
83                let s3_idx = b_idx * k_w;
84                for h_idx in 0..k_h {
85                    let s1_idx = s1_idx + h_idx * k_w;
86                    let s2_idx = s2_idx + h_idx;
87                    let dst_idx = h_idx * k_w;
88                    for w_idx in 0..k_w {
89                        let s1_idx = s1_idx + w_idx;
90                        let s3_idx = s3_idx + w_idx;
91                        let dst_idx = dst_idx + w_idx;
92                        dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
93                    }
94                }
95            });
96        let dst = candle::WithDType::to_cpu_storage_owned(dst);
97        Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
98    }
99}
100
101#[derive(Debug)]
102struct Attention {
103    qkv: super::Linear,
104    proj: super::Linear,
105    num_heads: usize,
106    scale: f64,
107    rel_pos_hw: Option<(Tensor, Tensor)>,
108    span: tracing::Span,
109    span_matmul: tracing::Span,
110    span_rel_pos: tracing::Span,
111    span_softmax: tracing::Span,
112}
113
114impl Attention {
115    fn new(
116        dim: usize,
117        num_heads: usize,
118        qkv_bias: bool,
119        use_rel_pos: bool,
120        input_size: (usize, usize),
121        vb: VarBuilder,
122    ) -> Result<Self> {
123        let span = tracing::span!(tracing::Level::TRACE, "attention");
124        let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
125        let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
126        let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
127        let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
128        let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
129        let head_dim = dim / num_heads;
130        let scale = 1. / (head_dim as f64).sqrt();
131        let rel_pos_hw = if use_rel_pos {
132            let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
133            let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
134            Some((h, w))
135        } else {
136            None
137        };
138        Ok(Self {
139            qkv,
140            proj,
141            num_heads,
142            scale,
143            rel_pos_hw,
144            span,
145            span_matmul,
146            span_rel_pos,
147            span_softmax,
148        })
149    }
150
151    fn add_decomposed_rel_pos(
152        &self,
153        attn: Tensor,
154        q: &Tensor,
155        (q_h, q_w): (usize, usize),
156        (k_h, k_w): (usize, usize),
157    ) -> Result<Tensor> {
158        match &self.rel_pos_hw {
159            Some((rel_pos_h, rel_pos_w)) => {
160                let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
161                let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
162                let (b, _, dim) = q.dims3()?;
163                let r_q = q.reshape((b, q_h, q_w, dim))?;
164                // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
165                let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
166                // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
167                let rel_w = r_q
168                    .transpose(1, 2)? // -> bwhc
169                    .contiguous()?
170                    .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
171                    .transpose(1, 2)?
172                    .contiguous()?;
173                if attn.device().is_cpu() {
174                    let op = Add3(b, q_h, q_w, k_h, k_w);
175                    attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
176                } else {
177                    (attn.reshape((b, q_h, q_w, k_h, k_w))?
178                        + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
179                    .reshape((b, q_h * q_w, k_h * k_w))
180                }
181            }
182            None => Ok(attn),
183        }
184    }
185}
186
187fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
188    let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
189    let dev = rel_pos.device();
190    let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
191        todo!("interpolation")
192    } else {
193        rel_pos
194    };
195    let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
196        .reshape((q_size, 1))?
197        .to_dtype(DType::F32)?;
198    let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
199        .reshape((1, k_size))?
200        .to_dtype(DType::F32)?;
201    let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
202    let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
203    let relative_coords = (q_coords.broadcast_sub(&k_coords)?
204        + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
205    let (d1, d2) = relative_coords.dims2()?;
206    let relative_coords = relative_coords.to_dtype(DType::U32)?;
207    rel_pos_resized
208        .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
209        .reshape((d1, d2, ()))
210}
211
212impl Module for Attention {
213    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
214        let _enter = self.span.enter();
215        let (b, h, w, c) = xs.dims4()?;
216        let qkv = self
217            .qkv
218            .forward(&xs.flatten_to(1)?)?
219            .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
220            .permute((2, 0, 3, 1, 4))?
221            .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
222        let q = qkv.i(0)?;
223        let k = qkv.i(1)?;
224        let v = qkv.i(2)?;
225        let attn = {
226            let _enter = self.span_matmul.enter();
227            (&q * self.scale)?.matmul(&k.t()?)?
228        };
229        let attn = {
230            let _enter = self.span_rel_pos.enter();
231            self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
232        };
233        let attn = {
234            let _enter = self.span_softmax.enter();
235            candle_nn::ops::softmax_last_dim(&attn)?
236        };
237        let attn = {
238            let _enter = self.span_matmul.enter();
239            attn.matmul(&v)?
240        };
241        let attn = attn
242            .reshape((b, self.num_heads, h, w, c / self.num_heads))?
243            .permute((0, 2, 3, 1, 4))?
244            .reshape((b, h * w, c))?;
245        self.proj.forward(&attn)?.reshape((b, h, w, c))
246    }
247}
248
249#[derive(Debug)]
250struct Block {
251    norm1: LayerNorm,
252    attn: Attention,
253    norm2: LayerNorm,
254    mlp: super::MlpBlock,
255    window_size: usize,
256    span: tracing::Span,
257}
258
259impl Block {
260    fn new(
261        dim: usize,
262        num_heads: usize,
263        qkv_bias: bool,
264        use_rel_pos: bool,
265        window_size: usize,
266        input_size: (usize, usize),
267        vb: VarBuilder,
268    ) -> Result<Self> {
269        let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
270        let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
271        let input_size_attn = if window_size == 0 {
272            input_size
273        } else {
274            (window_size, window_size)
275        };
276        let attn = Attention::new(
277            dim,
278            num_heads,
279            qkv_bias,
280            use_rel_pos,
281            input_size_attn,
282            vb.pp("attn"),
283        )?;
284        let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
285        let span = tracing::span!(tracing::Level::TRACE, "ie-block");
286        Ok(Self {
287            norm1,
288            attn,
289            norm2,
290            mlp,
291            window_size,
292            span,
293        })
294    }
295}
296
297fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
298    let (b, h, w, c) = xs.dims4()?;
299    let pad_h = (window_size - h % window_size) % window_size;
300    let pad_w = (window_size - w % window_size) % window_size;
301    let xs = if pad_h > 0 {
302        xs.pad_with_zeros(1, 0, pad_h)?
303    } else {
304        xs
305    };
306    let xs = if pad_w > 0 {
307        xs.pad_with_zeros(2, 0, pad_w)?
308    } else {
309        xs
310    };
311    let (h_p, w_p) = (h + pad_h, w + pad_w);
312    let windows = xs
313        .reshape((
314            b,
315            h_p / window_size,
316            window_size,
317            w_p / window_size,
318            window_size,
319            c,
320        ))?
321        .transpose(2, 3)?
322        .contiguous()?
323        .flatten_to(2)?;
324    Ok((windows, (h_p, w_p)))
325}
326
327fn window_unpartition(
328    windows: Tensor,
329    window_size: usize,
330    (h_p, w_p): (usize, usize),
331    (h, w): (usize, usize),
332) -> Result<Tensor> {
333    let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
334    let xs = windows
335        .reshape((
336            b,
337            h_p / window_size,
338            w_p / window_size,
339            window_size,
340            window_size,
341            windows.elem_count() / b / h_p / w_p,
342        ))?
343        .transpose(2, 3)?
344        .contiguous()?
345        .reshape((b, h_p, w_p, ()))?;
346    let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
347    let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
348    Ok(xs)
349}
350
351impl Module for Block {
352    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
353        let _enter = self.span.enter();
354        let shortcut = xs;
355        let xs = self.norm1.forward(xs)?;
356        let hw = (xs.dim(1)?, xs.dim(2)?);
357        let (xs, pad_hw) = if self.window_size > 0 {
358            window_partition(xs, self.window_size)?
359        } else {
360            (xs, (0, 0))
361        };
362        let xs = self.attn.forward(&xs)?;
363        let xs = if self.window_size > 0 {
364            window_unpartition(xs, self.window_size, pad_hw, hw)?
365        } else {
366            xs
367        };
368        let xs = (xs + shortcut)?;
369        &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
370    }
371}
372
373#[derive(Debug)]
374pub struct ImageEncoderViT {
375    patch_embed: PatchEmbed,
376    blocks: Vec<Block>,
377    neck_conv1: candle_nn::Conv2d,
378    neck_ln1: super::LayerNorm2d,
379    neck_conv2: candle_nn::Conv2d,
380    neck_ln2: super::LayerNorm2d,
381    pos_embed: Option<Tensor>,
382    span: tracing::Span,
383}
384
385impl ImageEncoderViT {
386    #[allow(clippy::too_many_arguments)]
387    pub fn new(
388        img_size: usize,
389        patch_size: usize,
390        in_chans: usize,
391        embed_dim: usize,
392        depth: usize,
393        num_heads: usize,
394        out_chans: usize,
395        qkv_bias: bool,
396        use_rel_pos: bool,
397        use_abs_pos: bool,
398        window_size: usize,
399        global_attn_indexes: &[usize],
400        vb: VarBuilder,
401    ) -> Result<Self> {
402        let patch_embed = PatchEmbed::new(
403            in_chans,
404            embed_dim,
405            patch_size,
406            patch_size,
407            0,
408            vb.pp("patch_embed"),
409        )?;
410        let mut blocks = Vec::with_capacity(depth);
411        let vb_b = vb.pp("blocks");
412        for i in 0..depth {
413            let window_size = if global_attn_indexes.contains(&i) {
414                0
415            } else {
416                window_size
417            };
418            let block = Block::new(
419                embed_dim,
420                num_heads,
421                qkv_bias,
422                use_rel_pos,
423                window_size,
424                (img_size / patch_size, img_size / patch_size),
425                vb_b.pp(i),
426            )?;
427            blocks.push(block)
428        }
429        let neck_conv1 = candle_nn::conv2d_no_bias(
430            embed_dim,
431            out_chans,
432            1,
433            Default::default(),
434            vb.pp("neck.0"),
435        )?;
436        let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
437        let cfg = candle_nn::Conv2dConfig {
438            padding: 1,
439            ..Default::default()
440        };
441        let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
442        let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
443        let pos_embed = if use_abs_pos {
444            let p = vb.get(
445                (1, img_size / patch_size, img_size / patch_size, embed_dim),
446                "pos_embed",
447            )?;
448            Some(p)
449        } else {
450            None
451        };
452        let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
453        Ok(Self {
454            patch_embed,
455            blocks,
456            neck_conv1,
457            neck_ln1,
458            neck_conv2,
459            neck_ln2,
460            pos_embed,
461            span,
462        })
463    }
464}
465
466impl Module for ImageEncoderViT {
467    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
468        let _enter = self.span.enter();
469        let xs = self.patch_embed.forward(xs)?;
470        let mut xs = match &self.pos_embed {
471            Some(pos_embed) => (xs + pos_embed)?,
472            None => xs,
473        };
474        for block in self.blocks.iter() {
475            xs = block.forward(&xs)?
476        }
477        xs.permute((0, 3, 1, 2))?
478            .apply(&self.neck_conv1)?
479            .apply(&self.neck_ln1)?
480            .apply(&self.neck_conv2)?
481            .apply(&self.neck_ln2)
482    }
483}