candle_transformers/models/segment_anything/
transformer.rs1use candle::{Result, Tensor};
2use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
3
4#[derive(Debug)]
5struct Attention {
6 q_proj: Linear,
7 k_proj: Linear,
8 v_proj: Linear,
9 out_proj: Linear,
10 num_heads: usize,
11}
12
13impl Attention {
14 fn new(
15 embedding_dim: usize,
16 num_heads: usize,
17 downsample_rate: usize,
18 vb: VarBuilder,
19 ) -> Result<Self> {
20 let internal_dim = embedding_dim / downsample_rate;
21 let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
22 let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
23 let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
24 let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
25 Ok(Self {
26 q_proj,
27 k_proj,
28 v_proj,
29 out_proj,
30 num_heads,
31 })
32 }
33
34 fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
35 let (b, n, c) = x.dims3()?;
36 x.reshape((b, n, self.num_heads, c / self.num_heads))?
37 .transpose(1, 2)?
38 .contiguous()
39 }
40
41 fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
42 let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
43 x.transpose(1, 2)?
44 .reshape((b, n_tokens, n_heads * c_per_head))
45 }
46
47 fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
48 let q = self.q_proj.forward(&q.contiguous()?)?;
49 let k = self.k_proj.forward(&k.contiguous()?)?;
50 let v = self.v_proj.forward(&v.contiguous()?)?;
51
52 let q = self.separate_heads(&q)?;
53 let k = self.separate_heads(&k)?;
54 let v = self.separate_heads(&v)?;
55
56 let (_, _, _, c_per_head) = q.dims4()?;
57 let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
58 let attn = candle_nn::ops::softmax_last_dim(&attn)?;
59
60 let out = attn.matmul(&v)?;
61 self.recombine_heads(&out)?.apply(&self.out_proj)
62 }
63}
64
65#[derive(Debug)]
66struct TwoWayAttentionBlock {
67 self_attn: Attention,
68 norm1: LayerNorm,
69 cross_attn_token_to_image: Attention,
70 norm2: LayerNorm,
71 mlp: super::MlpBlock,
72 norm3: LayerNorm,
73 norm4: LayerNorm,
74 cross_attn_image_to_token: Attention,
75 skip_first_layer_pe: bool,
76}
77
78impl TwoWayAttentionBlock {
79 fn new(
80 embedding_dim: usize,
81 num_heads: usize,
82 mlp_dim: usize,
83 skip_first_layer_pe: bool,
84 vb: VarBuilder,
85 ) -> Result<Self> {
86 let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
87 let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
88 let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
89 let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
90 let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
91 let cross_attn_token_to_image = Attention::new(
92 embedding_dim,
93 num_heads,
94 2,
95 vb.pp("cross_attn_token_to_image"),
96 )?;
97 let cross_attn_image_to_token = Attention::new(
98 embedding_dim,
99 num_heads,
100 2,
101 vb.pp("cross_attn_image_to_token"),
102 )?;
103 let mlp = super::MlpBlock::new(
104 embedding_dim,
105 mlp_dim,
106 candle_nn::Activation::Relu,
107 vb.pp("mlp"),
108 )?;
109 Ok(Self {
110 self_attn,
111 norm1,
112 cross_attn_image_to_token,
113 norm2,
114 mlp,
115 norm3,
116 norm4,
117 cross_attn_token_to_image,
118 skip_first_layer_pe,
119 })
120 }
121
122 fn forward(
123 &self,
124 queries: &Tensor,
125 keys: &Tensor,
126 query_pe: &Tensor,
127 key_pe: &Tensor,
128 ) -> Result<(Tensor, Tensor)> {
129 let queries = if self.skip_first_layer_pe {
131 self.self_attn.forward(queries, queries, queries)?
132 } else {
133 let q = (queries + query_pe)?;
134 let attn_out = self.self_attn.forward(&q, &q, queries)?;
135 (queries + attn_out)?
136 };
137 let queries = self.norm1.forward(&queries)?;
138
139 let q = (&queries + query_pe)?;
141 let k = (keys + key_pe)?;
142 let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
143 let queries = (&queries + attn_out)?;
144 let queries = self.norm2.forward(&queries)?;
145
146 let mlp_out = self.mlp.forward(&queries);
148 let queries = (queries + mlp_out)?;
149 let queries = self.norm3.forward(&queries)?;
150
151 let q = (&queries + query_pe)?;
153 let k = (keys + key_pe)?;
154 let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
155 let keys = (keys + attn_out)?;
156 let keys = self.norm4.forward(&keys)?;
157
158 Ok((queries, keys))
159 }
160}
161
162#[derive(Debug)]
163pub struct TwoWayTransformer {
164 layers: Vec<TwoWayAttentionBlock>,
165 final_attn_token_to_image: Attention,
166 norm_final_attn: LayerNorm,
167}
168
169impl TwoWayTransformer {
170 pub fn new(
171 depth: usize,
172 embedding_dim: usize,
173 num_heads: usize,
174 mlp_dim: usize,
175 vb: VarBuilder,
176 ) -> Result<Self> {
177 let vb_l = vb.pp("layers");
178 let mut layers = Vec::with_capacity(depth);
179 for i in 0..depth {
180 let layer =
181 TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
182 layers.push(layer)
183 }
184 let final_attn_token_to_image = Attention::new(
185 embedding_dim,
186 num_heads,
187 2,
188 vb.pp("final_attn_token_to_image"),
189 )?;
190 let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
191 Ok(Self {
192 layers,
193 final_attn_token_to_image,
194 norm_final_attn,
195 })
196 }
197
198 pub fn forward(
199 &self,
200 image_embedding: &Tensor,
201 image_pe: &Tensor,
202 point_embedding: &Tensor,
203 ) -> Result<(Tensor, Tensor)> {
204 let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
205 let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
206
207 let mut queries = point_embedding.clone();
208 let mut keys = image_embedding;
209
210 for layer in self.layers.iter() {
211 (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
212 }
213
214 let q = (&queries + point_embedding)?;
215 let k = (&keys + image_pe)?;
216 let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
217 let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
218
219 Ok((queries, keys))
220 }
221}