Skip to main content

rlx_sam/
prompt_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 prompt encoder — Fourier/point embeddings host-side; mask
17//! downscaling stack compiled via [`super::prompt_mask_ir`].
18//!
19//! Mirrors `candle-transformers/src/models/segment_anything/prompt_encoder.rs`
20//! exactly so the same `lmz/candle-sam` safetensors checkpoint loads
21//! without remapping.
22
23use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM};
24use super::prompt_mask_ir::SamPromptMaskCompiled;
25use anyhow::{Result, ensure};
26use rlx_core::weight_map::WeightMap;
27
28/// All weights consumed by [`PromptEncoder::forward`]. Loaded once
29/// from the safetensors file and then reused per prompt.
30pub struct PromptEncoderWeights {
31    /// `[2, embed_dim/2]` Gaussian random projection used by the
32    /// Fourier positional encoder.
33    pub pe_gaussian: Vec<f32>,
34    /// `[embed_dim]` learned token for "not a point" padding label
35    /// (used when there are points but no boxes — labels of -1).
36    pub not_a_point_embed: Vec<f32>,
37    /// `[4, embed_dim]` learned per-label embeddings:
38    /// 0 → background point, 1 → foreground point,
39    /// 2 → box top-left corner, 3 → box bottom-right corner.
40    pub point_embeddings: Vec<f32>,
41    /// Mask downscaling stack (Conv2d → LN2d → GELU → Conv2d → LN2d
42    /// → GELU → Conv2d).
43    pub mask_conv1_w: Vec<f32>,
44    pub mask_conv1_b: Vec<f32>,
45    pub mask_ln1_g: Vec<f32>,
46    pub mask_ln1_b: Vec<f32>,
47    pub mask_conv2_w: Vec<f32>,
48    pub mask_conv2_b: Vec<f32>,
49    pub mask_ln2_g: Vec<f32>,
50    pub mask_ln2_b: Vec<f32>,
51    pub mask_conv3_w: Vec<f32>,
52    pub mask_conv3_b: Vec<f32>,
53    /// `[embed_dim]` learned token broadcast over the image grid when
54    /// no mask prompt is supplied.
55    pub no_mask_embed: Vec<f32>,
56    pub embed_dim: usize,
57    /// `mask_in_chans` from candle's `Sam::new` (= 16 for ViT-B).
58    pub mask_in_chans: usize,
59}
60
61pub(super) fn extract_prompt_encoder_weights(
62    weights: &mut WeightMap,
63    embed_dim: usize,
64    mask_in_chans: usize,
65) -> Result<PromptEncoderWeights> {
66    let half = embed_dim / 2;
67    let (pe_gaussian, sh) =
68        weights.take("prompt_encoder.pe_layer.positional_encoding_gaussian_matrix")?;
69    ensure!(
70        sh == vec![2, half],
71        "pe_gaussian expected [2, {half}], got {sh:?}"
72    );
73
74    let (not_a_point_embed, _) = weights.take("prompt_encoder.not_a_point_embed.weight")?;
75    let (no_mask_embed, _) = weights.take("prompt_encoder.no_mask_embed.weight")?;
76
77    // 4 separate point embeddings: indices 0..4. Each `[1, embed_dim]`.
78    let mut point_embeddings = vec![0f32; 4 * embed_dim];
79    for i in 0..4 {
80        let (data, _) = weights.take(&format!("prompt_encoder.point_embeddings.{i}.weight"))?;
81        point_embeddings[i * embed_dim..(i + 1) * embed_dim].copy_from_slice(&data);
82    }
83
84    let q = mask_in_chans / 4;
85    let (mask_conv1_w, sh1) = weights.take("prompt_encoder.mask_downscaling.0.weight")?;
86    ensure!(
87        sh1 == vec![q, 1, 2, 2],
88        "mask_downscaling.0.weight expected [{q}, 1, 2, 2], got {sh1:?}"
89    );
90    let (mask_conv1_b, _) = weights.take("prompt_encoder.mask_downscaling.0.bias")?;
91    let (mask_ln1_g, _) = weights.take("prompt_encoder.mask_downscaling.1.weight")?;
92    let (mask_ln1_b, _) = weights.take("prompt_encoder.mask_downscaling.1.bias")?;
93
94    let (mask_conv2_w, sh2) = weights.take("prompt_encoder.mask_downscaling.3.weight")?;
95    ensure!(
96        sh2 == vec![mask_in_chans, q, 2, 2],
97        "mask_downscaling.3.weight expected [{mask_in_chans}, {q}, 2, 2], got {sh2:?}"
98    );
99    let (mask_conv2_b, _) = weights.take("prompt_encoder.mask_downscaling.3.bias")?;
100    let (mask_ln2_g, _) = weights.take("prompt_encoder.mask_downscaling.4.weight")?;
101    let (mask_ln2_b, _) = weights.take("prompt_encoder.mask_downscaling.4.bias")?;
102
103    let (mask_conv3_w, sh3) = weights.take("prompt_encoder.mask_downscaling.6.weight")?;
104    ensure!(
105        sh3 == vec![embed_dim, mask_in_chans, 1, 1],
106        "mask_downscaling.6.weight expected [{embed_dim}, {mask_in_chans}, 1, 1], got {sh3:?}"
107    );
108    let (mask_conv3_b, _) = weights.take("prompt_encoder.mask_downscaling.6.bias")?;
109
110    Ok(PromptEncoderWeights {
111        pe_gaussian,
112        not_a_point_embed,
113        point_embeddings,
114        mask_conv1_w,
115        mask_conv1_b,
116        mask_ln1_g,
117        mask_ln1_b,
118        mask_conv2_w,
119        mask_conv2_b,
120        mask_ln2_g,
121        mask_ln2_b,
122        mask_conv3_w,
123        mask_conv3_b,
124        no_mask_embed,
125        embed_dim,
126        mask_in_chans,
127    })
128}
129
130/// Output of [`PromptEncoder::forward`] — fed straight into the mask
131/// decoder. All host-side `Vec<f32>`.
132pub struct PromptEncoderOutput {
133    /// `[num_tokens, embed_dim]` — concatenation of point and box
134    /// embeddings. `num_tokens = 0` for the "no prompt" case.
135    pub sparse_embeddings: Vec<f32>,
136    pub num_sparse_tokens: usize,
137    /// `[embed_dim, hw, hw]` — dense pixel-wise embedding. Either the
138    /// mask-downscaled signal or the broadcast `no_mask_embed`.
139    pub dense_embeddings: Vec<f32>,
140    /// `[embed_dim, hw, hw]` — image positional encoding (the dense PE
141    /// fed into the mask decoder).
142    pub image_pe: Vec<f32>,
143}
144
145/// Run the prompt encoder. Mirrors candle's `PromptEncoder::forward`.
146///
147/// `points`: optional `(coords, labels)` where coords is `[N, 2]`
148///   (x, y in input-image pixels, 0..`SAM_IMG_SIZE`) and labels is
149///   `[N]` (1 = foreground, 0 = background, -1 = padding).
150/// `boxes`: optional `[M, 4]` boxes (x0, y0, x1, y1).
151/// `masks`: optional `[1, 4·hw, 4·hw]` mask prompt (a logits map
152///   at 4× the embedding resolution); pre-resized to 256×256 for
153///   ViT-B (where hw=64).
154pub fn prompt_encoder_forward(
155    w: &PromptEncoderWeights,
156    mask_stack: &mut SamPromptMaskCompiled,
157    points: Option<(&[f32], &[f32])>,
158    boxes: Option<&[f32]>,
159    masks: Option<&[f32]>,
160) -> Result<PromptEncoderOutput> {
161    let e = w.embed_dim;
162    let hw = SAM_EMBED_HW;
163
164    // ── Sparse embeddings ─────────────────────────────────────────
165    let pad_points = boxes.is_none();
166    let mut sparse = Vec::new();
167
168    if let Some((coords, labels)) = points {
169        let n = labels.len();
170        ensure!(
171            coords.len() == n * 2,
172            "points coords len {} ≠ N·2 ({}·2)",
173            coords.len(),
174            n
175        );
176        // Candle adds 0.5 to point coords.
177        let mut pts: Vec<f32> = coords.iter().map(|c| c + 0.5).collect();
178        let mut lbls = labels.to_vec();
179        if pad_points {
180            // Pad with a single "not-a-point" sentinel (label -1).
181            pts.push(0.0);
182            pts.push(0.0);
183            lbls.push(-1.0);
184        }
185        let n_padded = lbls.len();
186        let emb = embed_points_and_boxes(w, &pts, n_padded, /*is_box=*/ false, Some(&lbls))?;
187        sparse.extend_from_slice(&emb);
188    }
189    if let Some(box_coords) = boxes {
190        let m = box_coords.len() / 4;
191        ensure!(box_coords.len() == m * 4, "boxes len must be multiple of 4");
192        let coords_with_half: Vec<f32> = box_coords.iter().map(|c| c + 0.5).collect();
193        let emb = embed_points_and_boxes(w, &coords_with_half, m * 2, /*is_box=*/ true, None)?;
194        sparse.extend_from_slice(&emb);
195    }
196    let num_sparse_tokens = if sparse.is_empty() {
197        0
198    } else {
199        sparse.len() / e
200    };
201
202    // ── Dense embeddings ──────────────────────────────────────────
203    let dense_embeddings = match masks {
204        Some(m) => embed_mask(mask_stack, m, hw)?,
205        None => {
206            // Broadcast no_mask_embed [E] to [E, hw, hw].
207            let mut out = vec![0f32; e * hw * hw];
208            for c in 0..e {
209                let v = w.no_mask_embed[c];
210                let plane = &mut out[c * hw * hw..(c + 1) * hw * hw];
211                plane.fill(v);
212            }
213            out
214        }
215    };
216
217    // ── Image PE: random-Fourier encoding of a hw·hw normalized grid ──
218    let image_pe = compute_image_pe(w, hw, hw);
219
220    Ok(PromptEncoderOutput {
221        sparse_embeddings: sparse,
222        num_sparse_tokens,
223        dense_embeddings,
224        image_pe,
225    })
226}
227
228/// Random-Fourier positional encoding for a `(h, w)` grid.
229/// Output shape `[embed_dim, h, w]`.
230fn compute_image_pe(w: &PromptEncoderWeights, h: usize, ww: usize) -> Vec<f32> {
231    let e = w.embed_dim;
232    let half = e / 2;
233    let mut out = vec![0f32; e * h * ww];
234    // For each (y, x) cell, normalize to (x+0.5)/w, (y+0.5)/h, then map
235    // through the Gaussian + sin/cos pe_encoding.
236    for y in 0..h {
237        let fy = (y as f32 + 0.5) / h as f32;
238        for x in 0..ww {
239            let fx = (x as f32 + 0.5) / ww as f32;
240            // Candle's pe_encoding: coords = 2*coords - 1, then @ M, then *2π
241            let cx = fx * 2.0 - 1.0;
242            let cy = fy * 2.0 - 1.0;
243            // [cx, cy] @ gaussian [2, half]  → [half]
244            for k in 0..half {
245                let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
246                acc *= 2.0 * std::f32::consts::PI;
247                out[k * h * ww + y * ww + x] = acc.sin();
248                out[(half + k) * h * ww + y * ww + x] = acc.cos();
249            }
250        }
251    }
252    out
253}
254
255/// Apply the Gaussian + sin/cos PE to arbitrary `[N, 2]` coords already
256/// in `[0, 1]` (or with extra padding columns that are passed through
257/// unchanged via the candle convention). Returns `[N, embed_dim]`.
258fn pe_encode_normalized(w: &PromptEncoderWeights, coords: &[f32], n: usize) -> Vec<f32> {
259    let e = w.embed_dim;
260    let half = e / 2;
261    let mut out = vec![0f32; n * e];
262    for i in 0..n {
263        let cx = coords[i * 2] * 2.0 - 1.0;
264        let cy = coords[i * 2 + 1] * 2.0 - 1.0;
265        for k in 0..half {
266            let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
267            acc *= 2.0 * std::f32::consts::PI;
268            out[i * e + k] = acc.sin();
269            out[i * e + half + k] = acc.cos();
270        }
271    }
272    out
273}
274
275/// Embed N points or N boxes (each box becomes 2 corner points).
276///
277/// For points: applies labels to add per-label embeddings.
278/// For boxes: adds `point_embeddings[2]` to first corner, `[3]` to second.
279fn embed_points_and_boxes(
280    w: &PromptEncoderWeights,
281    coords_in_pixels: &[f32], // [n*2]
282    n: usize,
283    is_box: bool,
284    labels: Option<&[f32]>,
285) -> Result<Vec<f32>> {
286    let e = w.embed_dim;
287    // Normalize pixel coords to [0,1] using SAM_IMG_SIZE.
288    let img = SAM_IMG_SIZE as f32;
289    let normed: Vec<f32> = coords_in_pixels.iter().map(|c| c / img).collect();
290    let mut emb = pe_encode_normalized(w, &normed, n);
291
292    if is_box {
293        // Box: 2 corners per box. Alternate [2], [3] per pair.
294        for i in 0..n {
295            let pe_idx = if i % 2 == 0 { 2 } else { 3 };
296            for k in 0..e {
297                emb[i * e + k] += w.point_embeddings[pe_idx * e + k];
298            }
299        }
300    } else if let Some(lbls) = labels {
301        ensure!(lbls.len() == n, "labels len {} ≠ n {n}", lbls.len());
302        for i in 0..n {
303            let label = lbls[i];
304            if label < 0.0 {
305                // "not-a-point" padding token replaces the PE entirely.
306                for k in 0..e {
307                    emb[i * e + k] = w.not_a_point_embed[k];
308                }
309            } else if label == 0.0 {
310                for k in 0..e {
311                    emb[i * e + k] += w.point_embeddings[k];
312                }
313            } else {
314                // label == 1.0 (foreground)
315                for k in 0..e {
316                    emb[i * e + k] += w.point_embeddings[e + k];
317                }
318            }
319        }
320    }
321    Ok(emb)
322}
323
324/// Mask downscaling: Conv(k=2, s=2) → LN2d → GELU → Conv(k=2, s=2)
325/// → LN2d → GELU → Conv(k=1) → `[embed_dim, hw, hw]`.
326///
327/// Input `mask`: `[1, 4·hw, 4·hw]` (256×256 for ViT-B).
328fn embed_mask(stack: &mut SamPromptMaskCompiled, mask: &[f32], hw: usize) -> Result<Vec<f32>> {
329    let in_h = 4 * hw;
330    let in_w = 4 * hw;
331    ensure!(
332        mask.len() == in_h * in_w,
333        "mask must be [1, {in_h}, {in_w}], got len {}",
334        mask.len()
335    );
336    // NCHW `[1, 1, H, W]` for the compiled mask-downscaling graph.
337    stack.run(mask, in_h, in_w)
338}
339
340// ─── Tiny CPU kernels (host-side) ────────────────────────────────
341
342/// 2-D conv with kernel=2 stride=2 padding=0, NCHW.
343/// Each input pixel produces a 2×2 patch contribution that doesn't
344/// overlap with neighbours. Equivalent to im2col + matmul but kept
345/// inline for clarity.
346#[allow(dead_code)]
347fn conv2d_stride2_k2_pad0(
348    input: &[f32],
349    in_c: usize,
350    out_c: usize,
351    in_h: usize,
352    in_w: usize,
353    weight: &[f32], // [out_c, in_c, 2, 2]
354    bias: &[f32],   // [out_c]
355) -> Vec<f32> {
356    let out_h = in_h / 2;
357    let out_w = in_w / 2;
358    let mut out = vec![0f32; out_c * out_h * out_w];
359    for oc in 0..out_c {
360        for oy in 0..out_h {
361            for ox in 0..out_w {
362                let mut acc = bias[oc];
363                for ic in 0..in_c {
364                    for ky in 0..2 {
365                        let iy = oy * 2 + ky;
366                        for kx in 0..2 {
367                            let ix = ox * 2 + kx;
368                            let v = input[ic * in_h * in_w + iy * in_w + ix];
369                            let w_idx = ((oc * in_c + ic) * 2 + ky) * 2 + kx;
370                            acc += v * weight[w_idx];
371                        }
372                    }
373                }
374                out[oc * out_h * out_w + oy * out_w + ox] = acc;
375            }
376        }
377    }
378    out
379}
380
381/// 1×1 Conv2d = per-pixel matmul.
382#[allow(dead_code)]
383fn conv2d_1x1(
384    input: &[f32],
385    in_c: usize,
386    out_c: usize,
387    h: usize,
388    w: usize,
389    weight: &[f32], // [out_c, in_c, 1, 1]
390    bias: &[f32],   // [out_c]
391) -> Vec<f32> {
392    let mut out = vec![0f32; out_c * h * w];
393    for oc in 0..out_c {
394        let b = bias[oc];
395        for y in 0..h {
396            for x in 0..w {
397                let mut acc = b;
398                for ic in 0..in_c {
399                    acc += input[ic * h * w + y * w + x] * weight[oc * in_c + ic];
400                }
401                out[oc * h * w + y * w + x] = acc;
402            }
403        }
404    }
405    out
406}
407
408/// LayerNorm over the channel axis of NCHW (per spatial pos).
409/// Matches candle's `LayerNorm2d`.
410#[allow(dead_code)]
411fn layernorm2d_nchw(
412    data: &mut [f32],
413    c: usize,
414    h: usize,
415    w: usize,
416    gamma: &[f32],
417    beta: &[f32],
418    eps: f32,
419) {
420    let n = h * w;
421    for i in 0..n {
422        let mut mean = 0f32;
423        for k in 0..c {
424            mean += data[k * n + i];
425        }
426        mean /= c as f32;
427        let mut var = 0f32;
428        for k in 0..c {
429            let d = data[k * n + i] - mean;
430            var += d * d;
431        }
432        var /= c as f32;
433        let inv = 1.0 / (var + eps).sqrt();
434        for k in 0..c {
435            let v = (data[k * n + i] - mean) * inv;
436            data[k * n + i] = v * gamma[k] + beta[k];
437        }
438    }
439}
440
441/// Exact erf-based GELU (candle's `Activation::Gelu` → `gelu_erf`).
442#[allow(dead_code)]
443pub(super) fn gelu_erf_inplace(data: &mut [f32]) {
444    const INV_SQRT2: f32 = std::f32::consts::FRAC_1_SQRT_2;
445    for v in data.iter_mut() {
446        // Abramowitz & Stegun erf approximation — same constants as
447        // the rlx-cpu NEON kernel, so numerical agreement is exact.
448        let x = *v;
449        let s = (x * INV_SQRT2).abs();
450        let p = 0.327_591_1;
451        let a1 = 0.254_829_6;
452        let a2 = -0.284_496_7;
453        let a3 = 1.421_413_8;
454        let a4 = -1.453_152_1;
455        let a5 = 1.061_405_4;
456        let t = 1.0 / (1.0 + p * s);
457        let y = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
458        let erf_abs = 1.0 - y * (-s * s).exp();
459        let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
460        *v = 0.5 * x * (1.0 + erf);
461    }
462}
463
464#[cfg(test)]
465#[allow(dead_code)]
466pub(super) fn assert_shape(label: &str, actual: usize, expected: usize) {
467    assert_eq!(actual, expected, "{label}: {actual} ≠ {expected}");
468}
469
470#[allow(dead_code)]
471fn _silence_constant() {
472    let _ = SAM_PROMPT_EMBED_DIM;
473}