Skip to main content

rlx_sam/
mask_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 mask decoder — transformer host-side; upscaling via IR graph.
17//!
18//! Two-way transformer over (point tokens, image embeddings) → mask
19//! token outputs → ConvTranspose2d ×2 upscaling → hypernetwork MLPs
20//! → mask logits + IoU predictions. Mirrors candle's
21//! `mask_decoder.rs`.
22
23use super::config::SAM_EMBED_HW;
24use super::transformer::{
25    TwoWayTransformerWeights, extract_two_way_transformer_weights, linear,
26    two_way_transformer_forward,
27};
28use super::upscale_ir::SamMaskUpscaleCompiled;
29use anyhow::{Result, ensure};
30use rlx_core::weight_map::WeightMap;
31use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
32use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
33use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
34
35pub struct MaskDecoderWeights {
36    pub iou_token: Vec<f32>,   // [1, transformer_dim]
37    pub mask_tokens: Vec<f32>, // [num_mask_tokens, transformer_dim]
38    pub transformer: TwoWayTransformerWeights,
39
40    /// ConvTranspose2d: in=transformer_dim, out=transformer_dim/4,
41    /// kernel=2, stride=2. Weight shape `[in, out, 2, 2]`.
42    pub upscale_conv1_w: Vec<f32>,
43    pub upscale_conv1_b: Vec<f32>,
44    /// LayerNorm2d on the upscaled feature.
45    pub upscale_ln_g: Vec<f32>,
46    pub upscale_ln_b: Vec<f32>,
47    /// ConvTranspose2d: in=transformer_dim/4, out=transformer_dim/8.
48    pub upscale_conv2_w: Vec<f32>,
49    pub upscale_conv2_b: Vec<f32>,
50
51    /// `num_mask_tokens` × 3-layer ReLU MLPs (`transformer_dim → transformer_dim
52    /// → transformer_dim → transformer_dim/8`). Each MLP's flat
53    /// weights+biases stored sequentially in `hyper_mlps_*`.
54    pub hyper_mlps: Vec<HypernetMlp>,
55
56    /// IoU prediction head: 3-layer ReLU MLP `transformer_dim →
57    /// iou_head_hidden_dim → iou_head_hidden_dim → num_mask_tokens`.
58    pub iou_head: HypernetMlp,
59
60    pub transformer_dim: usize,
61    pub num_mask_tokens: usize,
62}
63
64pub struct HypernetMlp {
65    pub layers: Vec<MlpLayer>,
66}
67
68pub struct MlpLayer {
69    pub w: Vec<f32>,
70    pub b: Vec<f32>,
71    pub in_d: usize,
72    pub out_d: usize,
73}
74
75pub(super) fn extract_mask_decoder_weights(
76    weights: &mut WeightMap,
77    transformer_dim: usize,
78    num_mask_tokens: usize,
79    iou_head_depth: usize,
80    iou_head_hidden_dim: usize,
81    transformer_depth: usize,
82    transformer_num_heads: usize,
83    transformer_mlp_dim: usize,
84) -> Result<MaskDecoderWeights> {
85    let (iou_token, sh) = weights.take("mask_decoder.iou_token.weight")?;
86    ensure!(
87        sh == vec![1, transformer_dim],
88        "iou_token shape {sh:?} not [1, {transformer_dim}]"
89    );
90    let (mask_tokens, sh) = weights.take("mask_decoder.mask_tokens.weight")?;
91    ensure!(
92        sh == vec![num_mask_tokens, transformer_dim],
93        "mask_tokens shape {sh:?} not [{num_mask_tokens}, {transformer_dim}]"
94    );
95
96    // ConvTranspose2d weight convention in PyTorch: [in, out, kH, kW].
97    let q4 = transformer_dim / 4;
98    let q8 = transformer_dim / 8;
99    let (upscale_conv1_w, sh) = weights.take("mask_decoder.output_upscaling.0.weight")?;
100    ensure!(
101        sh == vec![transformer_dim, q4, 2, 2],
102        "output_upscaling.0.weight shape {sh:?} not [{transformer_dim}, {q4}, 2, 2]"
103    );
104    let (upscale_conv1_b, _) = weights.take("mask_decoder.output_upscaling.0.bias")?;
105    let (upscale_ln_g, _) = weights.take("mask_decoder.output_upscaling.1.weight")?;
106    let (upscale_ln_b, _) = weights.take("mask_decoder.output_upscaling.1.bias")?;
107    let (upscale_conv2_w, sh) = weights.take("mask_decoder.output_upscaling.3.weight")?;
108    ensure!(
109        sh == vec![q4, q8, 2, 2],
110        "output_upscaling.3.weight shape {sh:?} not [{q4}, {q8}, 2, 2]"
111    );
112    let (upscale_conv2_b, _) = weights.take("mask_decoder.output_upscaling.3.bias")?;
113
114    // Each hypernetwork MLP: 3-layer (transformer_dim → transformer_dim
115    // → transformer_dim → transformer_dim/8).
116    let mut hyper_mlps = Vec::with_capacity(num_mask_tokens);
117    for i in 0..num_mask_tokens {
118        let mlp = extract_mlp(
119            weights,
120            &format!("mask_decoder.output_hypernetworks_mlps.{i}"),
121            transformer_dim,
122            transformer_dim,
123            q8,
124            3,
125        )?;
126        hyper_mlps.push(mlp);
127    }
128
129    let iou_head = extract_mlp(
130        weights,
131        "mask_decoder.iou_prediction_head",
132        transformer_dim,
133        iou_head_hidden_dim,
134        num_mask_tokens,
135        iou_head_depth,
136    )?;
137
138    let transformer = extract_two_way_transformer_weights(
139        weights,
140        transformer_dim,
141        transformer_depth,
142        transformer_num_heads,
143        transformer_mlp_dim,
144    )?;
145
146    Ok(MaskDecoderWeights {
147        iou_token,
148        mask_tokens,
149        transformer,
150        upscale_conv1_w,
151        upscale_conv1_b,
152        upscale_ln_g,
153        upscale_ln_b,
154        upscale_conv2_w,
155        upscale_conv2_b,
156        hyper_mlps,
157        iou_head,
158        transformer_dim,
159        num_mask_tokens,
160    })
161}
162
163fn extract_mlp(
164    weights: &mut WeightMap,
165    prefix: &str,
166    input_dim: usize,
167    hidden_dim: usize,
168    output_dim: usize,
169    num_layers: usize,
170) -> Result<HypernetMlp> {
171    let mut layers = Vec::with_capacity(num_layers);
172    for i in 0..num_layers {
173        let in_d = if i == 0 { input_dim } else { hidden_dim };
174        let out_d = if i + 1 == num_layers {
175            output_dim
176        } else {
177            hidden_dim
178        };
179        let (w, sh) = weights.take(&format!("{prefix}.layers.{i}.weight"))?;
180        ensure!(
181            sh == vec![out_d, in_d],
182            "{prefix}.layers.{i}.weight shape {sh:?} not [{out_d}, {in_d}]"
183        );
184        let (b, _) = weights.take(&format!("{prefix}.layers.{i}.bias"))?;
185        layers.push(MlpLayer { w, b, in_d, out_d });
186    }
187    Ok(HypernetMlp { layers })
188}
189
190/// Forward through a ReLU MLP. Input `[rows, layer0.in_d]`, output
191/// `[rows, last_layer.out_d]`. The final layer is NOT followed by ReLU.
192pub fn mlp_forward(mlp: &HypernetMlp, x: &[f32], rows: usize) -> Vec<f32> {
193    let mut cur = x.to_vec();
194    let n = mlp.layers.len();
195    for (i, layer) in mlp.layers.iter().enumerate() {
196        cur = linear(&cur, &layer.w, &layer.b, rows, layer.in_d, layer.out_d);
197        if i + 1 < n {
198            for v in cur.iter_mut() {
199                if *v < 0.0 {
200                    *v = 0.0;
201                }
202            }
203        }
204    }
205    cur
206}
207
208/// Forward through the mask decoder, returning (masks, iou_pred).
209///
210/// `image_embeddings`: NCHW `[1, C=256, hw, hw]`.
211/// `image_pe`: NCHW `[1, C=256, hw, hw]`.
212/// `sparse_prompt_embeddings`: `[1, num_sparse, E]` (may have 0 sparse tokens).
213/// `dense_prompt_embeddings`: `[1, E, hw, hw]`.
214///
215/// `multimask_output`: if true, return masks[..., 1:4] (3 candidates);
216/// else return masks[..., 0:1] (the single "best" output).
217///
218/// Output shapes:
219///   - masks: `[1, num_masks, 4·hw, 4·hw]`
220///     (num_masks = 3 if multimask_output else 1).
221///   - iou_pred: `[1, num_masks]`.
222pub fn mask_decoder_forward(
223    w: &MaskDecoderWeights,
224    upscale: &mut SamMaskUpscaleCompiled,
225    hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
226    hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
227    iou_head_ir: Option<&mut MlpReluCompiled>,
228    tw_ir: Option<&mut TwoWayTransformerCompiled>,
229    image_embeddings: &[f32],
230    image_pe: &[f32],
231    sparse_prompt_embeddings: &[f32],
232    num_sparse_tokens: usize,
233    dense_prompt_embeddings: &[f32],
234    multimask_output: bool,
235) -> Result<(Vec<f32>, Vec<f32>, usize, usize)> {
236    let e = w.transformer_dim;
237    let hw = SAM_EMBED_HW;
238    ensure!(
239        image_embeddings.len() == e * hw * hw,
240        "image_embeddings len {} ≠ E·hw·hw ({e}·{hw}·{hw})",
241        image_embeddings.len()
242    );
243    ensure!(
244        image_pe.len() == e * hw * hw,
245        "image_pe len {} ≠ E·hw·hw",
246        image_pe.len()
247    );
248    ensure!(
249        dense_prompt_embeddings.len() == e * hw * hw,
250        "dense_prompt_embeddings len {} ≠ E·hw·hw",
251        dense_prompt_embeddings.len()
252    );
253    ensure!(
254        sparse_prompt_embeddings.len() == num_sparse_tokens * e,
255        "sparse_prompt_embeddings len {} ≠ num_sparse·E ({num_sparse_tokens}·{e})",
256        sparse_prompt_embeddings.len()
257    );
258
259    // ── Build `tokens` = cat(iou_token, mask_tokens, sparse_prompts) ──
260    // Output tokens (iou + mask): shape [1 + num_mask_tokens, E]
261    let nm = w.num_mask_tokens;
262    let n_out_tokens = 1 + nm;
263    let q_n = n_out_tokens + num_sparse_tokens;
264    let mut tokens = Vec::with_capacity(q_n * e);
265    tokens.extend_from_slice(&w.iou_token); // [1, E]
266    tokens.extend_from_slice(&w.mask_tokens); // [nm, E]
267    tokens.extend_from_slice(sparse_prompt_embeddings); // [num_sparse, E]
268    // Shape [1, q_n, E].
269
270    // ── src = image_embeddings + dense_prompt_embeddings ──
271    let mut src = image_embeddings.to_vec();
272    for i in 0..src.len() {
273        src[i] += dense_prompt_embeddings[i];
274    }
275    let pos_src = image_pe.to_vec();
276
277    // ── Run the two-way transformer ──
278    let k_n = hw * hw;
279    let (hs, src_post) = if let Some(tw) = tw_ir {
280        if tw.masked && q_n <= tw.max_q_n && tw.k_n == k_n {
281            tw.run_nchw_masked(&tokens, q_n, &src, &pos_src, hw)?
282        } else if !tw.masked && q_n == tw.max_q_n && tw.k_n == k_n {
283            tw.run_nchw(&tokens, &src, &pos_src, hw)?
284        } else {
285            two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
286        }
287    } else {
288        two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
289    };
290    // hs: [1, q_n, E]; src_post: [1, hw*hw, E]
291
292    // iou_token_out = hs[:, 0, :] → [1, E]
293    let iou_token_out: Vec<f32> = hs[..e].to_vec();
294    // mask_tokens_out = hs[:, 1..1+nm, :] → [1, nm, E]
295    let mask_tokens_out = &hs[e..e * (1 + nm)];
296
297    // src reshape to [B, C, H, W] (BCHW). src_post is [1, hw*hw, E];
298    // transpose to [1, E, hw*hw] then reshape to [1, E, hw, hw].
299    let mut src_nchw = vec![0f32; e * hw * hw];
300    for s in 0..hw * hw {
301        for c in 0..e {
302            src_nchw[c * hw * hw + s] = src_post[s * e + c];
303        }
304    }
305
306    // ── Upscaling via compiled IR (ConvTranspose2d → LN2d → GELU ×2) ──
307    let q8 = e / 8;
308    let h2 = hw * 4;
309    let w2 = hw * 4;
310    let up2 = upscale.run(&src_nchw)?;
311
312    // ── Per-mask hypernetwork MLPs → [nm, q8] ──
313    let mut hyper_in = vec![0f32; nm * q8];
314    if let Some(mlps) = hyper_mlps_ir {
315        ensure!(
316            mlps.len() == nm,
317            "hyper_mlps_ir len {} ≠ num_mask_tokens {}",
318            mlps.len(),
319            nm
320        );
321        for i in 0..nm {
322            let token = &mask_tokens_out[i * e..(i + 1) * e];
323            let h = mlps[i].run(token, 1)?;
324            hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
325        }
326    } else {
327        for i in 0..nm {
328            let token = &mask_tokens_out[i * e..(i + 1) * e];
329            let h = mlp_forward(&w.hyper_mlps[i], token, 1);
330            hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
331        }
332    }
333    // hyper_in: [nm, q8]. up2 flat [q8, spat].
334    // masks = hyper_in @ up2   shape [nm, spat]. BLAS-backed.
335    let spat = h2 * w2;
336    let mut masks_all = vec![0f32; nm * spat];
337    if let Some(hm) = hyper_matmul {
338        hm.run(&hyper_in, &up2, &mut masks_all)?;
339    } else {
340        rlx_cpu::blas::sgemm_auto(&hyper_in, &up2, &mut masks_all, nm, q8, spat);
341    }
342
343    // ── IoU prediction head ──
344    let iou_pred_all = if let Some(head) = iou_head_ir {
345        head.run(&iou_token_out, 1)?
346    } else {
347        mlp_forward(&w.iou_head, &iou_token_out, 1)
348    };
349
350    // ── Slice for multimask vs single ──
351    let (masks, iou_pred, num_masks) = if multimask_output {
352        // [1, 1..nm, h2, w2] = [1, nm-1, h2, w2] (3 masks for nm=4)
353        let mut masks = vec![0f32; (nm - 1) * spat];
354        masks.copy_from_slice(&masks_all[spat..]);
355        let mut iou = vec![0f32; nm - 1];
356        iou.copy_from_slice(&iou_pred_all[1..]);
357        (masks, iou, nm - 1)
358    } else {
359        let masks = masks_all[..spat].to_vec();
360        let iou = iou_pred_all[..1].to_vec();
361        (masks, iou, 1)
362    };
363
364    Ok((masks, iou_pred, num_masks, h2))
365}