candle_transformers/models/segment_anything/
prompt_encoder.rs

1use candle::{DType, IndexOp, Result, Tensor, D};
2use candle_nn::VarBuilder;
3
4#[derive(Debug)]
5struct PositionEmbeddingRandom {
6    positional_encoding_gaussian_matrix: Tensor,
7}
8
9impl PositionEmbeddingRandom {
10    fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
11        let positional_encoding_gaussian_matrix =
12            vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
13        Ok(Self {
14            positional_encoding_gaussian_matrix,
15        })
16    }
17
18    fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
19        let coords = coords.affine(2., -1.)?;
20        let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
21        let coords = (coords * (2. * std::f64::consts::PI))?;
22        Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
23    }
24
25    fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
26        let device = self.positional_encoding_gaussian_matrix.device();
27        let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
28        let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
29        let x_embed = (x_embed / w as f64)?
30            .reshape((1, ()))?
31            .broadcast_as((h, w))?;
32        let y_embed = (y_embed / h as f64)?
33            .reshape(((), 1))?
34            .broadcast_as((h, w))?;
35        let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
36        self.pe_encoding(&coords)?.permute((2, 0, 1))
37    }
38
39    fn forward_with_coords(
40        &self,
41        coords_input: &Tensor,
42        image_size: (usize, usize),
43    ) -> Result<Tensor> {
44        let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
45        let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
46        let c = coords_input.dim(D::Minus1)?;
47        let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
48        let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
49        self.pe_encoding(&coords)
50    }
51}
52
53#[derive(Debug)]
54pub struct PromptEncoder {
55    pe_layer: PositionEmbeddingRandom,
56    point_embeddings: Vec<candle_nn::Embedding>,
57    not_a_point_embed: candle_nn::Embedding,
58    mask_downscaling_conv1: candle_nn::Conv2d,
59    mask_downscaling_ln1: super::LayerNorm2d,
60    mask_downscaling_conv2: candle_nn::Conv2d,
61    mask_downscaling_ln2: super::LayerNorm2d,
62    mask_downscaling_conv3: candle_nn::Conv2d,
63    no_mask_embed: candle_nn::Embedding,
64    image_embedding_size: (usize, usize),
65    input_image_size: (usize, usize),
66    embed_dim: usize,
67    span: tracing::Span,
68}
69
70impl PromptEncoder {
71    pub fn new(
72        embed_dim: usize,
73        image_embedding_size: (usize, usize),
74        input_image_size: (usize, usize),
75        mask_in_chans: usize,
76        vb: VarBuilder,
77    ) -> Result<Self> {
78        let num_points_embeddings = 4;
79        let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
80        let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
81        let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
82        let cfg = candle_nn::Conv2dConfig {
83            stride: 2,
84            ..Default::default()
85        };
86        let mask_downscaling_conv1 =
87            candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
88        let mask_downscaling_conv2 = candle_nn::conv2d(
89            mask_in_chans / 4,
90            mask_in_chans,
91            2,
92            cfg,
93            vb.pp("mask_downscaling.3"),
94        )?;
95        let mask_downscaling_conv3 = candle_nn::conv2d(
96            mask_in_chans,
97            embed_dim,
98            1,
99            Default::default(),
100            vb.pp("mask_downscaling.6"),
101        )?;
102        let mask_downscaling_ln1 =
103            super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
104        let mask_downscaling_ln2 =
105            super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
106        let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
107        let vb_e = vb.pp("point_embeddings");
108        for i in 0..num_points_embeddings {
109            let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
110            point_embeddings.push(emb)
111        }
112        let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
113        Ok(Self {
114            pe_layer,
115            point_embeddings,
116            not_a_point_embed,
117            mask_downscaling_conv1,
118            mask_downscaling_ln1,
119            mask_downscaling_conv2,
120            mask_downscaling_ln2,
121            mask_downscaling_conv3,
122            no_mask_embed,
123            image_embedding_size,
124            input_image_size,
125            embed_dim,
126            span,
127        })
128    }
129
130    pub fn get_dense_pe(&self) -> Result<Tensor> {
131        self.pe_layer
132            .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
133            .unsqueeze(0)
134    }
135
136    fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
137        masks
138            .apply(&self.mask_downscaling_conv1)?
139            .apply(&self.mask_downscaling_ln1)?
140            .gelu()?
141            .apply(&self.mask_downscaling_conv2)?
142            .apply(&self.mask_downscaling_ln2)?
143            .gelu()?
144            .apply(&self.mask_downscaling_conv3)
145    }
146
147    fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
148        let points = (points + 0.5)?;
149        let dev = points.device();
150        let (points, labels) = if pad {
151            let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
152            let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
153            let points = Tensor::cat(&[&points, &padding_point], 1)?;
154            let labels = Tensor::cat(&[labels, &padding_label], 1)?;
155            (points, labels)
156        } else {
157            (points, labels.clone())
158        };
159        let point_embedding = self
160            .pe_layer
161            .forward_with_coords(&points, self.input_image_size)?;
162        let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
163        let zeros = point_embedding.zeros_like()?;
164        let point_embedding = labels.lt(0f32)?.where_cond(
165            &self
166                .not_a_point_embed
167                .embeddings()
168                .broadcast_as(zeros.shape())?,
169            &point_embedding,
170        )?;
171        let labels0 = labels.eq(0f32)?.where_cond(
172            &self.point_embeddings[0]
173                .embeddings()
174                .broadcast_as(zeros.shape())?,
175            &zeros,
176        )?;
177        let point_embedding = (point_embedding + labels0)?;
178        let labels1 = labels.eq(1f32)?.where_cond(
179            &self.point_embeddings[1]
180                .embeddings()
181                .broadcast_as(zeros.shape())?,
182            &zeros,
183        )?;
184        let point_embedding = (point_embedding + labels1)?;
185        Ok(point_embedding)
186    }
187
188    fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
189        let boxes = (boxes + 0.5)?;
190        let coords = boxes.reshape(((), 2, 2))?;
191        let corner_embedding = self
192            .pe_layer
193            .forward_with_coords(&coords, self.input_image_size)?;
194        let ce1 = corner_embedding.i((.., 0))?;
195        let ce2 = corner_embedding.i((.., 1))?;
196        let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
197        let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
198        Tensor::cat(&[&ce1, &ce2], 1)
199    }
200
201    pub fn forward(
202        &self,
203        points: Option<(&Tensor, &Tensor)>,
204        boxes: Option<&Tensor>,
205        masks: Option<&Tensor>,
206    ) -> Result<(Tensor, Tensor)> {
207        let _enter = self.span.enter();
208        let se_points = match points {
209            Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
210            None => None,
211        };
212        let se_boxes = match boxes {
213            Some(boxes) => Some(self.embed_boxes(boxes)?),
214            None => None,
215        };
216        let sparse_embeddings = match (se_points, se_boxes) {
217            (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
218            (Some(se_points), None) => se_points,
219            (None, Some(se_boxes)) => se_boxes,
220            (None, None) => {
221                let dev = self.no_mask_embed.embeddings().device();
222                Tensor::zeros((1, 0, self.embed_dim), DType::F32, dev)?
223            }
224        };
225
226        let dense_embeddings = match masks {
227            None => {
228                let emb = self.no_mask_embed.embeddings();
229                emb.reshape((1, (), 1, 1))?.expand((
230                    1,
231                    emb.elem_count(),
232                    self.image_embedding_size.0,
233                    self.image_embedding_size.1,
234                ))?
235            }
236            Some(masks) => self.embed_masks(masks)?,
237        };
238        Ok((sparse_embeddings, dense_embeddings))
239    }
240}