1use crate::{
25 error::{VisionError, VisionResult},
26 fpn::top_down::FeatureMap,
27 handle::LcgRng,
28 patch_embed::{PatchEmbed, PatchEmbedConfig, pos_2d_sincos},
29 vit::vit_block::{gelu_exact, layer_norm, linear, softmax_rows},
30 vit::{ViTBlock, ViTBlockConfig},
31};
32
33use std::f32::consts::PI;
34
35fn filled(n: usize, scale: f32, rng: &mut LcgRng) -> Vec<f32> {
39 let mut v = vec![0.0f32; n];
40 rng.fill_normal(&mut v);
41 for x in &mut v {
42 *x *= scale;
43 }
44 v
45}
46
47#[derive(Debug, Clone)]
49struct LayerNormParams {
50 weight: Vec<f32>,
51 bias: Vec<f32>,
52}
53
54impl LayerNormParams {
55 fn new(d: usize) -> Self {
56 Self {
57 weight: vec![1.0f32; d],
58 bias: vec![0.0f32; d],
59 }
60 }
61
62 fn apply(&self, x: &[f32], n: usize, d: usize) -> Vec<f32> {
63 layer_norm(x, &self.weight, &self.bias, n, d, 1e-6)
64 }
65}
66
67#[derive(Debug, Clone)]
69struct Mlp {
70 w1: Vec<f32>,
71 b1: Vec<f32>,
72 w2: Vec<f32>,
73 b2: Vec<f32>,
74 d_in: usize,
75 hidden: usize,
76 d_out: usize,
77}
78
79impl Mlp {
80 fn new(d_in: usize, hidden: usize, d_out: usize, rng: &mut LcgRng) -> Self {
81 Self {
82 w1: filled(hidden * d_in, (2.0 / d_in as f32).sqrt(), rng),
83 b1: vec![0.0f32; hidden],
84 w2: filled(d_out * hidden, (2.0 / hidden as f32).sqrt(), rng),
85 b2: vec![0.0f32; d_out],
86 d_in,
87 hidden,
88 d_out,
89 }
90 }
91
92 fn apply(&self, x: &[f32]) -> Vec<f32> {
94 let mut h = linear(x, &self.w1, &self.b1, self.d_in, self.hidden);
95 for v in &mut h {
96 *v = gelu_exact(*v);
97 }
98 linear(&h, &self.w2, &self.b2, self.hidden, self.d_out)
99 }
100}
101
102pub struct MultiHeadAttention {
109 wq: Vec<f32>,
110 bq: Vec<f32>,
111 wk: Vec<f32>,
112 bk: Vec<f32>,
113 wv: Vec<f32>,
114 bv: Vec<f32>,
115 wo: Vec<f32>,
116 bo: Vec<f32>,
117 embed_dim: usize,
118 n_heads: usize,
119 head_dim: usize,
120}
121
122impl MultiHeadAttention {
123 pub fn new(embed_dim: usize, n_heads: usize, rng: &mut LcgRng) -> VisionResult<Self> {
130 if n_heads == 0 {
131 return Err(VisionError::InvalidNumHeads(n_heads));
132 }
133 if embed_dim % n_heads != 0 {
134 return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
135 }
136 let scale = 1.0 / (embed_dim as f32).sqrt();
137 Ok(Self {
138 wq: filled(embed_dim * embed_dim, scale, rng),
139 bq: vec![0.0f32; embed_dim],
140 wk: filled(embed_dim * embed_dim, scale, rng),
141 bk: vec![0.0f32; embed_dim],
142 wv: filled(embed_dim * embed_dim, scale, rng),
143 bv: vec![0.0f32; embed_dim],
144 wo: filled(embed_dim * embed_dim, scale, rng),
145 bo: vec![0.0f32; embed_dim],
146 embed_dim,
147 n_heads,
148 head_dim: embed_dim / n_heads,
149 })
150 }
151
152 pub fn forward(
160 &self,
161 q_in: &[f32],
162 k_in: &[f32],
163 v_in: &[f32],
164 n_q: usize,
165 n_k: usize,
166 ) -> VisionResult<(Vec<f32>, Vec<f32>)> {
167 let e = self.embed_dim;
168 if q_in.len() != n_q * e {
169 return Err(VisionError::DimensionMismatch {
170 expected: n_q * e,
171 got: q_in.len(),
172 });
173 }
174 if k_in.len() != n_k * e || v_in.len() != n_k * e {
175 return Err(VisionError::DimensionMismatch {
176 expected: n_k * e,
177 got: k_in.len(),
178 });
179 }
180
181 let q = linear(q_in, &self.wq, &self.bq, e, e);
182 let k = linear(k_in, &self.wk, &self.bk, e, e);
183 let v = linear(v_in, &self.wv, &self.bv, e, e);
184
185 let scale = 1.0 / (self.head_dim as f32).sqrt();
186 let mut concat = vec![0.0f32; n_q * e];
187 let mut weights = vec![0.0f32; self.n_heads * n_q * n_k];
188 let mut scores = vec![0.0f32; n_q * n_k];
189
190 for h in 0..self.n_heads {
191 let off = h * self.head_dim;
192 for i in 0..n_q {
193 for j in 0..n_k {
194 let mut dot = 0.0f32;
195 for d in 0..self.head_dim {
196 dot += q[i * e + off + d] * k[j * e + off + d];
197 }
198 scores[i * n_k + j] = dot * scale;
199 }
200 }
201 softmax_rows(&mut scores, n_q, n_k);
202
203 for i in 0..n_q {
205 let w_row = (h * n_q + i) * n_k;
206 let s_row = i * n_k;
207 weights[w_row..w_row + n_k].copy_from_slice(&scores[s_row..s_row + n_k]);
208 for d in 0..self.head_dim {
209 let mut acc = 0.0f32;
210 for j in 0..n_k {
211 acc += scores[s_row + j] * v[j * e + off + d];
212 }
213 concat[i * e + off + d] = acc;
214 }
215 }
216 }
217
218 let out = linear(&concat, &self.wo, &self.bo, e, e);
219 if out.iter().any(|x| !x.is_finite()) {
220 return Err(VisionError::NonFinite("SAM attention output"));
221 }
222 Ok((out, weights))
223 }
224}
225
226#[derive(Debug, Clone)]
230pub struct TwoWayBlockOutput {
231 pub tokens: Vec<f32>,
233 pub image: Vec<f32>,
235 pub self_weights: Vec<f32>,
237 pub token_to_image_weights: Vec<f32>,
239 pub image_to_token_weights: Vec<f32>,
241}
242
243pub struct TwoWayAttentionBlock {
247 self_attn: MultiHeadAttention,
248 cross_token_to_image: MultiHeadAttention,
249 cross_image_to_token: MultiHeadAttention,
250 mlp: Mlp,
251 norm1: LayerNormParams,
252 norm2: LayerNormParams,
253 norm3: LayerNormParams,
254 norm4: LayerNormParams,
255 embed_dim: usize,
256}
257
258impl TwoWayAttentionBlock {
259 pub fn new(
264 embed_dim: usize,
265 n_heads: usize,
266 mlp_dim: usize,
267 rng: &mut LcgRng,
268 ) -> VisionResult<Self> {
269 Ok(Self {
270 self_attn: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
271 cross_token_to_image: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
272 cross_image_to_token: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
273 mlp: Mlp::new(embed_dim, mlp_dim, embed_dim, rng),
274 norm1: LayerNormParams::new(embed_dim),
275 norm2: LayerNormParams::new(embed_dim),
276 norm3: LayerNormParams::new(embed_dim),
277 norm4: LayerNormParams::new(embed_dim),
278 embed_dim,
279 })
280 }
281
282 pub fn forward(
292 &self,
293 tokens: &[f32],
294 image: &[f32],
295 query_pe: &[f32],
296 key_pe: &[f32],
297 ) -> VisionResult<TwoWayBlockOutput> {
298 let e = self.embed_dim;
299 if tokens.len() != query_pe.len() {
300 return Err(VisionError::DimensionMismatch {
301 expected: tokens.len(),
302 got: query_pe.len(),
303 });
304 }
305 if image.len() != key_pe.len() {
306 return Err(VisionError::DimensionMismatch {
307 expected: image.len(),
308 got: key_pe.len(),
309 });
310 }
311 if tokens.len() % e != 0 || image.len() % e != 0 {
312 return Err(VisionError::DimensionMismatch {
313 expected: e,
314 got: tokens.len() % e,
315 });
316 }
317 let n_t = tokens.len() / e;
318 let n_i = image.len() / e;
319
320 let q = add_vec(tokens, query_pe);
322 let (sa, self_w) = self_attn_or_err(&self.self_attn, &q, tokens, n_t)?;
323 let mut tokens_cur = add_vec(tokens, &sa);
324 tokens_cur = self.norm1.apply(&tokens_cur, n_t, e);
325
326 let q = add_vec(&tokens_cur, query_pe);
328 let k = add_vec(image, key_pe);
329 let (ca, t2i_w) = self.cross_token_to_image.forward(&q, &k, image, n_t, n_i)?;
330 tokens_cur = add_vec(&tokens_cur, &ca);
331 tokens_cur = self.norm2.apply(&tokens_cur, n_t, e);
332
333 let m = self.mlp.apply(&tokens_cur);
335 tokens_cur = add_vec(&tokens_cur, &m);
336 tokens_cur = self.norm3.apply(&tokens_cur, n_t, e);
337
338 let q = add_vec(image, key_pe);
340 let k = add_vec(&tokens_cur, query_pe);
341 let (ca2, i2t_w) = self
342 .cross_image_to_token
343 .forward(&q, &k, &tokens_cur, n_i, n_t)?;
344 let mut image_cur = add_vec(image, &ca2);
345 image_cur = self.norm4.apply(&image_cur, n_i, e);
346
347 Ok(TwoWayBlockOutput {
348 tokens: tokens_cur,
349 image: image_cur,
350 self_weights: self_w,
351 token_to_image_weights: t2i_w,
352 image_to_token_weights: i2t_w,
353 })
354 }
355}
356
357fn self_attn_or_err(
359 attn: &MultiHeadAttention,
360 qk: &[f32],
361 v: &[f32],
362 n: usize,
363) -> VisionResult<(Vec<f32>, Vec<f32>)> {
364 attn.forward(qk, qk, v, n, n)
365}
366
367pub struct TwoWayTransformer {
372 blocks: Vec<TwoWayAttentionBlock>,
373 final_attn: MultiHeadAttention,
374 final_norm: LayerNormParams,
375 embed_dim: usize,
376}
377
378impl TwoWayTransformer {
379 fn new(
380 embed_dim: usize,
381 n_heads: usize,
382 depth: usize,
383 mlp_dim: usize,
384 rng: &mut LcgRng,
385 ) -> VisionResult<Self> {
386 let mut blocks = Vec::with_capacity(depth);
387 for _ in 0..depth {
388 blocks.push(TwoWayAttentionBlock::new(embed_dim, n_heads, mlp_dim, rng)?);
389 }
390 Ok(Self {
391 blocks,
392 final_attn: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
393 final_norm: LayerNormParams::new(embed_dim),
394 embed_dim,
395 })
396 }
397
398 pub fn forward(
407 &self,
408 image: &[f32],
409 image_pe: &[f32],
410 point_tokens: &[f32],
411 ) -> VisionResult<(Vec<f32>, Vec<f32>)> {
412 let e = self.embed_dim;
413 let n_t = point_tokens.len() / e;
414 let n_i = image.len() / e;
415 let query_pe = point_tokens.to_vec();
416
417 let mut tokens = point_tokens.to_vec();
418 let mut img = image.to_vec();
419 for block in &self.blocks {
420 let out = block.forward(&tokens, &img, &query_pe, image_pe)?;
421 tokens = out.tokens;
422 img = out.image;
423 }
424
425 let q = add_vec(&tokens, &query_pe);
427 let k = add_vec(&img, image_pe);
428 let (attn, _w) = self.final_attn.forward(&q, &k, &img, n_t, n_i)?;
429 tokens = add_vec(&tokens, &attn);
430 tokens = self.final_norm.apply(&tokens, n_t, e);
431
432 Ok((tokens, img))
433 }
434}
435
436pub struct ImageEncoder {
440 patch_embed: PatchEmbed,
441 pos_embed: Vec<f32>,
442 blocks: Vec<ViTBlock>,
443 neck_w: Vec<f32>,
444 neck_b: Vec<f32>,
445 grid: usize,
446 embed_dim: usize,
447}
448
449impl ImageEncoder {
450 fn new(cfg: &SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
451 let pe_cfg =
452 PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)?;
453 let grid = pe_cfg.grid_size();
454 let patch_embed = PatchEmbed::new(pe_cfg, rng);
455 let pos_embed = pos_2d_sincos(grid, grid, cfg.embed_dim)?;
456 let block_cfg = ViTBlockConfig::new(cfg.embed_dim, cfg.enc_heads, cfg.enc_mlp_ratio)?;
457 let mut blocks = Vec::with_capacity(cfg.enc_depth);
458 for _ in 0..cfg.enc_depth {
459 blocks.push(ViTBlock::new(block_cfg.clone(), rng));
460 }
461 let scale = 1.0 / (cfg.embed_dim as f32).sqrt();
462 Ok(Self {
463 patch_embed,
464 pos_embed,
465 blocks,
466 neck_w: filled(cfg.embed_dim * cfg.embed_dim, scale, rng),
467 neck_b: vec![0.0f32; cfg.embed_dim],
468 grid,
469 embed_dim: cfg.embed_dim,
470 })
471 }
472
473 pub fn forward(&self, image: &[f32]) -> VisionResult<FeatureMap> {
478 let e = self.embed_dim;
479 let n_patches = self.grid * self.grid;
480 let mut tokens = self.patch_embed.forward(image)?;
481 for (t, p) in tokens.iter_mut().zip(self.pos_embed.iter()) {
483 *t += *p;
484 }
485 for block in &self.blocks {
486 tokens = block.forward(&tokens, n_patches)?;
487 }
488 let tokens = linear(&tokens, &self.neck_w, &self.neck_b, e, e);
490 let chw = tokens_to_chw(&tokens, e, self.grid, self.grid);
492 FeatureMap::new(chw, e, self.grid, self.grid)
493 }
494}
495
496pub struct PositionEmbeddingRandom {
501 gaussian: Vec<f32>,
502 num_freq: usize,
503}
504
505impl PositionEmbeddingRandom {
506 fn new(num_freq: usize, rng: &mut LcgRng) -> Self {
507 Self {
509 gaussian: filled(2 * num_freq, 1.0, rng),
510 num_freq,
511 }
512 }
513
514 fn encode_point(&self, x: f32, y: f32) -> Vec<f32> {
516 let nf = self.num_freq;
517 let mut out = vec![0.0f32; 2 * nf];
518 for f in 0..nf {
519 let proj = 2.0 * PI * (x * self.gaussian[f] + y * self.gaussian[nf + f]);
520 out[f] = proj.sin();
521 out[nf + f] = proj.cos();
522 }
523 out
524 }
525
526 fn encode_grid(&self, h: usize, w: usize) -> Vec<f32> {
529 let dim = 2 * self.num_freq;
530 let mut out = vec![0.0f32; dim * h * w];
531 for i in 0..h {
532 for j in 0..w {
533 let x = (j as f32 + 0.5) / w as f32;
534 let y = (i as f32 + 0.5) / h as f32;
535 let enc = self.encode_point(x, y);
536 for (c, &val) in enc.iter().enumerate() {
537 out[(c * h + i) * w + j] = val;
538 }
539 }
540 }
541 out
542 }
543}
544
545pub struct PromptEncoder {
550 pe_layer: PositionEmbeddingRandom,
551 point_embeddings: Vec<f32>,
554 corner_embeddings: Vec<f32>,
556 not_a_point: Vec<f32>,
558 no_mask_embed: Vec<f32>,
560 mask_w: Vec<f32>,
562 mask_b: Vec<f32>,
563 embed_dim: usize,
564 grid: usize,
565 input_size: f32,
566}
567
568impl PromptEncoder {
569 fn new(cfg: &SamConfig, rng: &mut LcgRng) -> Self {
570 let e = cfg.embed_dim;
571 let scale = 1.0 / (e as f32).sqrt();
572 Self {
573 pe_layer: PositionEmbeddingRandom::new(e / 2, rng),
574 point_embeddings: filled(2 * e, 0.1, rng),
575 corner_embeddings: filled(2 * e, 0.1, rng),
576 not_a_point: filled(e, 0.1, rng),
577 no_mask_embed: filled(e, 0.1, rng),
578 mask_w: filled(e, scale, rng),
579 mask_b: vec![0.0f32; e],
580 embed_dim: e,
581 grid: cfg.img_size / cfg.patch_size,
582 input_size: cfg.img_size as f32,
583 }
584 }
585
586 #[must_use]
588 pub fn dense_positional_encoding(&self) -> Vec<f32> {
589 self.pe_layer.encode_grid(self.grid, self.grid)
590 }
591
592 pub fn encode_points(&self, coords: &[f32], labels: &[i32]) -> VisionResult<Vec<f32>> {
601 let n = labels.len();
602 if coords.len() != n * 2 {
603 return Err(VisionError::DimensionMismatch {
604 expected: n * 2,
605 got: coords.len(),
606 });
607 }
608 let e = self.embed_dim;
609 let mut out = vec![0.0f32; n * e];
610 for p in 0..n {
611 let x = coords[p * 2] / self.input_size;
612 let y = coords[p * 2 + 1] / self.input_size;
613 let pe = self.pe_layer.encode_point(x, y);
614 let dst = &mut out[p * e..(p + 1) * e];
615 if labels[p] < 0 {
616 for (d, slot) in dst.iter_mut().enumerate() {
618 *slot = self.not_a_point[d];
619 }
620 } else {
621 let label_off = if labels[p] >= 1 { e } else { 0 };
622 for (d, slot) in dst.iter_mut().enumerate() {
623 *slot = pe[d] + self.point_embeddings[label_off + d];
624 }
625 }
626 }
627 Ok(out)
628 }
629
630 pub fn encode_box(&self, box4: &[f32]) -> VisionResult<Vec<f32>> {
636 if box4.len() != 4 {
637 return Err(VisionError::DimensionMismatch {
638 expected: 4,
639 got: box4.len(),
640 });
641 }
642 let e = self.embed_dim;
643 let corners = [(box4[0], box4[1], 0usize), (box4[2], box4[3], 1usize)];
644 let mut out = vec![0.0f32; 2 * e];
645 for (idx, &(cx, cy, corner)) in corners.iter().enumerate() {
646 let pe = self
647 .pe_layer
648 .encode_point(cx / self.input_size, cy / self.input_size);
649 let dst = &mut out[idx * e..(idx + 1) * e];
650 for (d, slot) in dst.iter_mut().enumerate() {
651 *slot = pe[d] + self.corner_embeddings[corner * e + d];
652 }
653 }
654 Ok(out)
655 }
656
657 pub fn encode_mask(&self, mask: Option<&[f32]>) -> VisionResult<Vec<f32>> {
665 let e = self.embed_dim;
666 let hw = self.grid * self.grid;
667 let mut out = vec![0.0f32; e * hw];
668 match mask {
669 None => {
670 for c in 0..e {
671 let val = self.no_mask_embed[c];
672 for p in 0..hw {
673 out[c * hw + p] = val;
674 }
675 }
676 }
677 Some(m) => {
678 if m.len() != hw {
679 return Err(VisionError::DimensionMismatch {
680 expected: hw,
681 got: m.len(),
682 });
683 }
684 for c in 0..e {
686 let w = self.mask_w[c];
687 let b = self.mask_b[c];
688 for p in 0..hw {
689 out[c * hw + p] = w * m[p] + b;
690 }
691 }
692 }
693 }
694 Ok(out)
695 }
696}
697
698#[derive(Debug, Clone)]
702pub struct MaskPrediction {
703 pub masks: Vec<f32>,
705 pub iou: Vec<f32>,
707 pub n_mask: usize,
709 pub height: usize,
711 pub width: usize,
713}
714
715pub struct MaskDecoder {
717 transformer: TwoWayTransformer,
718 iou_token: Vec<f32>,
719 mask_tokens: Vec<f32>,
720 upscale_w: Vec<f32>,
721 upscale_b: Vec<f32>,
722 hypernets: Vec<Mlp>,
723 iou_head: Mlp,
724 n_mask: usize,
725 embed_dim: usize,
726}
727
728impl MaskDecoder {
729 fn new(cfg: &SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
730 let e = cfg.embed_dim;
731 let transformer =
732 TwoWayTransformer::new(e, cfg.dec_heads, cfg.dec_depth, cfg.dec_mlp_dim, rng)?;
733 let scale = 1.0 / (e as f32).sqrt();
734 let hypernets = (0..cfg.n_mask).map(|_| Mlp::new(e, e, e, rng)).collect();
735 Ok(Self {
736 transformer,
737 iou_token: filled(e, 0.02, rng),
738 mask_tokens: filled(cfg.n_mask * e, 0.02, rng),
739 upscale_w: filled(e * e, scale, rng),
740 upscale_b: vec![0.0f32; e],
741 hypernets,
742 iou_head: Mlp::new(e, e, cfg.n_mask, rng),
743 n_mask: cfg.n_mask,
744 embed_dim: e,
745 })
746 }
747
748 pub fn forward(
759 &self,
760 image_embedding: &FeatureMap,
761 image_pe: &[f32],
762 sparse_prompt: &[f32],
763 dense_prompt: &[f32],
764 ) -> VisionResult<MaskPrediction> {
765 let e = self.embed_dim;
766 let (h, w) = (image_embedding.height, image_embedding.width);
767 let hw = h * w;
768 if image_embedding.channels != e || image_embedding.data.len() != e * hw {
769 return Err(VisionError::DimensionMismatch {
770 expected: e * hw,
771 got: image_embedding.data.len(),
772 });
773 }
774 if image_pe.len() != e * hw || dense_prompt.len() != e * hw {
775 return Err(VisionError::DimensionMismatch {
776 expected: e * hw,
777 got: image_pe.len(),
778 });
779 }
780 if sparse_prompt.len() % e != 0 {
781 return Err(VisionError::DimensionMismatch {
782 expected: e,
783 got: sparse_prompt.len() % e,
784 });
785 }
786 let n_sparse = sparse_prompt.len() / e;
787
788 let n_tokens = 1 + self.n_mask + n_sparse;
790 let mut tokens = Vec::with_capacity(n_tokens * e);
791 tokens.extend_from_slice(&self.iou_token);
792 tokens.extend_from_slice(&self.mask_tokens);
793 tokens.extend_from_slice(sparse_prompt);
794
795 let mut src_chw = image_embedding.data.clone();
798 for (s, d) in src_chw.iter_mut().zip(dense_prompt.iter()) {
799 *s += *d;
800 }
801 let src_tokens = chw_to_tokens(&src_chw, e, h, w);
802 let pe_tokens = chw_to_tokens(image_pe, e, h, w);
803
804 let (tokens_out, src_out) = self.transformer.forward(&src_tokens, &pe_tokens, &tokens)?;
806
807 let src_img = tokens_to_chw(&src_out, e, h, w);
809 let up = upsample2x_chw(&src_img, e, h, w);
810 let (uh, uw) = (h * 2, w * 2);
811 let up_tokens = chw_to_tokens(&up, e, uh, uw);
813 let up_tokens = linear(&up_tokens, &self.upscale_w, &self.upscale_b, e, e);
814 let up = tokens_to_chw(&up_tokens, e, uh, uw);
815
816 let mut masks = vec![0.0f32; self.n_mask * uh * uw];
818 for m in 0..self.n_mask {
819 let token = &tokens_out[(1 + m) * e..(2 + m) * e];
820 let filter = self.hypernets[m].apply(token); for p in 0..(uh * uw) {
822 let mut acc = 0.0f32;
823 for c in 0..e {
824 acc += filter[c] * up[c * uh * uw + p];
825 }
826 masks[m * uh * uw + p] = acc;
827 }
828 }
829
830 let iou_token = &tokens_out[0..e];
832 let iou = self.iou_head.apply(iou_token);
833
834 if masks.iter().chain(iou.iter()).any(|v| !v.is_finite()) {
835 return Err(VisionError::NonFinite("SAM mask decoder output"));
836 }
837
838 Ok(MaskPrediction {
839 masks,
840 iou,
841 n_mask: self.n_mask,
842 height: uh,
843 width: uw,
844 })
845 }
846}
847
848#[derive(Debug, Clone, PartialEq)]
852pub struct SamConfig {
853 pub in_chans: usize,
855 pub img_size: usize,
857 pub patch_size: usize,
859 pub embed_dim: usize,
861 pub enc_depth: usize,
863 pub enc_heads: usize,
865 pub enc_mlp_ratio: usize,
867 pub dec_depth: usize,
869 pub dec_heads: usize,
871 pub dec_mlp_dim: usize,
873 pub n_mask: usize,
875}
876
877impl SamConfig {
878 pub fn new(
888 in_chans: usize,
889 img_size: usize,
890 patch_size: usize,
891 embed_dim: usize,
892 enc_depth: usize,
893 enc_heads: usize,
894 enc_mlp_ratio: usize,
895 dec_depth: usize,
896 dec_heads: usize,
897 dec_mlp_dim: usize,
898 n_mask: usize,
899 ) -> VisionResult<Self> {
900 if embed_dim == 0 || embed_dim % 2 != 0 {
901 return Err(VisionError::InvalidEmbedDim(embed_dim));
902 }
903 if patch_size == 0 || img_size % patch_size != 0 {
904 return Err(VisionError::InvalidPatchSize {
905 patch_size,
906 img_size,
907 });
908 }
909 if enc_heads == 0 || embed_dim % enc_heads != 0 {
910 return Err(VisionError::HeadDimMismatch {
911 n_heads: enc_heads,
912 embed_dim,
913 });
914 }
915 if dec_heads == 0 || embed_dim % dec_heads != 0 {
916 return Err(VisionError::HeadDimMismatch {
917 n_heads: dec_heads,
918 embed_dim,
919 });
920 }
921 if n_mask == 0 {
922 return Err(VisionError::EmptyInput("sam n_mask"));
923 }
924 Ok(Self {
925 in_chans,
926 img_size,
927 patch_size,
928 embed_dim,
929 enc_depth,
930 enc_heads,
931 enc_mlp_ratio,
932 dec_depth,
933 dec_heads,
934 dec_mlp_dim,
935 n_mask,
936 })
937 }
938
939 #[must_use]
942 pub fn tiny() -> Self {
943 Self {
944 in_chans: 3,
945 img_size: 32,
946 patch_size: 8,
947 embed_dim: 16,
948 enc_depth: 2,
949 enc_heads: 2,
950 enc_mlp_ratio: 2,
951 dec_depth: 2,
952 dec_heads: 2,
953 dec_mlp_dim: 32,
954 n_mask: 3,
955 }
956 }
957}
958
959pub struct Sam {
961 cfg: SamConfig,
962 image_encoder: ImageEncoder,
963 prompt_encoder: PromptEncoder,
964 mask_decoder: MaskDecoder,
965}
966
967impl Sam {
968 pub fn new(cfg: SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
973 let image_encoder = ImageEncoder::new(&cfg, rng)?;
974 let prompt_encoder = PromptEncoder::new(&cfg, rng);
975 let mask_decoder = MaskDecoder::new(&cfg, rng)?;
976 Ok(Self {
977 cfg,
978 image_encoder,
979 prompt_encoder,
980 mask_decoder,
981 })
982 }
983
984 #[must_use]
986 #[inline]
987 pub fn config(&self) -> &SamConfig {
988 &self.cfg
989 }
990
991 #[must_use]
993 #[inline]
994 pub fn prompt_encoder(&self) -> &PromptEncoder {
995 &self.prompt_encoder
996 }
997
998 pub fn encode_image(&self, image: &[f32]) -> VisionResult<FeatureMap> {
1003 self.image_encoder.forward(image)
1004 }
1005
1006 pub fn predict(
1012 &self,
1013 image: &[f32],
1014 point_coords: &[f32],
1015 point_labels: &[i32],
1016 mask: Option<&[f32]>,
1017 ) -> VisionResult<MaskPrediction> {
1018 let embedding = self.encode_image(image)?;
1019 let sparse = self
1020 .prompt_encoder
1021 .encode_points(point_coords, point_labels)?;
1022 let dense = self.prompt_encoder.encode_mask(mask)?;
1023 let image_pe = self.prompt_encoder.dense_positional_encoding();
1024 self.mask_decoder
1025 .forward(&embedding, &image_pe, &sparse, &dense)
1026 }
1027}
1028
1029fn add_vec(a: &[f32], b: &[f32]) -> Vec<f32> {
1033 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
1034}
1035
1036fn chw_to_tokens(chw: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1038 let hw = h * w;
1039 let mut out = vec![0.0f32; hw * c];
1040 for ch in 0..c {
1041 for p in 0..hw {
1042 out[p * c + ch] = chw[ch * hw + p];
1043 }
1044 }
1045 out
1046}
1047
1048fn tokens_to_chw(tokens: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1050 let hw = h * w;
1051 let mut out = vec![0.0f32; c * hw];
1052 for p in 0..hw {
1053 for ch in 0..c {
1054 out[ch * hw + p] = tokens[p * c + ch];
1055 }
1056 }
1057 out
1058}
1059
1060fn upsample2x_chw(chw: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1062 let (h2, w2) = (h * 2, w * 2);
1063 let mut out = vec![0.0f32; c * h2 * w2];
1064 for ch in 0..c {
1065 for i in 0..h {
1066 for j in 0..w {
1067 let v = chw[(ch * h + i) * w + j];
1068 let oi = i * 2;
1069 let oj = j * 2;
1070 out[(ch * h2 + oi) * w2 + oj] = v;
1071 out[(ch * h2 + oi) * w2 + oj + 1] = v;
1072 out[(ch * h2 + oi + 1) * w2 + oj] = v;
1073 out[(ch * h2 + oi + 1) * w2 + oj + 1] = v;
1074 }
1075 }
1076 }
1077 out
1078}
1079
1080#[cfg(test)]
1083mod tests {
1084 use super::*;
1085
1086 fn random_image(cfg: &SamConfig, seed: u64) -> Vec<f32> {
1087 let mut rng = LcgRng::new(seed);
1088 let mut img = vec![0.0f32; cfg.in_chans * cfg.img_size * cfg.img_size];
1089 rng.fill_normal(&mut img);
1090 img
1091 }
1092
1093 #[test]
1096 fn config_tiny_valid() {
1097 let cfg = SamConfig::tiny();
1098 assert_eq!(cfg.embed_dim, 16);
1099 assert_eq!(cfg.n_mask, 3);
1100 }
1101
1102 #[test]
1103 fn config_bad_heads_errors() {
1104 let r = SamConfig::new(3, 32, 8, 16, 2, 3, 2, 2, 2, 32, 3);
1106 assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
1107 }
1108
1109 #[test]
1112 fn image_embedding_shape() {
1113 let cfg = SamConfig::tiny();
1114 let mut rng = LcgRng::new(1);
1115 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1116 let img = random_image(&cfg, 2);
1117 let emb = sam.encode_image(&img).expect("ok");
1118 let grid = cfg.img_size / cfg.patch_size; assert_eq!(
1120 (emb.channels, emb.height, emb.width),
1121 (cfg.embed_dim, grid, grid)
1122 );
1123 assert!(emb.data.iter().all(|v| v.is_finite()));
1124 }
1125
1126 #[test]
1129 fn different_points_give_different_sparse_embeddings() {
1130 let cfg = SamConfig::tiny();
1131 let mut rng = LcgRng::new(3);
1132 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1133 let pe = sam.prompt_encoder();
1134 let a = pe.encode_points(&[4.0, 4.0], &[1]).expect("ok");
1135 let b = pe.encode_points(&[28.0, 20.0], &[1]).expect("ok");
1136 assert_eq!(a.len(), cfg.embed_dim);
1137 let diff: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
1138 assert!(
1139 diff > 1e-3,
1140 "different points must encode differently, diff={diff}"
1141 );
1142 let fg = pe.encode_points(&[4.0, 4.0], &[1]).expect("ok");
1145 let bg = pe.encode_points(&[4.0, 4.0], &[0]).expect("ok");
1146 let label_diff: f32 = fg.iter().zip(bg.iter()).map(|(x, y)| (x - y).abs()).sum();
1147 assert!(label_diff > 1e-4, "fg/bg labels must differ");
1148 assert!(a.iter().any(|&v| v.abs() > 1e-6));
1150 }
1151
1152 #[test]
1153 fn box_prompt_encodes_two_corners() {
1154 let cfg = SamConfig::tiny();
1155 let mut rng = LcgRng::new(4);
1156 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1157 let emb = sam
1158 .prompt_encoder()
1159 .encode_box(&[2.0, 3.0, 20.0, 25.0])
1160 .expect("ok");
1161 assert_eq!(emb.len(), 2 * cfg.embed_dim, "box → 2 corner embeddings");
1162 assert!(emb.iter().all(|v| v.is_finite()));
1163 }
1164
1165 #[test]
1168 fn two_way_block_updates_both_and_weights_normalised() {
1169 let e = 16;
1170 let n_heads = 2;
1171 let mut rng = LcgRng::new(5);
1172 let block = TwoWayAttentionBlock::new(e, n_heads, 32, &mut rng).expect("ok");
1173
1174 let n_t = 4;
1175 let n_i = 9;
1176 let mut tokens = vec![0.0f32; n_t * e];
1177 let mut image = vec![0.0f32; n_i * e];
1178 let mut qpe = vec![0.0f32; n_t * e];
1179 let mut kpe = vec![0.0f32; n_i * e];
1180 rng.fill_normal(&mut tokens);
1181 rng.fill_normal(&mut image);
1182 rng.fill_normal(&mut qpe);
1183 rng.fill_normal(&mut kpe);
1184
1185 let out = block.forward(&tokens, &image, &qpe, &kpe).expect("ok");
1186
1187 let tok_diff: f32 = out
1189 .tokens
1190 .iter()
1191 .zip(tokens.iter())
1192 .map(|(a, b)| (a - b).abs())
1193 .sum();
1194 let img_diff: f32 = out
1195 .image
1196 .iter()
1197 .zip(image.iter())
1198 .map(|(a, b)| (a - b).abs())
1199 .sum();
1200 assert!(tok_diff > 1e-4, "tokens must be updated, diff={tok_diff}");
1201 assert!(img_diff > 1e-4, "image must be updated, diff={img_diff}");
1202
1203 check_rows_sum_to_one(&out.self_weights, n_heads, n_t, n_t);
1205 check_rows_sum_to_one(&out.token_to_image_weights, n_heads, n_t, n_i);
1207 check_rows_sum_to_one(&out.image_to_token_weights, n_heads, n_i, n_t);
1209 }
1210
1211 fn check_rows_sum_to_one(weights: &[f32], n_heads: usize, n_q: usize, n_k: usize) {
1212 for h in 0..n_heads {
1213 for i in 0..n_q {
1214 let row = &weights[(h * n_q + i) * n_k..(h * n_q + i + 1) * n_k];
1215 let sum: f32 = row.iter().sum();
1216 assert!(
1217 row.iter().all(|&w| w >= 0.0),
1218 "weights must be non-negative"
1219 );
1220 assert!((sum - 1.0).abs() < 1e-4, "attention row sum {sum} != 1");
1221 }
1222 }
1223 }
1224
1225 #[test]
1228 fn changing_prompt_changes_mask() {
1229 let cfg = SamConfig::tiny();
1230 let mut rng = LcgRng::new(6);
1231 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1232 let img = random_image(&cfg, 7);
1233 let pred_a = sam.predict(&img, &[4.0, 4.0], &[1], None).expect("ok");
1234 let pred_b = sam.predict(&img, &[28.0, 26.0], &[1], None).expect("ok");
1235 let diff: f32 = pred_a
1236 .masks
1237 .iter()
1238 .zip(pred_b.masks.iter())
1239 .map(|(a, b)| (a - b).abs())
1240 .sum();
1241 assert!(
1242 diff > 1e-4,
1243 "different prompts must change the mask, diff={diff}"
1244 );
1245 }
1246
1247 #[test]
1250 fn mask_output_dims_and_iou_finite() {
1251 let cfg = SamConfig::tiny();
1252 let mut rng = LcgRng::new(8);
1253 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1254 let img = random_image(&cfg, 9);
1255 let pred = sam.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1256 let grid = cfg.img_size / cfg.patch_size; assert_eq!(pred.n_mask, cfg.n_mask);
1258 assert_eq!((pred.height, pred.width), (grid * 2, grid * 2));
1259 assert_eq!(pred.masks.len(), cfg.n_mask * (grid * 2) * (grid * 2));
1260 assert_eq!(pred.iou.len(), cfg.n_mask);
1261 assert!(
1262 pred.iou.iter().all(|v| v.is_finite()),
1263 "IoU scores must be finite"
1264 );
1265 assert!(pred.masks.iter().all(|v| v.is_finite()));
1266 }
1267
1268 #[test]
1269 fn mask_prompt_changes_output() {
1270 let cfg = SamConfig::tiny();
1273 let mut rng = LcgRng::new(10);
1274 let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1275 let img = random_image(&cfg, 11);
1276 let grid = cfg.img_size / cfg.patch_size;
1277 let mut coarse = vec![0.0f32; grid * grid];
1278 let mut mrng = LcgRng::new(12);
1279 mrng.fill_normal(&mut coarse);
1280 let with_mask = sam
1281 .predict(&img, &[10.0, 10.0], &[1], Some(&coarse))
1282 .expect("ok");
1283 let without = sam.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1284 let diff: f32 = with_mask
1285 .masks
1286 .iter()
1287 .zip(without.masks.iter())
1288 .map(|(a, b)| (a - b).abs())
1289 .sum();
1290 assert!(
1291 diff > 1e-4,
1292 "mask prompt must influence the output, diff={diff}"
1293 );
1294 }
1295
1296 #[test]
1299 fn deterministic_same_seed() {
1300 let cfg = SamConfig::tiny();
1301 let img = random_image(&cfg, 13);
1302 let mut ra = LcgRng::new(77);
1303 let mut rb = LcgRng::new(77);
1304 let sa = Sam::new(cfg.clone(), &mut ra).expect("ok");
1305 let sb = Sam::new(cfg, &mut rb).expect("ok");
1306 let pa = sa.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1307 let pb = sb.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1308 assert_eq!(pa.masks, pb.masks, "same seed → identical masks");
1309 assert_eq!(pa.iou, pb.iou, "same seed → identical IoU");
1310 }
1311}