Skip to main content

oxicuda_vision/segmentation/
sam.rs

1//! SAM — a compact, faithful CPU reference of the *Segment Anything Model*
2//! (Kirillov et al. 2023, *"Segment Anything"*).
3//!
4//! The model has three cooperating parts, all implemented here with real
5//! computation (no shape-only stubs):
6//!
7//! 1. **Image encoder** — a ViT patch-embedder followed by a couple of
8//!    transformer blocks, reshaped into a dense image embedding `(C, H', W')`.
9//!    Reuses [`crate::patch_embed::PatchEmbed`] and [`crate::vit::ViTBlock`].
10//! 2. **Prompt encoder** — points / boxes are turned into **sparse** embeddings
11//!    via a random-Fourier positional encoding plus learned per-type embeddings;
12//!    a coarse mask is turned into a **dense** embedding added to the image
13//!    embedding.
14//! 3. **Two-way transformer mask decoder** — output tokens and the image
15//!    embedding attend to *each other* in alternating directions (tokens→image
16//!    and image→tokens), after which the mask tokens form a spatial filter that
17//!    is applied to the up-scaled image embedding to produce masks, and an
18//!    IoU-token MLP predicts a quality score per mask.
19//!
20//! ## Tensor layout
21//! Token sequences are flat `[n_tokens · embed_dim]` row-major. Dense maps use
22//! the channel-major [`FeatureMap`] layout (`[C, H, W]`).
23
24use crate::{
25    error::{VisionError, VisionResult},
26    fpn::top_down::FeatureMap,
27    handle::LcgRng,
28    patch_embed::{PatchEmbed, PatchEmbedConfig, pos_2d_sincos},
29    vit::vit_block::{gelu_exact, layer_norm, linear, softmax_rows},
30    vit::{ViTBlock, ViTBlockConfig},
31};
32
33use std::f32::consts::PI;
34
35// ─── Small reusable parameter blocks ───────────────────────────────────────────
36
37/// Fill a buffer with `N(0, scale)` samples.
38fn filled(n: usize, scale: f32, rng: &mut LcgRng) -> Vec<f32> {
39    let mut v = vec![0.0f32; n];
40    rng.fill_normal(&mut v);
41    for x in &mut v {
42        *x *= scale;
43    }
44    v
45}
46
47/// A learned LayerNorm affine (`γ` ones, `β` zeros), applied per row.
48#[derive(Debug, Clone)]
49struct LayerNormParams {
50    weight: Vec<f32>,
51    bias: Vec<f32>,
52}
53
54impl LayerNormParams {
55    fn new(d: usize) -> Self {
56        Self {
57            weight: vec![1.0f32; d],
58            bias: vec![0.0f32; d],
59        }
60    }
61
62    fn apply(&self, x: &[f32], n: usize, d: usize) -> Vec<f32> {
63        layer_norm(x, &self.weight, &self.bias, n, d, 1e-6)
64    }
65}
66
67/// A two-layer MLP `Linear → GELU → Linear` (the transformer / hypernetwork FFN).
68#[derive(Debug, Clone)]
69struct Mlp {
70    w1: Vec<f32>,
71    b1: Vec<f32>,
72    w2: Vec<f32>,
73    b2: Vec<f32>,
74    d_in: usize,
75    hidden: usize,
76    d_out: usize,
77}
78
79impl Mlp {
80    fn new(d_in: usize, hidden: usize, d_out: usize, rng: &mut LcgRng) -> Self {
81        Self {
82            w1: filled(hidden * d_in, (2.0 / d_in as f32).sqrt(), rng),
83            b1: vec![0.0f32; hidden],
84            w2: filled(d_out * hidden, (2.0 / hidden as f32).sqrt(), rng),
85            b2: vec![0.0f32; d_out],
86            d_in,
87            hidden,
88            d_out,
89        }
90    }
91
92    /// Apply to a `[batch · d_in]` buffer → `[batch · d_out]`.
93    fn apply(&self, x: &[f32]) -> Vec<f32> {
94        let mut h = linear(x, &self.w1, &self.b1, self.d_in, self.hidden);
95        for v in &mut h {
96            *v = gelu_exact(*v);
97        }
98        linear(&h, &self.w2, &self.b2, self.hidden, self.d_out)
99    }
100}
101
102// ─── Multi-head attention with explicit weights ────────────────────────────────
103
104/// General multi-head attention with **separate** Q / K / V projections so that
105/// the query and key/value sequences may differ (cross-attention). The forward
106/// pass returns both the output and the per-head attention weights, which the
107/// two-way transformer tests use to verify the softmax normalisation.
108pub struct MultiHeadAttention {
109    wq: Vec<f32>,
110    bq: Vec<f32>,
111    wk: Vec<f32>,
112    bk: Vec<f32>,
113    wv: Vec<f32>,
114    bv: Vec<f32>,
115    wo: Vec<f32>,
116    bo: Vec<f32>,
117    embed_dim: usize,
118    n_heads: usize,
119    head_dim: usize,
120}
121
122impl MultiHeadAttention {
123    /// Construct an attention module.
124    ///
125    /// # Errors
126    /// - [`VisionError::InvalidNumHeads`] if `n_heads == 0`.
127    /// - [`VisionError::HeadDimMismatch`] if `n_heads` does not divide
128    ///   `embed_dim`.
129    pub fn new(embed_dim: usize, n_heads: usize, rng: &mut LcgRng) -> VisionResult<Self> {
130        if n_heads == 0 {
131            return Err(VisionError::InvalidNumHeads(n_heads));
132        }
133        if embed_dim % n_heads != 0 {
134            return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
135        }
136        let scale = 1.0 / (embed_dim as f32).sqrt();
137        Ok(Self {
138            wq: filled(embed_dim * embed_dim, scale, rng),
139            bq: vec![0.0f32; embed_dim],
140            wk: filled(embed_dim * embed_dim, scale, rng),
141            bk: vec![0.0f32; embed_dim],
142            wv: filled(embed_dim * embed_dim, scale, rng),
143            bv: vec![0.0f32; embed_dim],
144            wo: filled(embed_dim * embed_dim, scale, rng),
145            bo: vec![0.0f32; embed_dim],
146            embed_dim,
147            n_heads,
148            head_dim: embed_dim / n_heads,
149        })
150    }
151
152    /// Attention forward. `q_in` is `[n_q · e]`; `k_in`, `v_in` are `[n_k · e]`.
153    ///
154    /// Returns `(output [n_q · e], weights [n_heads · n_q · n_k])`.
155    ///
156    /// # Errors
157    /// - [`VisionError::DimensionMismatch`] on inconsistent input lengths.
158    /// - [`VisionError::NonFinite`] if the output is non-finite.
159    pub fn forward(
160        &self,
161        q_in: &[f32],
162        k_in: &[f32],
163        v_in: &[f32],
164        n_q: usize,
165        n_k: usize,
166    ) -> VisionResult<(Vec<f32>, Vec<f32>)> {
167        let e = self.embed_dim;
168        if q_in.len() != n_q * e {
169            return Err(VisionError::DimensionMismatch {
170                expected: n_q * e,
171                got: q_in.len(),
172            });
173        }
174        if k_in.len() != n_k * e || v_in.len() != n_k * e {
175            return Err(VisionError::DimensionMismatch {
176                expected: n_k * e,
177                got: k_in.len(),
178            });
179        }
180
181        let q = linear(q_in, &self.wq, &self.bq, e, e);
182        let k = linear(k_in, &self.wk, &self.bk, e, e);
183        let v = linear(v_in, &self.wv, &self.bv, e, e);
184
185        let scale = 1.0 / (self.head_dim as f32).sqrt();
186        let mut concat = vec![0.0f32; n_q * e];
187        let mut weights = vec![0.0f32; self.n_heads * n_q * n_k];
188        let mut scores = vec![0.0f32; n_q * n_k];
189
190        for h in 0..self.n_heads {
191            let off = h * self.head_dim;
192            for i in 0..n_q {
193                for j in 0..n_k {
194                    let mut dot = 0.0f32;
195                    for d in 0..self.head_dim {
196                        dot += q[i * e + off + d] * k[j * e + off + d];
197                    }
198                    scores[i * n_k + j] = dot * scale;
199                }
200            }
201            softmax_rows(&mut scores, n_q, n_k);
202
203            // Record weights and aggregate values.
204            for i in 0..n_q {
205                let w_row = (h * n_q + i) * n_k;
206                let s_row = i * n_k;
207                weights[w_row..w_row + n_k].copy_from_slice(&scores[s_row..s_row + n_k]);
208                for d in 0..self.head_dim {
209                    let mut acc = 0.0f32;
210                    for j in 0..n_k {
211                        acc += scores[s_row + j] * v[j * e + off + d];
212                    }
213                    concat[i * e + off + d] = acc;
214                }
215            }
216        }
217
218        let out = linear(&concat, &self.wo, &self.bo, e, e);
219        if out.iter().any(|x| !x.is_finite()) {
220            return Err(VisionError::NonFinite("SAM attention output"));
221        }
222        Ok((out, weights))
223    }
224}
225
226// ─── Two-way attention block ───────────────────────────────────────────────────
227
228/// Output of a [`TwoWayAttentionBlock`] forward pass.
229#[derive(Debug, Clone)]
230pub struct TwoWayBlockOutput {
231    /// Updated token embeddings `[n_tokens · e]`.
232    pub tokens: Vec<f32>,
233    /// Updated image embeddings `[n_image · e]`.
234    pub image: Vec<f32>,
235    /// Token self-attention weights `[n_heads · n_tokens · n_tokens]`.
236    pub self_weights: Vec<f32>,
237    /// Token→image cross-attention weights `[n_heads · n_tokens · n_image]`.
238    pub token_to_image_weights: Vec<f32>,
239    /// Image→token cross-attention weights `[n_heads · n_image · n_tokens]`.
240    pub image_to_token_weights: Vec<f32>,
241}
242
243/// A two-way attention block: tokens attend to themselves and to the image, then
244/// the image attends back to the tokens — so **both** representations are
245/// updated in a single block.
246pub struct TwoWayAttentionBlock {
247    self_attn: MultiHeadAttention,
248    cross_token_to_image: MultiHeadAttention,
249    cross_image_to_token: MultiHeadAttention,
250    mlp: Mlp,
251    norm1: LayerNormParams,
252    norm2: LayerNormParams,
253    norm3: LayerNormParams,
254    norm4: LayerNormParams,
255    embed_dim: usize,
256}
257
258impl TwoWayAttentionBlock {
259    /// Construct a block.
260    ///
261    /// # Errors
262    /// Propagates [`MultiHeadAttention::new`] validation.
263    pub fn new(
264        embed_dim: usize,
265        n_heads: usize,
266        mlp_dim: usize,
267        rng: &mut LcgRng,
268    ) -> VisionResult<Self> {
269        Ok(Self {
270            self_attn: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
271            cross_token_to_image: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
272            cross_image_to_token: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
273            mlp: Mlp::new(embed_dim, mlp_dim, embed_dim, rng),
274            norm1: LayerNormParams::new(embed_dim),
275            norm2: LayerNormParams::new(embed_dim),
276            norm3: LayerNormParams::new(embed_dim),
277            norm4: LayerNormParams::new(embed_dim),
278            embed_dim,
279        })
280    }
281
282    /// Forward pass.
283    ///
284    /// `tokens`/`query_pe` are `[n_t · e]`; `image`/`key_pe` are `[n_i · e]`.
285    /// The positional encodings are *added to queries and keys only* (never to
286    /// the values), exactly as in SAM.
287    ///
288    /// # Errors
289    /// - [`VisionError::DimensionMismatch`] on inconsistent lengths.
290    /// - Propagates attention errors.
291    pub fn forward(
292        &self,
293        tokens: &[f32],
294        image: &[f32],
295        query_pe: &[f32],
296        key_pe: &[f32],
297    ) -> VisionResult<TwoWayBlockOutput> {
298        let e = self.embed_dim;
299        if tokens.len() != query_pe.len() {
300            return Err(VisionError::DimensionMismatch {
301                expected: tokens.len(),
302                got: query_pe.len(),
303            });
304        }
305        if image.len() != key_pe.len() {
306            return Err(VisionError::DimensionMismatch {
307                expected: image.len(),
308                got: key_pe.len(),
309            });
310        }
311        if tokens.len() % e != 0 || image.len() % e != 0 {
312            return Err(VisionError::DimensionMismatch {
313                expected: e,
314                got: tokens.len() % e,
315            });
316        }
317        let n_t = tokens.len() / e;
318        let n_i = image.len() / e;
319
320        // 1) Token self-attention (Q = K = tokens + pe, V = tokens).
321        let q = add_vec(tokens, query_pe);
322        let (sa, self_w) = self_attn_or_err(&self.self_attn, &q, tokens, n_t)?;
323        let mut tokens_cur = add_vec(tokens, &sa);
324        tokens_cur = self.norm1.apply(&tokens_cur, n_t, e);
325
326        // 2) Token → image cross-attention.
327        let q = add_vec(&tokens_cur, query_pe);
328        let k = add_vec(image, key_pe);
329        let (ca, t2i_w) = self.cross_token_to_image.forward(&q, &k, image, n_t, n_i)?;
330        tokens_cur = add_vec(&tokens_cur, &ca);
331        tokens_cur = self.norm2.apply(&tokens_cur, n_t, e);
332
333        // 3) Token MLP.
334        let m = self.mlp.apply(&tokens_cur);
335        tokens_cur = add_vec(&tokens_cur, &m);
336        tokens_cur = self.norm3.apply(&tokens_cur, n_t, e);
337
338        // 4) Image → token cross-attention (image is the query now).
339        let q = add_vec(image, key_pe);
340        let k = add_vec(&tokens_cur, query_pe);
341        let (ca2, i2t_w) = self
342            .cross_image_to_token
343            .forward(&q, &k, &tokens_cur, n_i, n_t)?;
344        let mut image_cur = add_vec(image, &ca2);
345        image_cur = self.norm4.apply(&image_cur, n_i, e);
346
347        Ok(TwoWayBlockOutput {
348            tokens: tokens_cur,
349            image: image_cur,
350            self_weights: self_w,
351            token_to_image_weights: t2i_w,
352            image_to_token_weights: i2t_w,
353        })
354    }
355}
356
357/// Self-attention helper (Q/K share, V is the raw tokens).
358fn self_attn_or_err(
359    attn: &MultiHeadAttention,
360    qk: &[f32],
361    v: &[f32],
362    n: usize,
363) -> VisionResult<(Vec<f32>, Vec<f32>)> {
364    attn.forward(qk, qk, v, n, n)
365}
366
367// ─── Two-way transformer ───────────────────────────────────────────────────────
368
369/// Stacks several [`TwoWayAttentionBlock`]s, followed by a final token→image
370/// attention as in SAM's `TwoWayTransformer`.
371pub struct TwoWayTransformer {
372    blocks: Vec<TwoWayAttentionBlock>,
373    final_attn: MultiHeadAttention,
374    final_norm: LayerNormParams,
375    embed_dim: usize,
376}
377
378impl TwoWayTransformer {
379    fn new(
380        embed_dim: usize,
381        n_heads: usize,
382        depth: usize,
383        mlp_dim: usize,
384        rng: &mut LcgRng,
385    ) -> VisionResult<Self> {
386        let mut blocks = Vec::with_capacity(depth);
387        for _ in 0..depth {
388            blocks.push(TwoWayAttentionBlock::new(embed_dim, n_heads, mlp_dim, rng)?);
389        }
390        Ok(Self {
391            blocks,
392            final_attn: MultiHeadAttention::new(embed_dim, n_heads, rng)?,
393            final_norm: LayerNormParams::new(embed_dim),
394            embed_dim,
395        })
396    }
397
398    /// Run the transformer. `image`/`image_pe` are `[n_i · e]`, `point_tokens`
399    /// is `[n_t · e]`. The point tokens double as the query positional encoding,
400    /// exactly as in SAM.
401    ///
402    /// Returns `(tokens [n_t · e], image [n_i · e])`.
403    ///
404    /// # Errors
405    /// Propagates block / attention errors.
406    pub fn forward(
407        &self,
408        image: &[f32],
409        image_pe: &[f32],
410        point_tokens: &[f32],
411    ) -> VisionResult<(Vec<f32>, Vec<f32>)> {
412        let e = self.embed_dim;
413        let n_t = point_tokens.len() / e;
414        let n_i = image.len() / e;
415        let query_pe = point_tokens.to_vec();
416
417        let mut tokens = point_tokens.to_vec();
418        let mut img = image.to_vec();
419        for block in &self.blocks {
420            let out = block.forward(&tokens, &img, &query_pe, image_pe)?;
421            tokens = out.tokens;
422            img = out.image;
423        }
424
425        // Final token → image attention.
426        let q = add_vec(&tokens, &query_pe);
427        let k = add_vec(&img, image_pe);
428        let (attn, _w) = self.final_attn.forward(&q, &k, &img, n_t, n_i)?;
429        tokens = add_vec(&tokens, &attn);
430        tokens = self.final_norm.apply(&tokens, n_t, e);
431
432        Ok((tokens, img))
433    }
434}
435
436// ─── Image encoder ─────────────────────────────────────────────────────────────
437
438/// ViT image encoder producing a dense image embedding.
439pub struct ImageEncoder {
440    patch_embed: PatchEmbed,
441    pos_embed: Vec<f32>,
442    blocks: Vec<ViTBlock>,
443    neck_w: Vec<f32>,
444    neck_b: Vec<f32>,
445    grid: usize,
446    embed_dim: usize,
447}
448
449impl ImageEncoder {
450    fn new(cfg: &SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
451        let pe_cfg =
452            PatchEmbedConfig::new(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim)?;
453        let grid = pe_cfg.grid_size();
454        let patch_embed = PatchEmbed::new(pe_cfg, rng);
455        let pos_embed = pos_2d_sincos(grid, grid, cfg.embed_dim)?;
456        let block_cfg = ViTBlockConfig::new(cfg.embed_dim, cfg.enc_heads, cfg.enc_mlp_ratio)?;
457        let mut blocks = Vec::with_capacity(cfg.enc_depth);
458        for _ in 0..cfg.enc_depth {
459            blocks.push(ViTBlock::new(block_cfg.clone(), rng));
460        }
461        let scale = 1.0 / (cfg.embed_dim as f32).sqrt();
462        Ok(Self {
463            patch_embed,
464            pos_embed,
465            blocks,
466            neck_w: filled(cfg.embed_dim * cfg.embed_dim, scale, rng),
467            neck_b: vec![0.0f32; cfg.embed_dim],
468            grid,
469            embed_dim: cfg.embed_dim,
470        })
471    }
472
473    /// Encode a flat CHW image into a dense `(embed_dim, grid, grid)` map.
474    ///
475    /// # Errors
476    /// Propagates patch-embed / block errors.
477    pub fn forward(&self, image: &[f32]) -> VisionResult<FeatureMap> {
478        let e = self.embed_dim;
479        let n_patches = self.grid * self.grid;
480        let mut tokens = self.patch_embed.forward(image)?;
481        // Add positional embedding.
482        for (t, p) in tokens.iter_mut().zip(self.pos_embed.iter()) {
483            *t += *p;
484        }
485        for block in &self.blocks {
486            tokens = block.forward(&tokens, n_patches)?;
487        }
488        // 1×1 neck (per-token linear).
489        let tokens = linear(&tokens, &self.neck_w, &self.neck_b, e, e);
490        // Reshape [n_patches, e] (patch-major) → (e, grid, grid).
491        let chw = tokens_to_chw(&tokens, e, self.grid, self.grid);
492        FeatureMap::new(chw, e, self.grid, self.grid)
493    }
494}
495
496// ─── Positional encoding (random Fourier) ──────────────────────────────────────
497
498/// SAM's `PositionEmbeddingRandom`: a fixed Gaussian matrix `[2 · num_freq]` maps
499/// a normalised coordinate `(x, y) ∈ [0, 1]²` to `[sin, cos]` Fourier features.
500pub struct PositionEmbeddingRandom {
501    gaussian: Vec<f32>,
502    num_freq: usize,
503}
504
505impl PositionEmbeddingRandom {
506    fn new(num_freq: usize, rng: &mut LcgRng) -> Self {
507        // Scale 1.0 keeps the projection well-conditioned for the tiny ref.
508        Self {
509            gaussian: filled(2 * num_freq, 1.0, rng),
510            num_freq,
511        }
512    }
513
514    /// Encode a normalised `(x, y)` into a `[2 · num_freq]` embedding.
515    fn encode_point(&self, x: f32, y: f32) -> Vec<f32> {
516        let nf = self.num_freq;
517        let mut out = vec![0.0f32; 2 * nf];
518        for f in 0..nf {
519            let proj = 2.0 * PI * (x * self.gaussian[f] + y * self.gaussian[nf + f]);
520            out[f] = proj.sin();
521            out[nf + f] = proj.cos();
522        }
523        out
524    }
525
526    /// Encode an `h × w` grid into a dense `(2·num_freq, h, w)` positional map,
527    /// using cell-centre coordinates normalised to `[0, 1]`.
528    fn encode_grid(&self, h: usize, w: usize) -> Vec<f32> {
529        let dim = 2 * self.num_freq;
530        let mut out = vec![0.0f32; dim * h * w];
531        for i in 0..h {
532            for j in 0..w {
533                let x = (j as f32 + 0.5) / w as f32;
534                let y = (i as f32 + 0.5) / h as f32;
535                let enc = self.encode_point(x, y);
536                for (c, &val) in enc.iter().enumerate() {
537                    out[(c * h + i) * w + j] = val;
538                }
539            }
540        }
541        out
542    }
543}
544
545// ─── Prompt encoder ────────────────────────────────────────────────────────────
546
547/// Encodes geometric prompts (points / boxes) into sparse embeddings and a
548/// coarse mask into a dense embedding.
549pub struct PromptEncoder {
550    pe_layer: PositionEmbeddingRandom,
551    /// Learned per-label point embeddings: `[2 · e]` (index 0 = background,
552    /// 1 = foreground).
553    point_embeddings: Vec<f32>,
554    /// Learned box-corner embeddings: `[2 · e]` (top-left, bottom-right).
555    corner_embeddings: Vec<f32>,
556    /// Embedding for a padding / "not a point" entry: `[e]`.
557    not_a_point: Vec<f32>,
558    /// Embedding used when no mask prompt is supplied: `[e]`.
559    no_mask_embed: Vec<f32>,
560    /// 1×1 projection `1 → e` for the coarse mask.
561    mask_w: Vec<f32>,
562    mask_b: Vec<f32>,
563    embed_dim: usize,
564    grid: usize,
565    input_size: f32,
566}
567
568impl PromptEncoder {
569    fn new(cfg: &SamConfig, rng: &mut LcgRng) -> Self {
570        let e = cfg.embed_dim;
571        let scale = 1.0 / (e as f32).sqrt();
572        Self {
573            pe_layer: PositionEmbeddingRandom::new(e / 2, rng),
574            point_embeddings: filled(2 * e, 0.1, rng),
575            corner_embeddings: filled(2 * e, 0.1, rng),
576            not_a_point: filled(e, 0.1, rng),
577            no_mask_embed: filled(e, 0.1, rng),
578            mask_w: filled(e, scale, rng),
579            mask_b: vec![0.0f32; e],
580            embed_dim: e,
581            grid: cfg.img_size / cfg.patch_size,
582            input_size: cfg.img_size as f32,
583        }
584    }
585
586    /// Image-grid positional encoding used by the decoder, `(e, grid, grid)`.
587    #[must_use]
588    pub fn dense_positional_encoding(&self) -> Vec<f32> {
589        self.pe_layer.encode_grid(self.grid, self.grid)
590    }
591
592    /// Encode point prompts.
593    ///
594    /// `coords` is `[n · 2]` pixel coordinates; `labels[i]` is `1` for a
595    /// foreground point, `0` for background, and any negative value for a
596    /// padding ("not a point") entry. Returns sparse embeddings `[n · e]`.
597    ///
598    /// # Errors
599    /// - [`VisionError::DimensionMismatch`] if `coords.len() != 2·labels.len()`.
600    pub fn encode_points(&self, coords: &[f32], labels: &[i32]) -> VisionResult<Vec<f32>> {
601        let n = labels.len();
602        if coords.len() != n * 2 {
603            return Err(VisionError::DimensionMismatch {
604                expected: n * 2,
605                got: coords.len(),
606            });
607        }
608        let e = self.embed_dim;
609        let mut out = vec![0.0f32; n * e];
610        for p in 0..n {
611            let x = coords[p * 2] / self.input_size;
612            let y = coords[p * 2 + 1] / self.input_size;
613            let pe = self.pe_layer.encode_point(x, y);
614            let dst = &mut out[p * e..(p + 1) * e];
615            if labels[p] < 0 {
616                // Padding point: zeroed PE plus the "not a point" embedding.
617                for (d, slot) in dst.iter_mut().enumerate() {
618                    *slot = self.not_a_point[d];
619                }
620            } else {
621                let label_off = if labels[p] >= 1 { e } else { 0 };
622                for (d, slot) in dst.iter_mut().enumerate() {
623                    *slot = pe[d] + self.point_embeddings[label_off + d];
624                }
625            }
626        }
627        Ok(out)
628    }
629
630    /// Encode a box prompt `[x1, y1, x2, y2]` (pixels) into two corner
631    /// embeddings `[2 · e]`.
632    ///
633    /// # Errors
634    /// - [`VisionError::DimensionMismatch`] if `box4.len() != 4`.
635    pub fn encode_box(&self, box4: &[f32]) -> VisionResult<Vec<f32>> {
636        if box4.len() != 4 {
637            return Err(VisionError::DimensionMismatch {
638                expected: 4,
639                got: box4.len(),
640            });
641        }
642        let e = self.embed_dim;
643        let corners = [(box4[0], box4[1], 0usize), (box4[2], box4[3], 1usize)];
644        let mut out = vec![0.0f32; 2 * e];
645        for (idx, &(cx, cy, corner)) in corners.iter().enumerate() {
646            let pe = self
647                .pe_layer
648                .encode_point(cx / self.input_size, cy / self.input_size);
649            let dst = &mut out[idx * e..(idx + 1) * e];
650            for (d, slot) in dst.iter_mut().enumerate() {
651                *slot = pe[d] + self.corner_embeddings[corner * e + d];
652            }
653        }
654        Ok(out)
655    }
656
657    /// Encode a coarse mask `[grid · grid]` into a dense embedding
658    /// `(e, grid, grid)`. When `mask` is `None`, the learned `no_mask` embedding
659    /// is broadcast over the grid.
660    ///
661    /// # Errors
662    /// - [`VisionError::DimensionMismatch`] if `mask` is `Some` but not
663    ///   `grid · grid` long.
664    pub fn encode_mask(&self, mask: Option<&[f32]>) -> VisionResult<Vec<f32>> {
665        let e = self.embed_dim;
666        let hw = self.grid * self.grid;
667        let mut out = vec![0.0f32; e * hw];
668        match mask {
669            None => {
670                for c in 0..e {
671                    let val = self.no_mask_embed[c];
672                    for p in 0..hw {
673                        out[c * hw + p] = val;
674                    }
675                }
676            }
677            Some(m) => {
678                if m.len() != hw {
679                    return Err(VisionError::DimensionMismatch {
680                        expected: hw,
681                        got: m.len(),
682                    });
683                }
684                // 1×1 conv with one input channel: out[c,p] = w[c]·m[p] + b[c].
685                for c in 0..e {
686                    let w = self.mask_w[c];
687                    let b = self.mask_b[c];
688                    for p in 0..hw {
689                        out[c * hw + p] = w * m[p] + b;
690                    }
691                }
692            }
693        }
694        Ok(out)
695    }
696}
697
698// ─── Mask decoder ──────────────────────────────────────────────────────────────
699
700/// Predicted masks plus their IoU quality scores.
701#[derive(Debug, Clone)]
702pub struct MaskPrediction {
703    /// Masks: flat `[n_mask · (2·grid) · (2·grid)]`.
704    pub masks: Vec<f32>,
705    /// Per-mask predicted IoU scores: `[n_mask]`.
706    pub iou: Vec<f32>,
707    /// Number of masks.
708    pub n_mask: usize,
709    /// Mask spatial height (= 2 · grid).
710    pub height: usize,
711    /// Mask spatial width (= 2 · grid).
712    pub width: usize,
713}
714
715/// The SAM mask decoder.
716pub struct MaskDecoder {
717    transformer: TwoWayTransformer,
718    iou_token: Vec<f32>,
719    mask_tokens: Vec<f32>,
720    upscale_w: Vec<f32>,
721    upscale_b: Vec<f32>,
722    hypernets: Vec<Mlp>,
723    iou_head: Mlp,
724    n_mask: usize,
725    embed_dim: usize,
726}
727
728impl MaskDecoder {
729    fn new(cfg: &SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
730        let e = cfg.embed_dim;
731        let transformer =
732            TwoWayTransformer::new(e, cfg.dec_heads, cfg.dec_depth, cfg.dec_mlp_dim, rng)?;
733        let scale = 1.0 / (e as f32).sqrt();
734        let hypernets = (0..cfg.n_mask).map(|_| Mlp::new(e, e, e, rng)).collect();
735        Ok(Self {
736            transformer,
737            iou_token: filled(e, 0.02, rng),
738            mask_tokens: filled(cfg.n_mask * e, 0.02, rng),
739            upscale_w: filled(e * e, scale, rng),
740            upscale_b: vec![0.0f32; e],
741            hypernets,
742            iou_head: Mlp::new(e, e, cfg.n_mask, rng),
743            n_mask: cfg.n_mask,
744            embed_dim: e,
745        })
746    }
747
748    /// Decode masks from an image embedding and prompt embeddings.
749    ///
750    /// - `image_embedding`: dense `(e, grid, grid)`.
751    /// - `image_pe`: dense positional encoding `(e, grid, grid)` (flat).
752    /// - `sparse_prompt`: `[n_sparse · e]` (may be empty).
753    /// - `dense_prompt`: dense `(e, grid, grid)` (flat).
754    ///
755    /// # Errors
756    /// - [`VisionError::DimensionMismatch`] on shape mismatches.
757    /// - [`VisionError::NonFinite`] if any output is non-finite.
758    pub fn forward(
759        &self,
760        image_embedding: &FeatureMap,
761        image_pe: &[f32],
762        sparse_prompt: &[f32],
763        dense_prompt: &[f32],
764    ) -> VisionResult<MaskPrediction> {
765        let e = self.embed_dim;
766        let (h, w) = (image_embedding.height, image_embedding.width);
767        let hw = h * w;
768        if image_embedding.channels != e || image_embedding.data.len() != e * hw {
769            return Err(VisionError::DimensionMismatch {
770                expected: e * hw,
771                got: image_embedding.data.len(),
772            });
773        }
774        if image_pe.len() != e * hw || dense_prompt.len() != e * hw {
775            return Err(VisionError::DimensionMismatch {
776                expected: e * hw,
777                got: image_pe.len(),
778            });
779        }
780        if sparse_prompt.len() % e != 0 {
781            return Err(VisionError::DimensionMismatch {
782                expected: e,
783                got: sparse_prompt.len() % e,
784            });
785        }
786        let n_sparse = sparse_prompt.len() / e;
787
788        // Assemble output tokens: [iou_token, mask_tokens..., sparse...].
789        let n_tokens = 1 + self.n_mask + n_sparse;
790        let mut tokens = Vec::with_capacity(n_tokens * e);
791        tokens.extend_from_slice(&self.iou_token);
792        tokens.extend_from_slice(&self.mask_tokens);
793        tokens.extend_from_slice(sparse_prompt);
794
795        // Inject the dense prompt into the image embedding, then flatten to
796        // token order [hw, e].
797        let mut src_chw = image_embedding.data.clone();
798        for (s, d) in src_chw.iter_mut().zip(dense_prompt.iter()) {
799            *s += *d;
800        }
801        let src_tokens = chw_to_tokens(&src_chw, e, h, w);
802        let pe_tokens = chw_to_tokens(image_pe, e, h, w);
803
804        // Two-way transformer.
805        let (tokens_out, src_out) = self.transformer.forward(&src_tokens, &pe_tokens, &tokens)?;
806
807        // Reshape the updated image tokens back to (e, h, w) and upscale 2×.
808        let src_img = tokens_to_chw(&src_out, e, h, w);
809        let up = upsample2x_chw(&src_img, e, h, w);
810        let (uh, uw) = (h * 2, w * 2);
811        // 1×1 conv on the upscaled features (per-pixel linear e → e).
812        let up_tokens = chw_to_tokens(&up, e, uh, uw);
813        let up_tokens = linear(&up_tokens, &self.upscale_w, &self.upscale_b, e, e);
814        let up = tokens_to_chw(&up_tokens, e, uh, uw);
815
816        // Hypernetwork filters → spatial masks.
817        let mut masks = vec![0.0f32; self.n_mask * uh * uw];
818        for m in 0..self.n_mask {
819            let token = &tokens_out[(1 + m) * e..(2 + m) * e];
820            let filter = self.hypernets[m].apply(token); // [e]
821            for p in 0..(uh * uw) {
822                let mut acc = 0.0f32;
823                for c in 0..e {
824                    acc += filter[c] * up[c * uh * uw + p];
825                }
826                masks[m * uh * uw + p] = acc;
827            }
828        }
829
830        // IoU prediction from the IoU token.
831        let iou_token = &tokens_out[0..e];
832        let iou = self.iou_head.apply(iou_token);
833
834        if masks.iter().chain(iou.iter()).any(|v| !v.is_finite()) {
835            return Err(VisionError::NonFinite("SAM mask decoder output"));
836        }
837
838        Ok(MaskPrediction {
839            masks,
840            iou,
841            n_mask: self.n_mask,
842            height: uh,
843            width: uw,
844        })
845    }
846}
847
848// ─── Sam (top level) ───────────────────────────────────────────────────────────
849
850/// SAM hyper-parameters.
851#[derive(Debug, Clone, PartialEq)]
852pub struct SamConfig {
853    /// Input image channels.
854    pub in_chans: usize,
855    /// Square input size.
856    pub img_size: usize,
857    /// Patch size of the image encoder.
858    pub patch_size: usize,
859    /// Embedding dimension (shared across encoder / prompt / decoder).
860    pub embed_dim: usize,
861    /// Encoder transformer depth.
862    pub enc_depth: usize,
863    /// Encoder attention heads.
864    pub enc_heads: usize,
865    /// Encoder MLP ratio.
866    pub enc_mlp_ratio: usize,
867    /// Decoder two-way transformer depth.
868    pub dec_depth: usize,
869    /// Decoder attention heads.
870    pub dec_heads: usize,
871    /// Decoder MLP hidden dimension.
872    pub dec_mlp_dim: usize,
873    /// Number of output mask tokens.
874    pub n_mask: usize,
875}
876
877impl SamConfig {
878    /// Create and validate a configuration.
879    ///
880    /// # Errors
881    /// - [`VisionError::InvalidEmbedDim`] if `embed_dim` is 0 or odd.
882    /// - [`VisionError::InvalidPatchSize`] if `patch_size` does not divide
883    ///   `img_size`.
884    /// - [`VisionError::HeadDimMismatch`] if a head count does not divide
885    ///   `embed_dim`.
886    /// - [`VisionError::EmptyInput`] if `n_mask == 0`.
887    pub fn new(
888        in_chans: usize,
889        img_size: usize,
890        patch_size: usize,
891        embed_dim: usize,
892        enc_depth: usize,
893        enc_heads: usize,
894        enc_mlp_ratio: usize,
895        dec_depth: usize,
896        dec_heads: usize,
897        dec_mlp_dim: usize,
898        n_mask: usize,
899    ) -> VisionResult<Self> {
900        if embed_dim == 0 || embed_dim % 2 != 0 {
901            return Err(VisionError::InvalidEmbedDim(embed_dim));
902        }
903        if patch_size == 0 || img_size % patch_size != 0 {
904            return Err(VisionError::InvalidPatchSize {
905                patch_size,
906                img_size,
907            });
908        }
909        if enc_heads == 0 || embed_dim % enc_heads != 0 {
910            return Err(VisionError::HeadDimMismatch {
911                n_heads: enc_heads,
912                embed_dim,
913            });
914        }
915        if dec_heads == 0 || embed_dim % dec_heads != 0 {
916            return Err(VisionError::HeadDimMismatch {
917                n_heads: dec_heads,
918                embed_dim,
919            });
920        }
921        if n_mask == 0 {
922            return Err(VisionError::EmptyInput("sam n_mask"));
923        }
924        Ok(Self {
925            in_chans,
926            img_size,
927            patch_size,
928            embed_dim,
929            enc_depth,
930            enc_heads,
931            enc_mlp_ratio,
932            dec_depth,
933            dec_heads,
934            dec_mlp_dim,
935            n_mask,
936        })
937    }
938
939    /// A tiny configuration for unit tests: 3×32×32 input, patch 8, embed 16,
940    /// 2 encoder blocks (2 heads), 2-deep decoder (2 heads), 3 masks.
941    #[must_use]
942    pub fn tiny() -> Self {
943        Self {
944            in_chans: 3,
945            img_size: 32,
946            patch_size: 8,
947            embed_dim: 16,
948            enc_depth: 2,
949            enc_heads: 2,
950            enc_mlp_ratio: 2,
951            dec_depth: 2,
952            dec_heads: 2,
953            dec_mlp_dim: 32,
954            n_mask: 3,
955        }
956    }
957}
958
959/// The full Segment Anything Model.
960pub struct Sam {
961    cfg: SamConfig,
962    image_encoder: ImageEncoder,
963    prompt_encoder: PromptEncoder,
964    mask_decoder: MaskDecoder,
965}
966
967impl Sam {
968    /// Build SAM with randomly-initialised weights.
969    ///
970    /// # Errors
971    /// Propagates configuration validation.
972    pub fn new(cfg: SamConfig, rng: &mut LcgRng) -> VisionResult<Self> {
973        let image_encoder = ImageEncoder::new(&cfg, rng)?;
974        let prompt_encoder = PromptEncoder::new(&cfg, rng);
975        let mask_decoder = MaskDecoder::new(&cfg, rng)?;
976        Ok(Self {
977            cfg,
978            image_encoder,
979            prompt_encoder,
980            mask_decoder,
981        })
982    }
983
984    /// Read-only configuration access.
985    #[must_use]
986    #[inline]
987    pub fn config(&self) -> &SamConfig {
988        &self.cfg
989    }
990
991    /// Read-only access to the prompt encoder.
992    #[must_use]
993    #[inline]
994    pub fn prompt_encoder(&self) -> &PromptEncoder {
995        &self.prompt_encoder
996    }
997
998    /// Encode an image into its dense embedding `(embed_dim, grid, grid)`.
999    ///
1000    /// # Errors
1001    /// Propagates encoder errors.
1002    pub fn encode_image(&self, image: &[f32]) -> VisionResult<FeatureMap> {
1003        self.image_encoder.forward(image)
1004    }
1005
1006    /// End-to-end prediction from an image and a set of point prompts (plus an
1007    /// optional coarse mask prompt).
1008    ///
1009    /// # Errors
1010    /// Propagates encoder / prompt / decoder errors.
1011    pub fn predict(
1012        &self,
1013        image: &[f32],
1014        point_coords: &[f32],
1015        point_labels: &[i32],
1016        mask: Option<&[f32]>,
1017    ) -> VisionResult<MaskPrediction> {
1018        let embedding = self.encode_image(image)?;
1019        let sparse = self
1020            .prompt_encoder
1021            .encode_points(point_coords, point_labels)?;
1022        let dense = self.prompt_encoder.encode_mask(mask)?;
1023        let image_pe = self.prompt_encoder.dense_positional_encoding();
1024        self.mask_decoder
1025            .forward(&embedding, &image_pe, &sparse, &dense)
1026    }
1027}
1028
1029// ─── Shared helpers ────────────────────────────────────────────────────────────
1030
1031/// Element-wise sum of two equal-length buffers.
1032fn add_vec(a: &[f32], b: &[f32]) -> Vec<f32> {
1033    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
1034}
1035
1036/// `(C, H, W)` → `[H·W, C]` (token-major).
1037fn chw_to_tokens(chw: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1038    let hw = h * w;
1039    let mut out = vec![0.0f32; hw * c];
1040    for ch in 0..c {
1041        for p in 0..hw {
1042            out[p * c + ch] = chw[ch * hw + p];
1043        }
1044    }
1045    out
1046}
1047
1048/// `[H·W, C]` (token-major) → `(C, H, W)`.
1049fn tokens_to_chw(tokens: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1050    let hw = h * w;
1051    let mut out = vec![0.0f32; c * hw];
1052    for p in 0..hw {
1053        for ch in 0..c {
1054            out[ch * hw + p] = tokens[p * c + ch];
1055        }
1056    }
1057    out
1058}
1059
1060/// Nearest-neighbour 2× upsampling of a `(C, H, W)` buffer.
1061fn upsample2x_chw(chw: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
1062    let (h2, w2) = (h * 2, w * 2);
1063    let mut out = vec![0.0f32; c * h2 * w2];
1064    for ch in 0..c {
1065        for i in 0..h {
1066            for j in 0..w {
1067                let v = chw[(ch * h + i) * w + j];
1068                let oi = i * 2;
1069                let oj = j * 2;
1070                out[(ch * h2 + oi) * w2 + oj] = v;
1071                out[(ch * h2 + oi) * w2 + oj + 1] = v;
1072                out[(ch * h2 + oi + 1) * w2 + oj] = v;
1073                out[(ch * h2 + oi + 1) * w2 + oj + 1] = v;
1074            }
1075        }
1076    }
1077    out
1078}
1079
1080// ─── Tests ───────────────────────────────────────────────────────────────────
1081
1082#[cfg(test)]
1083mod tests {
1084    use super::*;
1085
1086    fn random_image(cfg: &SamConfig, seed: u64) -> Vec<f32> {
1087        let mut rng = LcgRng::new(seed);
1088        let mut img = vec![0.0f32; cfg.in_chans * cfg.img_size * cfg.img_size];
1089        rng.fill_normal(&mut img);
1090        img
1091    }
1092
1093    // ── Config ─────────────────────────────────────────────────────────────────
1094
1095    #[test]
1096    fn config_tiny_valid() {
1097        let cfg = SamConfig::tiny();
1098        assert_eq!(cfg.embed_dim, 16);
1099        assert_eq!(cfg.n_mask, 3);
1100    }
1101
1102    #[test]
1103    fn config_bad_heads_errors() {
1104        // 16 % 3 != 0
1105        let r = SamConfig::new(3, 32, 8, 16, 2, 3, 2, 2, 2, 32, 3);
1106        assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
1107    }
1108
1109    // ── Image embedding shape ─────────────────────────────────────────────────
1110
1111    #[test]
1112    fn image_embedding_shape() {
1113        let cfg = SamConfig::tiny();
1114        let mut rng = LcgRng::new(1);
1115        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1116        let img = random_image(&cfg, 2);
1117        let emb = sam.encode_image(&img).expect("ok");
1118        let grid = cfg.img_size / cfg.patch_size; // 4
1119        assert_eq!(
1120            (emb.channels, emb.height, emb.width),
1121            (cfg.embed_dim, grid, grid)
1122        );
1123        assert!(emb.data.iter().all(|v| v.is_finite()));
1124    }
1125
1126    // ── Point prompt encoding ─────────────────────────────────────────────────
1127
1128    #[test]
1129    fn different_points_give_different_sparse_embeddings() {
1130        let cfg = SamConfig::tiny();
1131        let mut rng = LcgRng::new(3);
1132        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1133        let pe = sam.prompt_encoder();
1134        let a = pe.encode_points(&[4.0, 4.0], &[1]).expect("ok");
1135        let b = pe.encode_points(&[28.0, 20.0], &[1]).expect("ok");
1136        assert_eq!(a.len(), cfg.embed_dim);
1137        let diff: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
1138        assert!(
1139            diff > 1e-3,
1140            "different points must encode differently, diff={diff}"
1141        );
1142        // Foreground vs background label at the same location must also differ
1143        // (different learned label embedding).
1144        let fg = pe.encode_points(&[4.0, 4.0], &[1]).expect("ok");
1145        let bg = pe.encode_points(&[4.0, 4.0], &[0]).expect("ok");
1146        let label_diff: f32 = fg.iter().zip(bg.iter()).map(|(x, y)| (x - y).abs()).sum();
1147        assert!(label_diff > 1e-4, "fg/bg labels must differ");
1148        // Positional encoding present (not all zeros).
1149        assert!(a.iter().any(|&v| v.abs() > 1e-6));
1150    }
1151
1152    #[test]
1153    fn box_prompt_encodes_two_corners() {
1154        let cfg = SamConfig::tiny();
1155        let mut rng = LcgRng::new(4);
1156        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1157        let emb = sam
1158            .prompt_encoder()
1159            .encode_box(&[2.0, 3.0, 20.0, 25.0])
1160            .expect("ok");
1161        assert_eq!(emb.len(), 2 * cfg.embed_dim, "box → 2 corner embeddings");
1162        assert!(emb.iter().all(|v| v.is_finite()));
1163    }
1164
1165    // ── Two-way attention updates BOTH and weights sum to 1 ───────────────────
1166
1167    #[test]
1168    fn two_way_block_updates_both_and_weights_normalised() {
1169        let e = 16;
1170        let n_heads = 2;
1171        let mut rng = LcgRng::new(5);
1172        let block = TwoWayAttentionBlock::new(e, n_heads, 32, &mut rng).expect("ok");
1173
1174        let n_t = 4;
1175        let n_i = 9;
1176        let mut tokens = vec![0.0f32; n_t * e];
1177        let mut image = vec![0.0f32; n_i * e];
1178        let mut qpe = vec![0.0f32; n_t * e];
1179        let mut kpe = vec![0.0f32; n_i * e];
1180        rng.fill_normal(&mut tokens);
1181        rng.fill_normal(&mut image);
1182        rng.fill_normal(&mut qpe);
1183        rng.fill_normal(&mut kpe);
1184
1185        let out = block.forward(&tokens, &image, &qpe, &kpe).expect("ok");
1186
1187        // BOTH representations must change.
1188        let tok_diff: f32 = out
1189            .tokens
1190            .iter()
1191            .zip(tokens.iter())
1192            .map(|(a, b)| (a - b).abs())
1193            .sum();
1194        let img_diff: f32 = out
1195            .image
1196            .iter()
1197            .zip(image.iter())
1198            .map(|(a, b)| (a - b).abs())
1199            .sum();
1200        assert!(tok_diff > 1e-4, "tokens must be updated, diff={tok_diff}");
1201        assert!(img_diff > 1e-4, "image must be updated, diff={img_diff}");
1202
1203        // Self-attention weights: each (head, query) row sums to 1 over keys.
1204        check_rows_sum_to_one(&out.self_weights, n_heads, n_t, n_t);
1205        // Token→image weights: rows over n_i sum to 1.
1206        check_rows_sum_to_one(&out.token_to_image_weights, n_heads, n_t, n_i);
1207        // Image→token weights: rows over n_t sum to 1.
1208        check_rows_sum_to_one(&out.image_to_token_weights, n_heads, n_i, n_t);
1209    }
1210
1211    fn check_rows_sum_to_one(weights: &[f32], n_heads: usize, n_q: usize, n_k: usize) {
1212        for h in 0..n_heads {
1213            for i in 0..n_q {
1214                let row = &weights[(h * n_q + i) * n_k..(h * n_q + i + 1) * n_k];
1215                let sum: f32 = row.iter().sum();
1216                assert!(
1217                    row.iter().all(|&w| w >= 0.0),
1218                    "weights must be non-negative"
1219                );
1220                assert!((sum - 1.0).abs() < 1e-4, "attention row sum {sum} != 1");
1221            }
1222        }
1223    }
1224
1225    // ── Prompt changes the mask ───────────────────────────────────────────────
1226
1227    #[test]
1228    fn changing_prompt_changes_mask() {
1229        let cfg = SamConfig::tiny();
1230        let mut rng = LcgRng::new(6);
1231        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1232        let img = random_image(&cfg, 7);
1233        let pred_a = sam.predict(&img, &[4.0, 4.0], &[1], None).expect("ok");
1234        let pred_b = sam.predict(&img, &[28.0, 26.0], &[1], None).expect("ok");
1235        let diff: f32 = pred_a
1236            .masks
1237            .iter()
1238            .zip(pred_b.masks.iter())
1239            .map(|(a, b)| (a - b).abs())
1240            .sum();
1241        assert!(
1242            diff > 1e-4,
1243            "different prompts must change the mask, diff={diff}"
1244        );
1245    }
1246
1247    // ── Mask output dims & IoU finite ─────────────────────────────────────────
1248
1249    #[test]
1250    fn mask_output_dims_and_iou_finite() {
1251        let cfg = SamConfig::tiny();
1252        let mut rng = LcgRng::new(8);
1253        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1254        let img = random_image(&cfg, 9);
1255        let pred = sam.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1256        let grid = cfg.img_size / cfg.patch_size; // 4
1257        assert_eq!(pred.n_mask, cfg.n_mask);
1258        assert_eq!((pred.height, pred.width), (grid * 2, grid * 2));
1259        assert_eq!(pred.masks.len(), cfg.n_mask * (grid * 2) * (grid * 2));
1260        assert_eq!(pred.iou.len(), cfg.n_mask);
1261        assert!(
1262            pred.iou.iter().all(|v| v.is_finite()),
1263            "IoU scores must be finite"
1264        );
1265        assert!(pred.masks.iter().all(|v| v.is_finite()));
1266    }
1267
1268    #[test]
1269    fn mask_prompt_changes_output() {
1270        // Supplying a coarse mask prompt (dense path) must change the result vs.
1271        // the no-mask broadcast embedding.
1272        let cfg = SamConfig::tiny();
1273        let mut rng = LcgRng::new(10);
1274        let sam = Sam::new(cfg.clone(), &mut rng).expect("ok");
1275        let img = random_image(&cfg, 11);
1276        let grid = cfg.img_size / cfg.patch_size;
1277        let mut coarse = vec![0.0f32; grid * grid];
1278        let mut mrng = LcgRng::new(12);
1279        mrng.fill_normal(&mut coarse);
1280        let with_mask = sam
1281            .predict(&img, &[10.0, 10.0], &[1], Some(&coarse))
1282            .expect("ok");
1283        let without = sam.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1284        let diff: f32 = with_mask
1285            .masks
1286            .iter()
1287            .zip(without.masks.iter())
1288            .map(|(a, b)| (a - b).abs())
1289            .sum();
1290        assert!(
1291            diff > 1e-4,
1292            "mask prompt must influence the output, diff={diff}"
1293        );
1294    }
1295
1296    // ── Determinism ───────────────────────────────────────────────────────────
1297
1298    #[test]
1299    fn deterministic_same_seed() {
1300        let cfg = SamConfig::tiny();
1301        let img = random_image(&cfg, 13);
1302        let mut ra = LcgRng::new(77);
1303        let mut rb = LcgRng::new(77);
1304        let sa = Sam::new(cfg.clone(), &mut ra).expect("ok");
1305        let sb = Sam::new(cfg, &mut rb).expect("ok");
1306        let pa = sa.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1307        let pb = sb.predict(&img, &[10.0, 10.0], &[1], None).expect("ok");
1308        assert_eq!(pa.masks, pb.masks, "same seed → identical masks");
1309        assert_eq!(pa.iou, pb.iou, "same seed → identical IoU");
1310    }
1311}