candle_transformers/models/segment_anything/
transformer.rs

1use candle::{Result, Tensor};
2use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
3
4#[derive(Debug)]
5struct Attention {
6    q_proj: Linear,
7    k_proj: Linear,
8    v_proj: Linear,
9    out_proj: Linear,
10    num_heads: usize,
11}
12
13impl Attention {
14    fn new(
15        embedding_dim: usize,
16        num_heads: usize,
17        downsample_rate: usize,
18        vb: VarBuilder,
19    ) -> Result<Self> {
20        let internal_dim = embedding_dim / downsample_rate;
21        let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
22        let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
23        let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
24        let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
25        Ok(Self {
26            q_proj,
27            k_proj,
28            v_proj,
29            out_proj,
30            num_heads,
31        })
32    }
33
34    fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
35        let (b, n, c) = x.dims3()?;
36        x.reshape((b, n, self.num_heads, c / self.num_heads))?
37            .transpose(1, 2)?
38            .contiguous()
39    }
40
41    fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
42        let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
43        x.transpose(1, 2)?
44            .reshape((b, n_tokens, n_heads * c_per_head))
45    }
46
47    fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
48        let q = self.q_proj.forward(&q.contiguous()?)?;
49        let k = self.k_proj.forward(&k.contiguous()?)?;
50        let v = self.v_proj.forward(&v.contiguous()?)?;
51
52        let q = self.separate_heads(&q)?;
53        let k = self.separate_heads(&k)?;
54        let v = self.separate_heads(&v)?;
55
56        let (_, _, _, c_per_head) = q.dims4()?;
57        let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
58        let attn = candle_nn::ops::softmax_last_dim(&attn)?;
59
60        let out = attn.matmul(&v)?;
61        self.recombine_heads(&out)?.apply(&self.out_proj)
62    }
63}
64
65#[derive(Debug)]
66struct TwoWayAttentionBlock {
67    self_attn: Attention,
68    norm1: LayerNorm,
69    cross_attn_token_to_image: Attention,
70    norm2: LayerNorm,
71    mlp: super::MlpBlock,
72    norm3: LayerNorm,
73    norm4: LayerNorm,
74    cross_attn_image_to_token: Attention,
75    skip_first_layer_pe: bool,
76}
77
78impl TwoWayAttentionBlock {
79    fn new(
80        embedding_dim: usize,
81        num_heads: usize,
82        mlp_dim: usize,
83        skip_first_layer_pe: bool,
84        vb: VarBuilder,
85    ) -> Result<Self> {
86        let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
87        let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
88        let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
89        let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
90        let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
91        let cross_attn_token_to_image = Attention::new(
92            embedding_dim,
93            num_heads,
94            2,
95            vb.pp("cross_attn_token_to_image"),
96        )?;
97        let cross_attn_image_to_token = Attention::new(
98            embedding_dim,
99            num_heads,
100            2,
101            vb.pp("cross_attn_image_to_token"),
102        )?;
103        let mlp = super::MlpBlock::new(
104            embedding_dim,
105            mlp_dim,
106            candle_nn::Activation::Relu,
107            vb.pp("mlp"),
108        )?;
109        Ok(Self {
110            self_attn,
111            norm1,
112            cross_attn_image_to_token,
113            norm2,
114            mlp,
115            norm3,
116            norm4,
117            cross_attn_token_to_image,
118            skip_first_layer_pe,
119        })
120    }
121
122    fn forward(
123        &self,
124        queries: &Tensor,
125        keys: &Tensor,
126        query_pe: &Tensor,
127        key_pe: &Tensor,
128    ) -> Result<(Tensor, Tensor)> {
129        // Self attention block
130        let queries = if self.skip_first_layer_pe {
131            self.self_attn.forward(queries, queries, queries)?
132        } else {
133            let q = (queries + query_pe)?;
134            let attn_out = self.self_attn.forward(&q, &q, queries)?;
135            (queries + attn_out)?
136        };
137        let queries = self.norm1.forward(&queries)?;
138
139        // Cross attention block, tokens attending to image embedding
140        let q = (&queries + query_pe)?;
141        let k = (keys + key_pe)?;
142        let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
143        let queries = (&queries + attn_out)?;
144        let queries = self.norm2.forward(&queries)?;
145
146        // MLP block
147        let mlp_out = self.mlp.forward(&queries);
148        let queries = (queries + mlp_out)?;
149        let queries = self.norm3.forward(&queries)?;
150
151        // Cross attention block, image embedding attending to tokens
152        let q = (&queries + query_pe)?;
153        let k = (keys + key_pe)?;
154        let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
155        let keys = (keys + attn_out)?;
156        let keys = self.norm4.forward(&keys)?;
157
158        Ok((queries, keys))
159    }
160}
161
162#[derive(Debug)]
163pub struct TwoWayTransformer {
164    layers: Vec<TwoWayAttentionBlock>,
165    final_attn_token_to_image: Attention,
166    norm_final_attn: LayerNorm,
167}
168
169impl TwoWayTransformer {
170    pub fn new(
171        depth: usize,
172        embedding_dim: usize,
173        num_heads: usize,
174        mlp_dim: usize,
175        vb: VarBuilder,
176    ) -> Result<Self> {
177        let vb_l = vb.pp("layers");
178        let mut layers = Vec::with_capacity(depth);
179        for i in 0..depth {
180            let layer =
181                TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
182            layers.push(layer)
183        }
184        let final_attn_token_to_image = Attention::new(
185            embedding_dim,
186            num_heads,
187            2,
188            vb.pp("final_attn_token_to_image"),
189        )?;
190        let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
191        Ok(Self {
192            layers,
193            final_attn_token_to_image,
194            norm_final_attn,
195        })
196    }
197
198    pub fn forward(
199        &self,
200        image_embedding: &Tensor,
201        image_pe: &Tensor,
202        point_embedding: &Tensor,
203    ) -> Result<(Tensor, Tensor)> {
204        let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
205        let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
206
207        let mut queries = point_embedding.clone();
208        let mut keys = image_embedding;
209
210        for layer in self.layers.iter() {
211            (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
212        }
213
214        let q = (&queries + point_embedding)?;
215        let k = (&keys + image_pe)?;
216        let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
217        let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
218
219        Ok((queries, keys))
220    }
221}