candle_transformers/models/segment_anything/
sam.rs

1use candle::{DType, IndexOp, Result, Tensor};
2use candle_nn::{Module, VarBuilder};
3
4use super::image_encoder::ImageEncoderViT;
5use super::mask_decoder::MaskDecoder;
6use super::prompt_encoder::PromptEncoder;
7use super::tiny_vit::{tiny_vit_5m, TinyViT};
8
9const PROMPT_EMBED_DIM: usize = 256;
10pub const IMAGE_SIZE: usize = 1024;
11const VIT_PATCH_SIZE: usize = 16;
12const PRED_IOU_THRESH: f32 = 0.88;
13const STABILITY_SCORE_OFFSET: f32 = 1.0;
14const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
15const MODEL_MASK_THRESHOLD: f32 = 0.0;
16const CROP_NMS_THRESH: f32 = 0.7;
17
18#[derive(Debug)]
19enum ImageEncoder {
20    Original(ImageEncoderViT),
21    TinyViT(TinyViT),
22}
23
24impl Module for ImageEncoder {
25    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
26        match self {
27            Self::Original(vit) => vit.forward(xs),
28            Self::TinyViT(vit) => vit.forward(xs),
29        }
30    }
31}
32
33#[derive(Debug)]
34pub struct Sam {
35    image_encoder: ImageEncoder,
36    prompt_encoder: PromptEncoder,
37    mask_decoder: MaskDecoder,
38    pixel_mean: Tensor,
39    pixel_std: Tensor,
40}
41
42impl Sam {
43    pub fn new(
44        encoder_embed_dim: usize,
45        encoder_depth: usize,
46        encoder_num_heads: usize,
47        encoder_global_attn_indexes: &[usize],
48        vb: VarBuilder,
49    ) -> Result<Self> {
50        let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
51
52        let image_encoder = ImageEncoderViT::new(
53            IMAGE_SIZE,
54            VIT_PATCH_SIZE,
55            3,
56            encoder_embed_dim,
57            encoder_depth,
58            encoder_num_heads,
59            PROMPT_EMBED_DIM,
60            /* qkv_bias */ true,
61            /* use_rel_pos */ true,
62            /* use_abs_pos */ true,
63            /* window_size */ 14,
64            /* global_attn_indexes */ encoder_global_attn_indexes,
65            vb.pp("image_encoder"),
66        )?;
67        let prompt_encoder = PromptEncoder::new(
68            PROMPT_EMBED_DIM,
69            (image_embedding_size, image_embedding_size),
70            (IMAGE_SIZE, IMAGE_SIZE),
71            16,
72            vb.pp("prompt_encoder"),
73        )?;
74        let mask_decoder = MaskDecoder::new(
75            PROMPT_EMBED_DIM,
76            /* num_multitask_outputs */ 3,
77            /* iou_head_depth */ 3,
78            /* iou_head_hidden_dim */ 256,
79            vb.pp("mask_decoder"),
80        )?;
81        let pixel_mean =
82            Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
83        let pixel_std =
84            Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
85        Ok(Self {
86            image_encoder: ImageEncoder::Original(image_encoder),
87            prompt_encoder,
88            mask_decoder,
89            pixel_std,
90            pixel_mean,
91        })
92    }
93
94    pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
95        let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
96
97        let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
98        let prompt_encoder = PromptEncoder::new(
99            PROMPT_EMBED_DIM,
100            (image_embedding_size, image_embedding_size),
101            (IMAGE_SIZE, IMAGE_SIZE),
102            16,
103            vb.pp("prompt_encoder"),
104        )?;
105        let mask_decoder = MaskDecoder::new(
106            PROMPT_EMBED_DIM,
107            /* num_multitask_outputs */ 3,
108            /* iou_head_depth */ 3,
109            /* iou_head_hidden_dim */ 256,
110            vb.pp("mask_decoder"),
111        )?;
112        let pixel_mean =
113            Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
114        let pixel_std =
115            Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
116        Ok(Self {
117            image_encoder: ImageEncoder::TinyViT(image_encoder),
118            prompt_encoder,
119            mask_decoder,
120            pixel_std,
121            pixel_mean,
122        })
123    }
124
125    pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
126        let img = self.preprocess(img)?.unsqueeze(0)?;
127        self.image_encoder.forward(&img)
128    }
129
130    pub fn forward(
131        &self,
132        img: &Tensor,
133        points: &[(f64, f64, bool)],
134        multimask_output: bool,
135    ) -> Result<(Tensor, Tensor)> {
136        let (_c, original_h, original_w) = img.dims3()?;
137        let img = self.preprocess(img)?.unsqueeze(0)?;
138        let img_embeddings = self.image_encoder.forward(&img)?;
139        let (low_res_mask, iou) = self.forward_for_embeddings(
140            &img_embeddings,
141            original_h,
142            original_w,
143            points,
144            multimask_output,
145        )?;
146        let mask = low_res_mask
147            .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
148            .get(0)?
149            .i((.., ..original_h, ..original_w))?;
150        Ok((mask, iou))
151    }
152
153    /// Generate the mask and IOU predictions from some image embeddings and prompt.
154    ///
155    /// The prompt is specified as a list of points `(x, y, b)`. `x` and `y` are the point
156    /// coordinates (between 0 and 1) and `b` is `true` for points that should be part of the mask
157    /// and `false` for points that should be part of the background and so excluded from the mask.
158    pub fn forward_for_embeddings(
159        &self,
160        img_embeddings: &Tensor,
161        original_h: usize,
162        original_w: usize,
163        points: &[(f64, f64, bool)],
164        multimask_output: bool,
165    ) -> Result<(Tensor, Tensor)> {
166        let image_pe = self.prompt_encoder.get_dense_pe()?;
167        let points = if points.is_empty() {
168            None
169        } else {
170            let n_points = points.len();
171            let xys = points
172                .iter()
173                .flat_map(|(x, y, _b)| {
174                    let x = (*x as f32) * (original_w as f32);
175                    let y = (*y as f32) * (original_h as f32);
176                    [x, y]
177                })
178                .collect::<Vec<_>>();
179            let labels = points
180                .iter()
181                .map(|(_x, _y, b)| if *b { 1f32 } else { 0f32 })
182                .collect::<Vec<_>>();
183            let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
184            let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
185            Some((points, labels))
186        };
187        let points = points.as_ref().map(|xy| (&xy.0, &xy.1));
188        let (sparse_prompt_embeddings, dense_prompt_embeddings) =
189            self.prompt_encoder.forward(points, None, None)?;
190        self.mask_decoder.forward(
191            img_embeddings,
192            &image_pe,
193            &sparse_prompt_embeddings,
194            &dense_prompt_embeddings,
195            multimask_output,
196        )
197    }
198
199    pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
200        let img = img
201            .broadcast_mul(&self.pixel_std)?
202            .broadcast_add(&self.pixel_mean)?;
203        img.maximum(&img.zeros_like()?)?
204            .minimum(&(img.ones_like()? * 255.)?)
205    }
206
207    pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
208        let (_c, h, w) = img.dims3()?;
209        let img = img
210            .to_dtype(DType::F32)?
211            .broadcast_sub(&self.pixel_mean)?
212            .broadcast_div(&self.pixel_std)?;
213        if h > IMAGE_SIZE || w > IMAGE_SIZE {
214            candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
215        }
216        let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
217        img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
218    }
219
220    fn process_crop(
221        &self,
222        img: &Tensor,
223        cb: CropBox,
224        point_grids: &[(f64, f64)],
225    ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
226        // Crop the image and calculate embeddings.
227        let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
228        let img = self.preprocess(&img)?.unsqueeze(0)?;
229        let img_embeddings = self.image_encoder.forward(&img)?;
230
231        let crop_w = cb.x1 - cb.x0;
232        let crop_h = cb.y1 - cb.y0;
233
234        // Generate masks for this crop.
235        let image_pe = self.prompt_encoder.get_dense_pe()?;
236        let points = point_grids
237            .iter()
238            .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
239            .collect::<Vec<_>>();
240
241        let mut bboxes = Vec::new();
242        for points in points.chunks(64) {
243            // Run the model on this batch.
244            let points_len = points.len();
245            let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
246            let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
247            let (sparse_prompt_embeddings, dense_prompt_embeddings) =
248                self.prompt_encoder
249                    .forward(Some((&in_points, &in_labels)), None, None)?;
250
251            let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
252                &img_embeddings,
253                &image_pe,
254                &sparse_prompt_embeddings,
255                &dense_prompt_embeddings,
256                /* multimask_output */ true,
257            )?;
258            let low_res_mask = low_res_mask.flatten(0, 1)?;
259            let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
260            let dev = low_res_mask.device();
261
262            for (i, iou) in iou_predictions.iter().enumerate() {
263                // Filter by predicted IoU.
264                if *iou < PRED_IOU_THRESH {
265                    continue;
266                }
267                let low_res_mask = low_res_mask.get(i)?;
268
269                // Calculate stability score.
270                let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
271                    .broadcast_as(low_res_mask.shape())?;
272                let intersections = low_res_mask
273                    .ge(&bound)?
274                    .to_dtype(DType::F32)?
275                    .sum_all()?
276                    .to_vec0::<f32>()?;
277                let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
278                    .broadcast_as(low_res_mask.shape())?;
279                let unions = low_res_mask
280                    .ge(&bound)?
281                    .to_dtype(DType::F32)?
282                    .sum_all()?
283                    .to_vec0::<f32>()?;
284                let stability_score = intersections / unions;
285                if stability_score < STABILITY_SCORE_THRESHOLD {
286                    continue;
287                }
288
289                // Threshold masks and calculate boxes.
290                let low_res_mask = low_res_mask
291                    .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
292                    .to_dtype(DType::U32)?;
293                let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
294                let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
295                let min_max_x = min_max_indexes(&low_res_mask_per_x);
296                let min_max_y = min_max_indexes(&low_res_mask_per_y);
297                if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
298                    let bbox = crate::object_detection::Bbox {
299                        xmin: x0 as f32,
300                        ymin: y0 as f32,
301                        xmax: x1 as f32,
302                        ymax: y1 as f32,
303                        confidence: *iou,
304                        data: low_res_mask,
305                    };
306                    bboxes.push(bbox);
307                }
308                // TODO:
309                // Filter boxes that touch crop boundaries
310                // Compress to RLE.
311            }
312        }
313
314        let mut bboxes = vec![bboxes];
315        // Remove duplicates within this crop.
316        crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
317
318        // TODO: Return to the original image frame.
319        Ok(bboxes.remove(0))
320    }
321
322    pub fn generate_masks(
323        &self,
324        img: &Tensor,
325        points_per_side: usize,
326        crop_n_layer: usize,
327        crop_overlap_ratio: f64,
328        crop_n_points_downscale_factor: usize,
329    ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
330        let (_c, h, w) = img.dims3()?;
331        let point_grids = build_all_layer_point_grids(
332            points_per_side,
333            crop_n_layer,
334            crop_n_points_downscale_factor,
335        );
336        let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
337        let mut bboxes = Vec::new();
338        for crop_box in crop_boxes.into_iter() {
339            let layer_idx = crop_box.layer_idx;
340            let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
341            bboxes.extend(b)
342        }
343        // TODO: remove duplicates
344        Ok(bboxes)
345    }
346}
347
348// Return the first and last indexes i for which values[i] > 0
349fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
350    let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
351    for (i, &s) in values.iter().enumerate() {
352        if s == 0 {
353            continue;
354        }
355        min_i = usize::min(i, min_i);
356        max_i = usize::max(i, max_i);
357    }
358    if max_i < min_i {
359        None
360    } else {
361        Some((min_i, max_i))
362    }
363}
364
365#[derive(Debug)]
366struct CropBox {
367    x0: usize,
368    y0: usize,
369    x1: usize,
370    y1: usize,
371    layer_idx: usize,
372}
373
374impl CropBox {
375    fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
376        Self {
377            x0,
378            y0,
379            x1,
380            y1,
381            layer_idx,
382        }
383    }
384}
385
386fn generate_crop_boxes(
387    (im_h, im_w): (usize, usize),
388    n_layers: usize,
389    overlap_ratio: f64,
390) -> Vec<CropBox> {
391    fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
392        f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
393    }
394
395    let short_side = usize::min(im_h, im_w);
396
397    let mut crop_boxes = Vec::new();
398
399    // Original image.
400    crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
401
402    for layer_idx in 1..=n_layers {
403        let n_crops_per_side = 1 << layer_idx;
404        let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
405        let crop_w = crop_len(im_w, n_crops_per_side, overlap);
406        let crop_h = crop_len(im_w, n_crops_per_side, overlap);
407
408        for i_x in 0..n_crops_per_side {
409            let x0 = (crop_w - overlap) * i_x;
410            for i_y in 0..n_crops_per_side {
411                let y0 = (crop_h - overlap) * i_y;
412                let x1 = usize::min(im_w, x0 + crop_w);
413                let y1 = usize::min(im_h, y0 + crop_h);
414                crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
415            }
416        }
417    }
418
419    crop_boxes
420}
421
422// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
423fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
424    let offset = 1f64 / (2 * n_per_side) as f64;
425    let mut points = Vec::with_capacity(n_per_side * n_per_side);
426    for i_x in 0..n_per_side {
427        let x = offset + i_x as f64 / n_per_side as f64;
428        for i_y in 0..n_per_side {
429            let y = offset + i_y as f64 / n_per_side as f64;
430            points.push((x, y))
431        }
432    }
433    points
434}
435
436fn build_all_layer_point_grids(
437    n_per_side: usize,
438    n_layers: usize,
439    scale_per_layer: usize,
440) -> Vec<Vec<(f64, f64)>> {
441    let mut points_by_layer = Vec::with_capacity(n_layers + 1);
442    for i in 0..=n_layers {
443        let n_points = n_per_side / scale_per_layer.pow(i as u32);
444        points_by_layer.push(build_point_grid(n_points))
445    }
446    points_by_layer
447}