Skip to main content

rlx_sam2/
prompt_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 prompt encoder — Fourier/point embeddings host-side; mask stack via IR.
17//!
18//! Mirrors `sam2/modeling/sam/prompt_encoder.py::PromptEncoder` exactly.
19//! Structurally identical to SAM v1 (random-Fourier positional
20//! encoding + per-label embeddings + 2-stage Conv mask downscale +
21//! `no_mask_embed` broadcast). Two integration differences vs v1:
22//!
23//!   1. Weight key prefix is `sam_prompt_encoder.*` instead of
24//!      `prompt_encoder.*` — SAM 2 nests the prompt encoder under
25//!      `sam_prompt_encoder` inside the published checkpoints.
26//!   2. The embedding grid resolution comes from the *finest* FPN
27//!      level (stride 16, 64×64 for 1024 input) — the reference's
28//!      `image_embedding_size = (64, 64)`. Constant for every Hiera
29//!      variant.
30//!
31//! The prompt encoder is < 1 % of total compute, so keeping it on the
32//! CPU keeps Phase 2 self-contained (no IR-surface growth for
33//! Gaussian-PE / Conv2d k=2 s=2 / etc.).
34
35use super::config::SAM2_IMG_SIZE;
36use super::prompt_mask_ir::Sam2PromptMaskCompiled;
37use anyhow::{Result, ensure};
38use rlx_core::weight_map::WeightMap;
39
40/// Spatial resolution of the dense image embedding fed to the mask
41/// decoder. SAM 2 hardcodes 64×64 (stride 16 on a 1024 input). This
42/// matches `image_embedding_size=(64, 64)` in every published
43/// `sam2_hiera_*.yaml`.
44pub const SAM2_PROMPT_GRID: usize = 64;
45
46/// `mask_in_chans` per the reference YAML — fixed at 16 across all
47/// Hiera variants, same as SAM v1.
48pub const SAM2_MASK_IN_CHANS: usize = 16;
49
50/// All weights consumed by [`prompt_encoder_forward`]. Loaded once
51/// from the safetensors file and then reused per prompt.
52pub struct Sam2PromptEncoderWeights {
53    /// `[2, embed_dim/2]` Gaussian random projection used by the
54    /// Fourier positional encoder.
55    pub pe_gaussian: Vec<f32>,
56    /// `[embed_dim]` learned token for "not a point" padding label
57    /// (used when there are points but no boxes — labels of -1).
58    pub not_a_point_embed: Vec<f32>,
59    /// `[4, embed_dim]` learned per-label embeddings:
60    /// 0 → background point, 1 → foreground point,
61    /// 2 → box top-left corner, 3 → box bottom-right corner.
62    pub point_embeddings: Vec<f32>,
63    /// Mask downscaling stack (Conv2d → LN2d → GELU → Conv2d → LN2d
64    /// → GELU → Conv2d).
65    pub mask_conv1_w: Vec<f32>,
66    pub mask_conv1_b: Vec<f32>,
67    pub mask_ln1_g: Vec<f32>,
68    pub mask_ln1_b: Vec<f32>,
69    pub mask_conv2_w: Vec<f32>,
70    pub mask_conv2_b: Vec<f32>,
71    pub mask_ln2_g: Vec<f32>,
72    pub mask_ln2_b: Vec<f32>,
73    pub mask_conv3_w: Vec<f32>,
74    pub mask_conv3_b: Vec<f32>,
75    /// `[embed_dim]` learned token broadcast over the image grid when
76    /// no mask prompt is supplied.
77    pub no_mask_embed: Vec<f32>,
78    pub embed_dim: usize,
79    /// `mask_in_chans` (16 for all SAM 2 variants).
80    pub mask_in_chans: usize,
81    /// Grid edge length (64 for SAM 2's stride-16 path).
82    pub grid: usize,
83}
84
85/// Drain the prompt-encoder weights from the safetensors map. Returns
86/// `Sam2PromptEncoderWeights` and consumes the corresponding keys.
87pub fn extract_prompt_encoder_weights(
88    weights: &mut WeightMap,
89    embed_dim: usize,
90    mask_in_chans: usize,
91) -> Result<Sam2PromptEncoderWeights> {
92    let half = embed_dim / 2;
93    let (pe_gaussian, sh) =
94        weights.take("sam_prompt_encoder.pe_layer.positional_encoding_gaussian_matrix")?;
95    ensure!(
96        sh == vec![2, half],
97        "pe_gaussian expected [2, {half}], got {sh:?}"
98    );
99
100    let (not_a_point_embed, _) = weights.take("sam_prompt_encoder.not_a_point_embed.weight")?;
101    let (no_mask_embed, _) = weights.take("sam_prompt_encoder.no_mask_embed.weight")?;
102
103    let mut point_embeddings = vec![0f32; 4 * embed_dim];
104    for i in 0..4 {
105        let (data, _) = weights.take(&format!("sam_prompt_encoder.point_embeddings.{i}.weight"))?;
106        point_embeddings[i * embed_dim..(i + 1) * embed_dim].copy_from_slice(&data);
107    }
108
109    let q = mask_in_chans / 4;
110    let (mask_conv1_w, sh1) = weights.take("sam_prompt_encoder.mask_downscaling.0.weight")?;
111    ensure!(
112        sh1 == vec![q, 1, 2, 2],
113        "mask_downscaling.0.weight expected [{q}, 1, 2, 2], got {sh1:?}"
114    );
115    let (mask_conv1_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.0.bias")?;
116    let (mask_ln1_g, _) = weights.take("sam_prompt_encoder.mask_downscaling.1.weight")?;
117    let (mask_ln1_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.1.bias")?;
118
119    let (mask_conv2_w, sh2) = weights.take("sam_prompt_encoder.mask_downscaling.3.weight")?;
120    ensure!(
121        sh2 == vec![mask_in_chans, q, 2, 2],
122        "mask_downscaling.3.weight expected [{mask_in_chans}, {q}, 2, 2], got {sh2:?}"
123    );
124    let (mask_conv2_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.3.bias")?;
125    let (mask_ln2_g, _) = weights.take("sam_prompt_encoder.mask_downscaling.4.weight")?;
126    let (mask_ln2_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.4.bias")?;
127
128    let (mask_conv3_w, sh3) = weights.take("sam_prompt_encoder.mask_downscaling.6.weight")?;
129    ensure!(
130        sh3 == vec![embed_dim, mask_in_chans, 1, 1],
131        "mask_downscaling.6.weight expected [{embed_dim}, {mask_in_chans}, 1, 1], got {sh3:?}"
132    );
133    let (mask_conv3_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.6.bias")?;
134
135    Ok(Sam2PromptEncoderWeights {
136        pe_gaussian,
137        not_a_point_embed,
138        point_embeddings,
139        mask_conv1_w,
140        mask_conv1_b,
141        mask_ln1_g,
142        mask_ln1_b,
143        mask_conv2_w,
144        mask_conv2_b,
145        mask_ln2_g,
146        mask_ln2_b,
147        mask_conv3_w,
148        mask_conv3_b,
149        no_mask_embed,
150        embed_dim,
151        mask_in_chans,
152        grid: SAM2_PROMPT_GRID,
153    })
154}
155
156/// Output of [`prompt_encoder_forward`] — fed straight into the mask
157/// decoder. All host-side `Vec<f32>`.
158pub struct Sam2PromptEncoderOutput {
159    /// `[num_tokens, embed_dim]` — concatenation of point and box
160    /// embeddings. `num_tokens = 0` for the "no prompt" case.
161    pub sparse_embeddings: Vec<f32>,
162    pub num_sparse_tokens: usize,
163    /// `[embed_dim, grid, grid]` — dense pixel-wise embedding.
164    pub dense_embeddings: Vec<f32>,
165    /// `[embed_dim, grid, grid]` — image positional encoding (the
166    /// dense PE fed into the mask decoder).
167    pub image_pe: Vec<f32>,
168}
169
170/// Run the SAM 2 prompt encoder. Mirrors
171/// `sam2.modeling.sam.prompt_encoder.PromptEncoder.forward`.
172///
173/// `points`: optional `(coords, labels)` where coords is `[N, 2]`
174///   (x, y in input-image pixels, 0..`SAM2_IMG_SIZE`) and labels is
175///   `[N]` (1 = foreground, 0 = background, -1 = padding).
176/// `boxes`: optional `[M, 4]` boxes (x0, y0, x1, y1).
177/// `masks`: optional `[1, 4·grid, 4·grid]` mask prompt (a logits map
178///   at 4× the embedding resolution); 256×256 for SAM 2's 64×64 grid.
179pub fn prompt_encoder_forward(
180    w: &Sam2PromptEncoderWeights,
181    mask_stack: &mut Sam2PromptMaskCompiled,
182    points: Option<(&[f32], &[f32])>,
183    boxes: Option<&[f32]>,
184    masks: Option<&[f32]>,
185) -> Result<Sam2PromptEncoderOutput> {
186    let e = w.embed_dim;
187    let g = w.grid;
188
189    // ── Sparse embeddings ──
190    let pad_points = boxes.is_none();
191    let mut sparse = Vec::new();
192
193    if let Some((coords, labels)) = points {
194        let n = labels.len();
195        ensure!(
196            coords.len() == n * 2,
197            "points coords len {} ≠ N·2 ({}·2)",
198            coords.len(),
199            n
200        );
201        let mut pts: Vec<f32> = coords.iter().map(|c| c + 0.5).collect();
202        let mut lbls = labels.to_vec();
203        if pad_points {
204            pts.push(0.0);
205            pts.push(0.0);
206            lbls.push(-1.0);
207        }
208        let n_padded = lbls.len();
209        let emb = embed_points_and_boxes(w, &pts, n_padded, /*is_box=*/ false, Some(&lbls))?;
210        sparse.extend_from_slice(&emb);
211    }
212    if let Some(box_coords) = boxes {
213        let m = box_coords.len() / 4;
214        ensure!(box_coords.len() == m * 4, "boxes len must be multiple of 4");
215        let coords_with_half: Vec<f32> = box_coords.iter().map(|c| c + 0.5).collect();
216        let emb = embed_points_and_boxes(w, &coords_with_half, m * 2, /*is_box=*/ true, None)?;
217        sparse.extend_from_slice(&emb);
218    }
219    let num_sparse_tokens = if sparse.is_empty() {
220        0
221    } else {
222        sparse.len() / e
223    };
224
225    // ── Dense embeddings ──
226    let dense_embeddings = match masks {
227        Some(m) => mask_stack.run(m)?,
228        None => {
229            // Broadcast no_mask_embed [E] to [E, g, g].
230            let mut out = vec![0f32; e * g * g];
231            for c in 0..e {
232                let v = w.no_mask_embed[c];
233                out[c * g * g..(c + 1) * g * g].fill(v);
234            }
235            out
236        }
237    };
238
239    // ── Image PE: random-Fourier encoding of a g·g normalised grid ──
240    let image_pe = compute_image_pe(w, g, g);
241
242    Ok(Sam2PromptEncoderOutput {
243        sparse_embeddings: sparse,
244        num_sparse_tokens,
245        dense_embeddings,
246        image_pe,
247    })
248}
249
250/// Random-Fourier positional encoding for a `(h, w)` grid.
251/// Output shape `[embed_dim, h, w]`.
252pub fn compute_image_pe(w: &Sam2PromptEncoderWeights, h: usize, ww: usize) -> Vec<f32> {
253    let e = w.embed_dim;
254    let half = e / 2;
255    let mut out = vec![0f32; e * h * ww];
256    for y in 0..h {
257        let fy = (y as f32 + 0.5) / h as f32;
258        for x in 0..ww {
259            let fx = (x as f32 + 0.5) / ww as f32;
260            let cx = fx * 2.0 - 1.0;
261            let cy = fy * 2.0 - 1.0;
262            for k in 0..half {
263                let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
264                acc *= 2.0 * std::f32::consts::PI;
265                out[k * h * ww + y * ww + x] = acc.sin();
266                out[(half + k) * h * ww + y * ww + x] = acc.cos();
267            }
268        }
269    }
270    out
271}
272
273/// Apply the Gaussian + sin/cos PE to coords already in `[0, 1]`.
274/// Returns `[N, embed_dim]`.
275fn pe_encode_normalized(w: &Sam2PromptEncoderWeights, coords: &[f32], n: usize) -> Vec<f32> {
276    let e = w.embed_dim;
277    let half = e / 2;
278    let mut out = vec![0f32; n * e];
279    for i in 0..n {
280        let cx = coords[i * 2] * 2.0 - 1.0;
281        let cy = coords[i * 2 + 1] * 2.0 - 1.0;
282        for k in 0..half {
283            let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
284            acc *= 2.0 * std::f32::consts::PI;
285            out[i * e + k] = acc.sin();
286            out[i * e + half + k] = acc.cos();
287        }
288    }
289    out
290}
291
292/// Embed N points or N boxes (each box becomes 2 corner points).
293fn embed_points_and_boxes(
294    w: &Sam2PromptEncoderWeights,
295    coords_in_pixels: &[f32],
296    n: usize,
297    is_box: bool,
298    labels: Option<&[f32]>,
299) -> Result<Vec<f32>> {
300    let e = w.embed_dim;
301    let img = SAM2_IMG_SIZE as f32;
302    let normed: Vec<f32> = coords_in_pixels.iter().map(|c| c / img).collect();
303    let mut emb = pe_encode_normalized(w, &normed, n);
304
305    if is_box {
306        for i in 0..n {
307            let pe_idx = if i % 2 == 0 { 2 } else { 3 };
308            for k in 0..e {
309                emb[i * e + k] += w.point_embeddings[pe_idx * e + k];
310            }
311        }
312    } else if let Some(lbls) = labels {
313        ensure!(lbls.len() == n, "labels len {} ≠ n {n}", lbls.len());
314        for i in 0..n {
315            let label = lbls[i];
316            if label < 0.0 {
317                for k in 0..e {
318                    emb[i * e + k] = w.not_a_point_embed[k];
319                }
320            } else if label == 0.0 {
321                for k in 0..e {
322                    emb[i * e + k] += w.point_embeddings[k];
323                }
324            } else {
325                for k in 0..e {
326                    emb[i * e + k] += w.point_embeddings[e + k];
327                }
328            }
329        }
330    }
331    Ok(emb)
332}
333
334// ─── Shared host-side helpers (also re-used by mask_decoder.rs / memory_encoder.rs) ────────
335
336/// 2-D conv with kernel=2 stride=2 padding=0, NCHW.
337#[allow(dead_code)]
338pub(super) fn conv2d_stride2_k2_pad0(
339    input: &[f32],
340    in_c: usize,
341    out_c: usize,
342    in_h: usize,
343    in_w: usize,
344    weight: &[f32], // [out_c, in_c, 2, 2]
345    bias: &[f32],   // [out_c]
346) -> Vec<f32> {
347    let out_h = in_h / 2;
348    let out_w = in_w / 2;
349    let mut out = vec![0f32; out_c * out_h * out_w];
350    for oc in 0..out_c {
351        for oy in 0..out_h {
352            for ox in 0..out_w {
353                let mut acc = bias[oc];
354                for ic in 0..in_c {
355                    for ky in 0..2 {
356                        let iy = oy * 2 + ky;
357                        for kx in 0..2 {
358                            let ix = ox * 2 + kx;
359                            let v = input[ic * in_h * in_w + iy * in_w + ix];
360                            let w_idx = ((oc * in_c + ic) * 2 + ky) * 2 + kx;
361                            acc += v * weight[w_idx];
362                        }
363                    }
364                }
365                out[oc * out_h * out_w + oy * out_w + ox] = acc;
366            }
367        }
368    }
369    out
370}
371
372/// 1×1 Conv2d = per-pixel matmul.
373pub(super) fn conv2d_1x1(
374    input: &[f32],
375    in_c: usize,
376    out_c: usize,
377    h: usize,
378    w: usize,
379    weight: &[f32], // [out_c, in_c, 1, 1]
380    bias: &[f32],   // [out_c]
381) -> Vec<f32> {
382    let mut out = vec![0f32; out_c * h * w];
383    for oc in 0..out_c {
384        let b = bias[oc];
385        for y in 0..h {
386            for x in 0..w {
387                let mut acc = b;
388                for ic in 0..in_c {
389                    acc += input[ic * h * w + y * w + x] * weight[oc * in_c + ic];
390                }
391                out[oc * h * w + y * w + x] = acc;
392            }
393        }
394    }
395    out
396}
397
398/// LayerNorm over the channel axis of NCHW (per spatial pos).
399pub(super) fn layernorm2d_nchw(
400    data: &mut [f32],
401    c: usize,
402    h: usize,
403    w: usize,
404    gamma: &[f32],
405    beta: &[f32],
406    eps: f32,
407) {
408    let n = h * w;
409    for i in 0..n {
410        let mut mean = 0f32;
411        for k in 0..c {
412            mean += data[k * n + i];
413        }
414        mean /= c as f32;
415        let mut var = 0f32;
416        for k in 0..c {
417            let d = data[k * n + i] - mean;
418            var += d * d;
419        }
420        var /= c as f32;
421        let inv = 1.0 / (var + eps).sqrt();
422        for k in 0..c {
423            let v = (data[k * n + i] - mean) * inv;
424            data[k * n + i] = v * gamma[k] + beta[k];
425        }
426    }
427}
428
429/// Exact erf-based GELU (matches `nn.GELU()` default).
430pub(super) fn gelu_erf_inplace(data: &mut [f32]) {
431    const INV_SQRT2: f32 = std::f32::consts::FRAC_1_SQRT_2;
432    for v in data.iter_mut() {
433        let x = *v;
434        let s = (x * INV_SQRT2).abs();
435        let p = 0.327_591_1;
436        let a1 = 0.254_829_6;
437        let a2 = -0.284_496_7;
438        let a3 = 1.421_413_8;
439        let a4 = -1.453_152_1;
440        let a5 = 1.061_405_4;
441        let t = 1.0 / (1.0 + p * s);
442        let y = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
443        let erf_abs = 1.0 - y * (-s * s).exp();
444        let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
445        *v = 0.5 * x * (1.0 + erf);
446    }
447}
448
449/// Sigmoid in place.
450pub(super) fn sigmoid_inplace(x: &mut [f32]) {
451    for v in x.iter_mut() {
452        *v = 1.0 / (1.0 + (-*v).exp());
453    }
454}