1use candle::{DType, IndexOp, Result, Tensor};
2use candle_nn::{Module, VarBuilder};
3
4use super::image_encoder::ImageEncoderViT;
5use super::mask_decoder::MaskDecoder;
6use super::prompt_encoder::PromptEncoder;
7use super::tiny_vit::{tiny_vit_5m, TinyViT};
8
9const PROMPT_EMBED_DIM: usize = 256;
10pub const IMAGE_SIZE: usize = 1024;
11const VIT_PATCH_SIZE: usize = 16;
12const PRED_IOU_THRESH: f32 = 0.88;
13const STABILITY_SCORE_OFFSET: f32 = 1.0;
14const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
15const MODEL_MASK_THRESHOLD: f32 = 0.0;
16const CROP_NMS_THRESH: f32 = 0.7;
17
18#[derive(Debug)]
19enum ImageEncoder {
20 Original(ImageEncoderViT),
21 TinyViT(TinyViT),
22}
23
24impl Module for ImageEncoder {
25 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
26 match self {
27 Self::Original(vit) => vit.forward(xs),
28 Self::TinyViT(vit) => vit.forward(xs),
29 }
30 }
31}
32
33#[derive(Debug)]
34pub struct Sam {
35 image_encoder: ImageEncoder,
36 prompt_encoder: PromptEncoder,
37 mask_decoder: MaskDecoder,
38 pixel_mean: Tensor,
39 pixel_std: Tensor,
40}
41
42impl Sam {
43 pub fn new(
44 encoder_embed_dim: usize,
45 encoder_depth: usize,
46 encoder_num_heads: usize,
47 encoder_global_attn_indexes: &[usize],
48 vb: VarBuilder,
49 ) -> Result<Self> {
50 let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
51
52 let image_encoder = ImageEncoderViT::new(
53 IMAGE_SIZE,
54 VIT_PATCH_SIZE,
55 3,
56 encoder_embed_dim,
57 encoder_depth,
58 encoder_num_heads,
59 PROMPT_EMBED_DIM,
60 true,
61 true,
62 true,
63 14,
64 encoder_global_attn_indexes,
65 vb.pp("image_encoder"),
66 )?;
67 let prompt_encoder = PromptEncoder::new(
68 PROMPT_EMBED_DIM,
69 (image_embedding_size, image_embedding_size),
70 (IMAGE_SIZE, IMAGE_SIZE),
71 16,
72 vb.pp("prompt_encoder"),
73 )?;
74 let mask_decoder = MaskDecoder::new(
75 PROMPT_EMBED_DIM,
76 3,
77 3,
78 256,
79 vb.pp("mask_decoder"),
80 )?;
81 let pixel_mean =
82 Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
83 let pixel_std =
84 Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
85 Ok(Self {
86 image_encoder: ImageEncoder::Original(image_encoder),
87 prompt_encoder,
88 mask_decoder,
89 pixel_std,
90 pixel_mean,
91 })
92 }
93
94 pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
95 let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
96
97 let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
98 let prompt_encoder = PromptEncoder::new(
99 PROMPT_EMBED_DIM,
100 (image_embedding_size, image_embedding_size),
101 (IMAGE_SIZE, IMAGE_SIZE),
102 16,
103 vb.pp("prompt_encoder"),
104 )?;
105 let mask_decoder = MaskDecoder::new(
106 PROMPT_EMBED_DIM,
107 3,
108 3,
109 256,
110 vb.pp("mask_decoder"),
111 )?;
112 let pixel_mean =
113 Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
114 let pixel_std =
115 Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
116 Ok(Self {
117 image_encoder: ImageEncoder::TinyViT(image_encoder),
118 prompt_encoder,
119 mask_decoder,
120 pixel_std,
121 pixel_mean,
122 })
123 }
124
125 pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
126 let img = self.preprocess(img)?.unsqueeze(0)?;
127 self.image_encoder.forward(&img)
128 }
129
130 pub fn forward(
131 &self,
132 img: &Tensor,
133 points: &[(f64, f64, bool)],
134 multimask_output: bool,
135 ) -> Result<(Tensor, Tensor)> {
136 let (_c, original_h, original_w) = img.dims3()?;
137 let img = self.preprocess(img)?.unsqueeze(0)?;
138 let img_embeddings = self.image_encoder.forward(&img)?;
139 let (low_res_mask, iou) = self.forward_for_embeddings(
140 &img_embeddings,
141 original_h,
142 original_w,
143 points,
144 multimask_output,
145 )?;
146 let mask = low_res_mask
147 .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
148 .get(0)?
149 .i((.., ..original_h, ..original_w))?;
150 Ok((mask, iou))
151 }
152
153 pub fn forward_for_embeddings(
159 &self,
160 img_embeddings: &Tensor,
161 original_h: usize,
162 original_w: usize,
163 points: &[(f64, f64, bool)],
164 multimask_output: bool,
165 ) -> Result<(Tensor, Tensor)> {
166 let image_pe = self.prompt_encoder.get_dense_pe()?;
167 let points = if points.is_empty() {
168 None
169 } else {
170 let n_points = points.len();
171 let xys = points
172 .iter()
173 .flat_map(|(x, y, _b)| {
174 let x = (*x as f32) * (original_w as f32);
175 let y = (*y as f32) * (original_h as f32);
176 [x, y]
177 })
178 .collect::<Vec<_>>();
179 let labels = points
180 .iter()
181 .map(|(_x, _y, b)| if *b { 1f32 } else { 0f32 })
182 .collect::<Vec<_>>();
183 let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
184 let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
185 Some((points, labels))
186 };
187 let points = points.as_ref().map(|xy| (&xy.0, &xy.1));
188 let (sparse_prompt_embeddings, dense_prompt_embeddings) =
189 self.prompt_encoder.forward(points, None, None)?;
190 self.mask_decoder.forward(
191 img_embeddings,
192 &image_pe,
193 &sparse_prompt_embeddings,
194 &dense_prompt_embeddings,
195 multimask_output,
196 )
197 }
198
199 pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
200 let img = img
201 .broadcast_mul(&self.pixel_std)?
202 .broadcast_add(&self.pixel_mean)?;
203 img.maximum(&img.zeros_like()?)?
204 .minimum(&(img.ones_like()? * 255.)?)
205 }
206
207 pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
208 let (_c, h, w) = img.dims3()?;
209 let img = img
210 .to_dtype(DType::F32)?
211 .broadcast_sub(&self.pixel_mean)?
212 .broadcast_div(&self.pixel_std)?;
213 if h > IMAGE_SIZE || w > IMAGE_SIZE {
214 candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
215 }
216 let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
217 img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
218 }
219
220 fn process_crop(
221 &self,
222 img: &Tensor,
223 cb: CropBox,
224 point_grids: &[(f64, f64)],
225 ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
226 let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
228 let img = self.preprocess(&img)?.unsqueeze(0)?;
229 let img_embeddings = self.image_encoder.forward(&img)?;
230
231 let crop_w = cb.x1 - cb.x0;
232 let crop_h = cb.y1 - cb.y0;
233
234 let image_pe = self.prompt_encoder.get_dense_pe()?;
236 let points = point_grids
237 .iter()
238 .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
239 .collect::<Vec<_>>();
240
241 let mut bboxes = Vec::new();
242 for points in points.chunks(64) {
243 let points_len = points.len();
245 let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
246 let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
247 let (sparse_prompt_embeddings, dense_prompt_embeddings) =
248 self.prompt_encoder
249 .forward(Some((&in_points, &in_labels)), None, None)?;
250
251 let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
252 &img_embeddings,
253 &image_pe,
254 &sparse_prompt_embeddings,
255 &dense_prompt_embeddings,
256 true,
257 )?;
258 let low_res_mask = low_res_mask.flatten(0, 1)?;
259 let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
260 let dev = low_res_mask.device();
261
262 for (i, iou) in iou_predictions.iter().enumerate() {
263 if *iou < PRED_IOU_THRESH {
265 continue;
266 }
267 let low_res_mask = low_res_mask.get(i)?;
268
269 let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
271 .broadcast_as(low_res_mask.shape())?;
272 let intersections = low_res_mask
273 .ge(&bound)?
274 .to_dtype(DType::F32)?
275 .sum_all()?
276 .to_vec0::<f32>()?;
277 let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
278 .broadcast_as(low_res_mask.shape())?;
279 let unions = low_res_mask
280 .ge(&bound)?
281 .to_dtype(DType::F32)?
282 .sum_all()?
283 .to_vec0::<f32>()?;
284 let stability_score = intersections / unions;
285 if stability_score < STABILITY_SCORE_THRESHOLD {
286 continue;
287 }
288
289 let low_res_mask = low_res_mask
291 .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
292 .to_dtype(DType::U32)?;
293 let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
294 let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
295 let min_max_x = min_max_indexes(&low_res_mask_per_x);
296 let min_max_y = min_max_indexes(&low_res_mask_per_y);
297 if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
298 let bbox = crate::object_detection::Bbox {
299 xmin: x0 as f32,
300 ymin: y0 as f32,
301 xmax: x1 as f32,
302 ymax: y1 as f32,
303 confidence: *iou,
304 data: low_res_mask,
305 };
306 bboxes.push(bbox);
307 }
308 }
312 }
313
314 let mut bboxes = vec![bboxes];
315 crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
317
318 Ok(bboxes.remove(0))
320 }
321
322 pub fn generate_masks(
323 &self,
324 img: &Tensor,
325 points_per_side: usize,
326 crop_n_layer: usize,
327 crop_overlap_ratio: f64,
328 crop_n_points_downscale_factor: usize,
329 ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
330 let (_c, h, w) = img.dims3()?;
331 let point_grids = build_all_layer_point_grids(
332 points_per_side,
333 crop_n_layer,
334 crop_n_points_downscale_factor,
335 );
336 let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
337 let mut bboxes = Vec::new();
338 for crop_box in crop_boxes.into_iter() {
339 let layer_idx = crop_box.layer_idx;
340 let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
341 bboxes.extend(b)
342 }
343 Ok(bboxes)
345 }
346}
347
348fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
350 let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
351 for (i, &s) in values.iter().enumerate() {
352 if s == 0 {
353 continue;
354 }
355 min_i = usize::min(i, min_i);
356 max_i = usize::max(i, max_i);
357 }
358 if max_i < min_i {
359 None
360 } else {
361 Some((min_i, max_i))
362 }
363}
364
365#[derive(Debug)]
366struct CropBox {
367 x0: usize,
368 y0: usize,
369 x1: usize,
370 y1: usize,
371 layer_idx: usize,
372}
373
374impl CropBox {
375 fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
376 Self {
377 x0,
378 y0,
379 x1,
380 y1,
381 layer_idx,
382 }
383 }
384}
385
386fn generate_crop_boxes(
387 (im_h, im_w): (usize, usize),
388 n_layers: usize,
389 overlap_ratio: f64,
390) -> Vec<CropBox> {
391 fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
392 f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
393 }
394
395 let short_side = usize::min(im_h, im_w);
396
397 let mut crop_boxes = Vec::new();
398
399 crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
401
402 for layer_idx in 1..=n_layers {
403 let n_crops_per_side = 1 << layer_idx;
404 let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
405 let crop_w = crop_len(im_w, n_crops_per_side, overlap);
406 let crop_h = crop_len(im_w, n_crops_per_side, overlap);
407
408 for i_x in 0..n_crops_per_side {
409 let x0 = (crop_w - overlap) * i_x;
410 for i_y in 0..n_crops_per_side {
411 let y0 = (crop_h - overlap) * i_y;
412 let x1 = usize::min(im_w, x0 + crop_w);
413 let y1 = usize::min(im_h, y0 + crop_h);
414 crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
415 }
416 }
417 }
418
419 crop_boxes
420}
421
422fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
424 let offset = 1f64 / (2 * n_per_side) as f64;
425 let mut points = Vec::with_capacity(n_per_side * n_per_side);
426 for i_x in 0..n_per_side {
427 let x = offset + i_x as f64 / n_per_side as f64;
428 for i_y in 0..n_per_side {
429 let y = offset + i_y as f64 / n_per_side as f64;
430 points.push((x, y))
431 }
432 }
433 points
434}
435
436fn build_all_layer_point_grids(
437 n_per_side: usize,
438 n_layers: usize,
439 scale_per_layer: usize,
440) -> Vec<Vec<(f64, f64)>> {
441 let mut points_by_layer = Vec::with_capacity(n_layers + 1);
442 for i in 0..=n_layers {
443 let n_points = n_per_side / scale_per_layer.pow(i as u32);
444 points_by_layer.push(build_point_grid(n_points))
445 }
446 points_by_layer
447}