1use candle::{DType, IndexOp, Result, Tensor, D};
2use candle_nn::VarBuilder;
3
4#[derive(Debug)]
5struct PositionEmbeddingRandom {
6 positional_encoding_gaussian_matrix: Tensor,
7}
8
9impl PositionEmbeddingRandom {
10 fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
11 let positional_encoding_gaussian_matrix =
12 vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
13 Ok(Self {
14 positional_encoding_gaussian_matrix,
15 })
16 }
17
18 fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
19 let coords = coords.affine(2., -1.)?;
20 let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
21 let coords = (coords * (2. * std::f64::consts::PI))?;
22 Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
23 }
24
25 fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
26 let device = self.positional_encoding_gaussian_matrix.device();
27 let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
28 let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
29 let x_embed = (x_embed / w as f64)?
30 .reshape((1, ()))?
31 .broadcast_as((h, w))?;
32 let y_embed = (y_embed / h as f64)?
33 .reshape(((), 1))?
34 .broadcast_as((h, w))?;
35 let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
36 self.pe_encoding(&coords)?.permute((2, 0, 1))
37 }
38
39 fn forward_with_coords(
40 &self,
41 coords_input: &Tensor,
42 image_size: (usize, usize),
43 ) -> Result<Tensor> {
44 let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
45 let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
46 let c = coords_input.dim(D::Minus1)?;
47 let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
48 let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
49 self.pe_encoding(&coords)
50 }
51}
52
53#[derive(Debug)]
54pub struct PromptEncoder {
55 pe_layer: PositionEmbeddingRandom,
56 point_embeddings: Vec<candle_nn::Embedding>,
57 not_a_point_embed: candle_nn::Embedding,
58 mask_downscaling_conv1: candle_nn::Conv2d,
59 mask_downscaling_ln1: super::LayerNorm2d,
60 mask_downscaling_conv2: candle_nn::Conv2d,
61 mask_downscaling_ln2: super::LayerNorm2d,
62 mask_downscaling_conv3: candle_nn::Conv2d,
63 no_mask_embed: candle_nn::Embedding,
64 image_embedding_size: (usize, usize),
65 input_image_size: (usize, usize),
66 embed_dim: usize,
67 span: tracing::Span,
68}
69
70impl PromptEncoder {
71 pub fn new(
72 embed_dim: usize,
73 image_embedding_size: (usize, usize),
74 input_image_size: (usize, usize),
75 mask_in_chans: usize,
76 vb: VarBuilder,
77 ) -> Result<Self> {
78 let num_points_embeddings = 4;
79 let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
80 let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
81 let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
82 let cfg = candle_nn::Conv2dConfig {
83 stride: 2,
84 ..Default::default()
85 };
86 let mask_downscaling_conv1 =
87 candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
88 let mask_downscaling_conv2 = candle_nn::conv2d(
89 mask_in_chans / 4,
90 mask_in_chans,
91 2,
92 cfg,
93 vb.pp("mask_downscaling.3"),
94 )?;
95 let mask_downscaling_conv3 = candle_nn::conv2d(
96 mask_in_chans,
97 embed_dim,
98 1,
99 Default::default(),
100 vb.pp("mask_downscaling.6"),
101 )?;
102 let mask_downscaling_ln1 =
103 super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
104 let mask_downscaling_ln2 =
105 super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
106 let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
107 let vb_e = vb.pp("point_embeddings");
108 for i in 0..num_points_embeddings {
109 let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
110 point_embeddings.push(emb)
111 }
112 let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
113 Ok(Self {
114 pe_layer,
115 point_embeddings,
116 not_a_point_embed,
117 mask_downscaling_conv1,
118 mask_downscaling_ln1,
119 mask_downscaling_conv2,
120 mask_downscaling_ln2,
121 mask_downscaling_conv3,
122 no_mask_embed,
123 image_embedding_size,
124 input_image_size,
125 embed_dim,
126 span,
127 })
128 }
129
130 pub fn get_dense_pe(&self) -> Result<Tensor> {
131 self.pe_layer
132 .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
133 .unsqueeze(0)
134 }
135
136 fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
137 masks
138 .apply(&self.mask_downscaling_conv1)?
139 .apply(&self.mask_downscaling_ln1)?
140 .gelu()?
141 .apply(&self.mask_downscaling_conv2)?
142 .apply(&self.mask_downscaling_ln2)?
143 .gelu()?
144 .apply(&self.mask_downscaling_conv3)
145 }
146
147 fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
148 let points = (points + 0.5)?;
149 let dev = points.device();
150 let (points, labels) = if pad {
151 let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
152 let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
153 let points = Tensor::cat(&[&points, &padding_point], 1)?;
154 let labels = Tensor::cat(&[labels, &padding_label], 1)?;
155 (points, labels)
156 } else {
157 (points, labels.clone())
158 };
159 let point_embedding = self
160 .pe_layer
161 .forward_with_coords(&points, self.input_image_size)?;
162 let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
163 let zeros = point_embedding.zeros_like()?;
164 let point_embedding = labels.lt(0f32)?.where_cond(
165 &self
166 .not_a_point_embed
167 .embeddings()
168 .broadcast_as(zeros.shape())?,
169 &point_embedding,
170 )?;
171 let labels0 = labels.eq(0f32)?.where_cond(
172 &self.point_embeddings[0]
173 .embeddings()
174 .broadcast_as(zeros.shape())?,
175 &zeros,
176 )?;
177 let point_embedding = (point_embedding + labels0)?;
178 let labels1 = labels.eq(1f32)?.where_cond(
179 &self.point_embeddings[1]
180 .embeddings()
181 .broadcast_as(zeros.shape())?,
182 &zeros,
183 )?;
184 let point_embedding = (point_embedding + labels1)?;
185 Ok(point_embedding)
186 }
187
188 fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
189 let boxes = (boxes + 0.5)?;
190 let coords = boxes.reshape(((), 2, 2))?;
191 let corner_embedding = self
192 .pe_layer
193 .forward_with_coords(&coords, self.input_image_size)?;
194 let ce1 = corner_embedding.i((.., 0))?;
195 let ce2 = corner_embedding.i((.., 1))?;
196 let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
197 let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
198 Tensor::cat(&[&ce1, &ce2], 1)
199 }
200
201 pub fn forward(
202 &self,
203 points: Option<(&Tensor, &Tensor)>,
204 boxes: Option<&Tensor>,
205 masks: Option<&Tensor>,
206 ) -> Result<(Tensor, Tensor)> {
207 let _enter = self.span.enter();
208 let se_points = match points {
209 Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
210 None => None,
211 };
212 let se_boxes = match boxes {
213 Some(boxes) => Some(self.embed_boxes(boxes)?),
214 None => None,
215 };
216 let sparse_embeddings = match (se_points, se_boxes) {
217 (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
218 (Some(se_points), None) => se_points,
219 (None, Some(se_boxes)) => se_boxes,
220 (None, None) => {
221 let dev = self.no_mask_embed.embeddings().device();
222 Tensor::zeros((1, 0, self.embed_dim), DType::F32, dev)?
223 }
224 };
225
226 let dense_embeddings = match masks {
227 None => {
228 let emb = self.no_mask_embed.embeddings();
229 emb.reshape((1, (), 1, 1))?.expand((
230 1,
231 emb.elem_count(),
232 self.image_embedding_size.0,
233 self.image_embedding_size.1,
234 ))?
235 }
236 Some(masks) => self.embed_masks(masks)?,
237 };
238 Ok((sparse_embeddings, dense_embeddings))
239 }
240}