Skip to main content

rlx_sam3/
detector_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 detector decoder (6 layers, 200 queries + presence token,
17//! box refinement, log boxRPB, text + image cross-attention).
18//!
19//! Mirrors `sam3.model.decoder.TransformerDecoder` configured by
20//! `model_builder._create_transformer_decoder`. Inference-time settings:
21//!
22//!   * `apply_dac=False` (DAC is training-only)
23//!   * `presence_token=True` (extra +1 query prepended in self-attn)
24//!   * `box_refine=True`, `boxRPB="log"`, `return_intermediate=True`
25//!   * `use_text_cross_attention=True`
26//!   * `num_queries=200`, `d_model=256`, `n_heads=8`, `dim_ff=2048`
27
28use super::tensor::{layer_norm, linear};
29use anyhow::{Result, ensure};
30use rlx_core::weight_map::WeightMap;
31use rlx_flow::GgufPackedParams;
32
33use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
34
35const D_MODEL: usize = 256;
36const DIM_FF: usize = 2048;
37const N_HEADS: usize = 8;
38const N_LAYERS: usize = 6;
39const NUM_QUERIES: usize = 200;
40
41#[derive(Clone)]
42pub struct Sam3DecoderLayerWeights {
43    pub self_attn_in_w_t: Vec<f32>,
44    pub self_attn_in_b: Vec<f32>,
45    pub self_attn_in_gguf_key: Option<String>,
46    pub self_attn_out_w_t: Vec<f32>,
47    pub self_attn_out_b: Vec<f32>,
48    pub self_attn_out_gguf_key: Option<String>,
49    pub ca_text_in_w_t: Vec<f32>,
50    pub ca_text_in_b: Vec<f32>,
51    pub ca_text_in_gguf_key: Option<String>,
52    pub ca_text_out_w_t: Vec<f32>,
53    pub ca_text_out_b: Vec<f32>,
54    pub ca_text_out_gguf_key: Option<String>,
55    pub cross_attn_in_w_t: Vec<f32>,
56    pub cross_attn_in_b: Vec<f32>,
57    pub cross_attn_in_gguf_key: Option<String>,
58    pub cross_attn_out_w_t: Vec<f32>,
59    pub cross_attn_out_b: Vec<f32>,
60    pub cross_attn_out_gguf_key: Option<String>,
61    pub linear1_w_t: Vec<f32>,
62    pub linear1_b: Vec<f32>,
63    pub linear1_gguf_key: Option<String>,
64    pub linear2_w_t: Vec<f32>,
65    pub linear2_b: Vec<f32>,
66    pub linear2_gguf_key: Option<String>,
67    pub norm1_w: Vec<f32>, // post image cross-attn
68    pub norm1_b: Vec<f32>,
69    pub norm2_w: Vec<f32>, // post self-attn
70    pub norm2_b: Vec<f32>,
71    pub norm3_w: Vec<f32>, // post FFN
72    pub norm3_b: Vec<f32>,
73    pub catext_norm_w: Vec<f32>, // post text cross-attn
74    pub catext_norm_b: Vec<f32>,
75}
76
77#[derive(Clone, Default)]
78pub struct Sam3DecoderWeights {
79    pub loaded: bool,
80    /// Checkpoint prefix (`detector.transformer.decoder`).
81    pub prefix: String,
82    pub layers: Vec<Sam3DecoderLayerWeights>,
83    pub query_embed: Vec<f32>,      // [num_queries, D]
84    pub reference_points: Vec<f32>, // [num_queries, 4]
85    pub norm_w: Vec<f32>,
86    pub norm_b: Vec<f32>,
87    pub bbox_embed: Mlp3,          // 256→256→256→4
88    pub ref_point_head: Mlp2,      // 512→256→256
89    pub boxrpb_x: Mlp2,            // 2→256→n_heads
90    pub boxrpb_y: Mlp2,            // 2→256→n_heads
91    pub presence_token: Vec<f32>,  // [1, D]
92    pub presence_token_head: Mlp3, // 256→256→256→1
93    pub presence_token_out_norm_w: Vec<f32>,
94    pub presence_token_out_norm_b: Vec<f32>,
95}
96
97#[derive(Clone, Default)]
98pub struct Mlp2 {
99    pub w0_t: Vec<f32>,
100    pub b0: Vec<f32>,
101    pub w1_t: Vec<f32>,
102    pub b1: Vec<f32>,
103    pub in_dim: usize,
104    pub hidden: usize,
105    pub out_dim: usize,
106    pub w0_gguf_key: Option<String>,
107    pub w1_gguf_key: Option<String>,
108}
109
110#[derive(Clone, Default)]
111pub struct Mlp3 {
112    pub w0_t: Vec<f32>,
113    pub b0: Vec<f32>,
114    pub w1_t: Vec<f32>,
115    pub b1: Vec<f32>,
116    pub w2_t: Vec<f32>,
117    pub b2: Vec<f32>,
118    pub in_dim: usize,
119    pub hidden: usize,
120    pub out_dim: usize,
121    pub w0_gguf_key: Option<String>,
122    pub w1_gguf_key: Option<String>,
123    pub w2_gguf_key: Option<String>,
124}
125
126pub fn take_mlp2(
127    weights: &mut WeightMap,
128    gguf_packed: Option<&GgufPackedParams>,
129    base: &str,
130    in_dim: usize,
131    hidden: usize,
132    out_dim: usize,
133) -> Result<Mlp2> {
134    let (w0_t, w0_gguf_key) =
135        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.0.weight"))?;
136    let (b0, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.0.bias"))?;
137    let (w1_t, w1_gguf_key) =
138        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.1.weight"))?;
139    let (b1, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.1.bias"))?;
140    Ok(Mlp2 {
141        w0_t,
142        b0,
143        w1_t,
144        b1,
145        in_dim,
146        hidden,
147        out_dim,
148        w0_gguf_key,
149        w1_gguf_key,
150    })
151}
152
153pub fn take_mlp3(
154    weights: &mut WeightMap,
155    gguf_packed: Option<&GgufPackedParams>,
156    base: &str,
157    in_dim: usize,
158    hidden: usize,
159    out_dim: usize,
160) -> Result<Mlp3> {
161    let (w0_t, w0_gguf_key) =
162        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.0.weight"))?;
163    let (b0, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.0.bias"))?;
164    let (w1_t, w1_gguf_key) =
165        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.1.weight"))?;
166    let (b1, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.1.bias"))?;
167    let (w2_t, w2_gguf_key) =
168        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.2.weight"))?;
169    let (b2, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.2.bias"))?;
170    Ok(Mlp3 {
171        w0_t,
172        b0,
173        w1_t,
174        b1,
175        w2_t,
176        b2,
177        in_dim,
178        hidden,
179        out_dim,
180        w0_gguf_key,
181        w1_gguf_key,
182        w2_gguf_key,
183    })
184}
185
186pub fn mlp2_forward(
187    mlp: &Mlp2,
188    x: &[f32],
189    rows: usize,
190    gguf_packed: Option<&GgufPackedParams>,
191) -> Result<Vec<f32>> {
192    let mut h = linear_maybe_gguf(
193        x,
194        rows,
195        mlp.in_dim,
196        &mlp.w0_t,
197        mlp.w0_gguf_key.as_deref(),
198        gguf_packed,
199        mlp.hidden,
200        &mlp.b0,
201    )?;
202    for v in h.iter_mut() {
203        if *v < 0.0 {
204            *v = 0.0;
205        }
206    }
207    linear_maybe_gguf(
208        &h,
209        rows,
210        mlp.hidden,
211        &mlp.w1_t,
212        mlp.w1_gguf_key.as_deref(),
213        gguf_packed,
214        mlp.out_dim,
215        &mlp.b1,
216    )
217}
218
219/// Apply the decoder's `bbox_embed` MLP to `[rows, D]` and return `[rows, 4]` deltas.
220pub fn bbox_embed_forward(
221    weights: &Sam3DecoderWeights,
222    x: &[f32],
223    rows: usize,
224    gguf_packed: Option<&GgufPackedParams>,
225) -> Result<Vec<f32>> {
226    mlp3_forward(&weights.bbox_embed, x, rows, gguf_packed)
227}
228
229pub fn mlp3_forward(
230    mlp: &Mlp3,
231    x: &[f32],
232    rows: usize,
233    gguf_packed: Option<&GgufPackedParams>,
234) -> Result<Vec<f32>> {
235    let mut h = linear_maybe_gguf(
236        x,
237        rows,
238        mlp.in_dim,
239        &mlp.w0_t,
240        mlp.w0_gguf_key.as_deref(),
241        gguf_packed,
242        mlp.hidden,
243        &mlp.b0,
244    )?;
245    for v in h.iter_mut() {
246        if *v < 0.0 {
247            *v = 0.0;
248        }
249    }
250    h = linear_maybe_gguf(
251        &h,
252        rows,
253        mlp.hidden,
254        &mlp.w1_t,
255        mlp.w1_gguf_key.as_deref(),
256        gguf_packed,
257        mlp.hidden,
258        &mlp.b1,
259    )?;
260    for v in h.iter_mut() {
261        if *v < 0.0 {
262            *v = 0.0;
263        }
264    }
265    linear_maybe_gguf(
266        &h,
267        rows,
268        mlp.hidden,
269        &mlp.w2_t,
270        mlp.w2_gguf_key.as_deref(),
271        gguf_packed,
272        mlp.out_dim,
273        &mlp.b2,
274    )
275}
276
277/// In-place mlp2: `out = w1·relu(w0·x + b0) + b1`. Uses BLAS when weights are F32.
278pub fn mlp2_forward_into(
279    mlp: &Mlp2,
280    x: &[f32],
281    rows: usize,
282    hidden: &mut [f32],
283    out: &mut [f32],
284    gguf_packed: Option<&GgufPackedParams>,
285) -> Result<()> {
286    if mlp.w0_gguf_key.is_none() && !mlp.w0_t.is_empty() {
287        rlx_cpu::blas::sgemm_bias_epilogue(
288            x,
289            &mlp.w0_t,
290            &mlp.b0,
291            hidden,
292            rows,
293            mlp.in_dim,
294            mlp.hidden,
295            |v| if v < 0.0 { 0.0 } else { v },
296        );
297        rlx_cpu::blas::sgemm_bias(
298            hidden,
299            &mlp.w1_t,
300            &mlp.b1,
301            out,
302            rows,
303            mlp.hidden,
304            mlp.out_dim,
305        );
306        return Ok(());
307    }
308    let mut h = linear_maybe_gguf(
309        x,
310        rows,
311        mlp.in_dim,
312        &mlp.w0_t,
313        mlp.w0_gguf_key.as_deref(),
314        gguf_packed,
315        mlp.hidden,
316        &mlp.b0,
317    )?;
318    for v in h.iter_mut() {
319        if *v < 0.0 {
320            *v = 0.0;
321        }
322    }
323    hidden.copy_from_slice(&h);
324    let h2 = linear_maybe_gguf(
325        hidden,
326        rows,
327        mlp.hidden,
328        &mlp.w1_t,
329        mlp.w1_gguf_key.as_deref(),
330        gguf_packed,
331        mlp.out_dim,
332        &mlp.b1,
333    )?;
334    out.copy_from_slice(&h2);
335    Ok(())
336}
337
338/// In-place mlp3 into caller buffers (no alloc when F32 + BLAS).
339pub fn mlp3_forward_into(
340    mlp: &Mlp3,
341    x: &[f32],
342    rows: usize,
343    h0: &mut [f32],
344    h1: &mut [f32],
345    out: &mut [f32],
346    gguf_packed: Option<&GgufPackedParams>,
347) -> Result<()> {
348    if mlp.w0_gguf_key.is_none() && !mlp.w0_t.is_empty() {
349        let relu = |v: f32| if v < 0.0 { 0.0 } else { v };
350        rlx_cpu::blas::sgemm_bias_epilogue(
351            x, &mlp.w0_t, &mlp.b0, h0, rows, mlp.in_dim, mlp.hidden, relu,
352        );
353        rlx_cpu::blas::sgemm_bias_epilogue(
354            h0, &mlp.w1_t, &mlp.b1, h1, rows, mlp.hidden, mlp.hidden, relu,
355        );
356        rlx_cpu::blas::sgemm_bias(h1, &mlp.w2_t, &mlp.b2, out, rows, mlp.hidden, mlp.out_dim);
357        return Ok(());
358    }
359    let o = mlp3_forward(mlp, x, rows, gguf_packed)?;
360    out.copy_from_slice(&o);
361    Ok(())
362}
363
364pub fn extract_decoder_weights(
365    weights: &mut WeightMap,
366    gguf_packed: Option<&GgufPackedParams>,
367) -> Result<Sam3DecoderWeights> {
368    let base = "detector.transformer.decoder";
369    ensure!(
370        weights.has(&format!("{base}.query_embed.weight")),
371        "SAM3 detector decoder not found"
372    );
373
374    let mut layers = Vec::with_capacity(N_LAYERS);
375    for i in 0..N_LAYERS {
376        let p = format!("{base}.layers.{i}");
377        let (self_attn_in_w_t, self_attn_in_gguf_key) = take_transposed_with_gguf_key(
378            weights,
379            gguf_packed,
380            &format!("{p}.self_attn.in_proj_weight"),
381        )?;
382        let (self_attn_in_b, _) =
383            take_or_gguf(weights, gguf_packed, &format!("{p}.self_attn.in_proj_bias"))?;
384        let (self_attn_out_w_t, self_attn_out_gguf_key) = take_transposed_with_gguf_key(
385            weights,
386            gguf_packed,
387            &format!("{p}.self_attn.out_proj.weight"),
388        )?;
389        let (self_attn_out_b, _) = take_or_gguf(
390            weights,
391            gguf_packed,
392            &format!("{p}.self_attn.out_proj.bias"),
393        )?;
394        let (ca_text_in_w_t, ca_text_in_gguf_key) = take_transposed_with_gguf_key(
395            weights,
396            gguf_packed,
397            &format!("{p}.ca_text.in_proj_weight"),
398        )?;
399        let (ca_text_in_b, _) =
400            take_or_gguf(weights, gguf_packed, &format!("{p}.ca_text.in_proj_bias"))?;
401        let (ca_text_out_w_t, ca_text_out_gguf_key) = take_transposed_with_gguf_key(
402            weights,
403            gguf_packed,
404            &format!("{p}.ca_text.out_proj.weight"),
405        )?;
406        let (ca_text_out_b, _) =
407            take_or_gguf(weights, gguf_packed, &format!("{p}.ca_text.out_proj.bias"))?;
408        let (cross_attn_in_w_t, cross_attn_in_gguf_key) = take_transposed_with_gguf_key(
409            weights,
410            gguf_packed,
411            &format!("{p}.cross_attn.in_proj_weight"),
412        )?;
413        let (cross_attn_in_b, _) = take_or_gguf(
414            weights,
415            gguf_packed,
416            &format!("{p}.cross_attn.in_proj.bias"),
417        )?;
418        let (cross_attn_out_w_t, cross_attn_out_gguf_key) = take_transposed_with_gguf_key(
419            weights,
420            gguf_packed,
421            &format!("{p}.cross_attn.out_proj.weight"),
422        )?;
423        let (cross_attn_out_b, _) = take_or_gguf(
424            weights,
425            gguf_packed,
426            &format!("{p}.cross_attn.out_proj.bias"),
427        )?;
428        let (linear1_w_t, linear1_gguf_key) =
429            take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear1.weight"))?;
430        let (linear1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear1.bias"))?;
431        let (linear2_w_t, linear2_gguf_key) =
432            take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear2.weight"))?;
433        let (linear2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear2.bias"))?;
434        let (norm1_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm1.weight"))?;
435        let (norm1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm1.bias"))?;
436        let (norm2_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm2.weight"))?;
437        let (norm2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm2.bias"))?;
438        let (norm3_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm3.weight"))?;
439        let (norm3_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm3.bias"))?;
440        let (catext_norm_w, _) =
441            take_or_gguf(weights, gguf_packed, &format!("{p}.catext_norm.weight"))?;
442        let (catext_norm_b, _) =
443            take_or_gguf(weights, gguf_packed, &format!("{p}.catext_norm.bias"))?;
444        layers.push(Sam3DecoderLayerWeights {
445            self_attn_in_w_t,
446            self_attn_in_b,
447            self_attn_in_gguf_key,
448            self_attn_out_w_t,
449            self_attn_out_b,
450            self_attn_out_gguf_key,
451            ca_text_in_w_t,
452            ca_text_in_b,
453            ca_text_in_gguf_key,
454            ca_text_out_w_t,
455            ca_text_out_b,
456            ca_text_out_gguf_key,
457            cross_attn_in_w_t,
458            cross_attn_in_b,
459            cross_attn_in_gguf_key,
460            cross_attn_out_w_t,
461            cross_attn_out_b,
462            cross_attn_out_gguf_key,
463            linear1_w_t,
464            linear1_b,
465            linear1_gguf_key,
466            linear2_w_t,
467            linear2_b,
468            linear2_gguf_key,
469            norm1_w,
470            norm1_b,
471            norm2_w,
472            norm2_b,
473            norm3_w,
474            norm3_b,
475            catext_norm_w,
476            catext_norm_b,
477        });
478    }
479
480    let (query_embed, qs) =
481        take_or_gguf(weights, gguf_packed, &format!("{base}.query_embed.weight"))?;
482    ensure!(qs == vec![NUM_QUERIES, D_MODEL], "query_embed shape {qs:?}");
483    let (reference_points, rs) = take_or_gguf(
484        weights,
485        gguf_packed,
486        &format!("{base}.reference_points.weight"),
487    )?;
488    ensure!(rs == vec![NUM_QUERIES, 4], "reference_points shape {rs:?}");
489    let (norm_w, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.norm.weight"))?;
490    let (norm_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.norm.bias"))?;
491    let bbox_embed = take_mlp3(
492        weights,
493        gguf_packed,
494        &format!("{base}.bbox_embed"),
495        D_MODEL,
496        D_MODEL,
497        4,
498    )?;
499    let ref_point_head = take_mlp2(
500        weights,
501        gguf_packed,
502        &format!("{base}.ref_point_head"),
503        2 * D_MODEL,
504        D_MODEL,
505        D_MODEL,
506    )?;
507    let boxrpb_x = take_mlp2(
508        weights,
509        gguf_packed,
510        &format!("{base}.boxRPB_embed_x"),
511        2,
512        D_MODEL,
513        N_HEADS,
514    )?;
515    let boxrpb_y = take_mlp2(
516        weights,
517        gguf_packed,
518        &format!("{base}.boxRPB_embed_y"),
519        2,
520        D_MODEL,
521        N_HEADS,
522    )?;
523    let (presence_token, ps) = take_or_gguf(
524        weights,
525        gguf_packed,
526        &format!("{base}.presence_token.weight"),
527    )?;
528    ensure!(ps == vec![1, D_MODEL], "presence_token shape {ps:?}");
529    let presence_token_head = take_mlp3(
530        weights,
531        gguf_packed,
532        &format!("{base}.presence_token_head"),
533        D_MODEL,
534        D_MODEL,
535        1,
536    )?;
537    let (presence_token_out_norm_w, _) = take_or_gguf(
538        weights,
539        gguf_packed,
540        &format!("{base}.presence_token_out_norm.weight"),
541    )?;
542    let (presence_token_out_norm_b, _) = take_or_gguf(
543        weights,
544        gguf_packed,
545        &format!("{base}.presence_token_out_norm.bias"),
546    )?;
547
548    Ok(Sam3DecoderWeights {
549        loaded: true,
550        prefix: base.to_string(),
551        layers,
552        query_embed,
553        reference_points,
554        norm_w,
555        norm_b,
556        bbox_embed,
557        ref_point_head,
558        boxrpb_x,
559        boxrpb_y,
560        presence_token,
561        presence_token_head,
562        presence_token_out_norm_w,
563        presence_token_out_norm_b,
564    })
565}
566
567#[derive(Debug, Clone, Default)]
568pub struct Sam3DecoderOutput {
569    /// `[num_layers, num_queries, batch, d_model]` post-norm.
570    pub intermediate: Vec<f32>,
571    /// `[num_layers, num_queries, batch, 4]` refined reference boxes
572    /// (sigmoid scale). The first entry is the *initial* boxes; the last
573    /// is layer 5's output (per upstream convention).
574    pub intermediate_ref_boxes: Vec<f32>,
575    /// `[num_layers, batch, 1]` per-layer presence logits.
576    pub presence_logits: Vec<f32>,
577    /// `[1, batch, d_model]` final presence features.
578    pub presence_feats: Vec<f32>,
579    pub num_layers: usize,
580    pub num_queries: usize,
581    pub batch: usize,
582    pub d_model: usize,
583}
584
585/// Sine/cos encoding of a 4D position tensor `[nq, bs, 4]` → `[nq, bs, 2*D]`
586/// matching `sam3.model.model_misc.gen_sineembed_for_position` with
587/// `num_feats=256`.
588fn sineembed_for_position_4d(pos: &[f32], nq: usize, bs: usize, d_model: usize) -> Vec<f32> {
589    let half = d_model / 2;
590    let scale = 2.0 * std::f32::consts::PI;
591    let mut dim_t = vec![0.0f32; half];
592    for i in 0..half {
593        let exp = 2.0 * ((i / 2) as f32) / half as f32;
594        dim_t[i] = 10000.0f32.powf(exp);
595    }
596    let mut out = vec![0.0f32; nq * bs * 2 * d_model];
597    for q in 0..nq {
598        for b in 0..bs {
599            let p = &pos[(q * bs + b) * 4..(q * bs + b + 1) * 4];
600            let x_e = p[0] * scale;
601            let y_e = p[1] * scale;
602            let w_e = p[2] * scale;
603            let h_e = p[3] * scale;
604            // Layout: cat([pos_y, pos_x, pos_w, pos_h], dim=-1).
605            let base = (q * bs + b) * 2 * d_model;
606            for axis in 0..4 {
607                let val = [y_e, x_e, w_e, h_e][axis];
608                let slot = base + axis * half;
609                for i in 0..half {
610                    let theta = val / dim_t[i];
611                    out[slot + i] = if i % 2 == 0 { theta.sin() } else { theta.cos() };
612                }
613            }
614        }
615    }
616    out
617}
618
619fn inverse_sigmoid(x: f32) -> f32 {
620    let eps = 1e-3f32;
621    let x = x.clamp(0.0, 1.0);
622    let x1 = x.max(eps);
623    let x2 = (1.0 - x).max(eps);
624    (x1 / x2).ln()
625}
626
627fn sigmoid(x: f32) -> f32 {
628    1.0 / (1.0 + (-x).exp())
629}
630
631/// Compute the log-scale boxRPB attention mask for one batch element.
632/// Returns flat `[n_heads * num_queries, H * W]` to be broadcast-added to
633/// the attention scores per head.
634fn boxrpb_log_mask(
635    weights: &Sam3DecoderWeights,
636    reference_boxes: &[f32], // [nq, 1, 4] cxcywh in [0, 1]
637    nq: usize,
638    h: usize,
639    w: usize,
640    gguf_packed: Option<&GgufPackedParams>,
641) -> Result<Vec<f32>> {
642    // coords_h[y] = y / H, coords_w[x] = x / W.
643    let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
644    let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
645
646    // For each query, compute boxes_xyxy = (cx-w/2, cy-h/2, cx+w/2, cy+h/2).
647    let mut deltas_x = vec![0f32; nq * w * 2];
648    let mut deltas_y = vec![0f32; nq * h * 2];
649    for q in 0..nq {
650        let p = &reference_boxes[q * 4..(q + 1) * 4];
651        let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
652        let x0 = cx - 0.5 * bw;
653        let x1 = cx + 0.5 * bw;
654        let y0 = cy - 0.5 * bh;
655        let y1 = cy + 0.5 * bh;
656        for xi in 0..w {
657            let dx0 = (coords_w[xi] - x0) * 8.0;
658            let dx1 = (coords_w[xi] - x1) * 8.0;
659            deltas_x[(q * w + xi) * 2] = log_norm(dx0);
660            deltas_x[(q * w + xi) * 2 + 1] = log_norm(dx1);
661        }
662        for yi in 0..h {
663            let dy0 = (coords_h[yi] - y0) * 8.0;
664            let dy1 = (coords_h[yi] - y1) * 8.0;
665            deltas_y[(q * h + yi) * 2] = log_norm(dy0);
666            deltas_y[(q * h + yi) * 2 + 1] = log_norm(dy1);
667        }
668    }
669    // MLPs: [nq*W, 2] → [nq*W, n_heads].
670    let dx_feats = mlp2_forward(&weights.boxrpb_x, &deltas_x, nq * w, gguf_packed)?;
671    let dy_feats = mlp2_forward(&weights.boxrpb_y, &deltas_y, nq * h, gguf_packed)?;
672
673    // B[q, y, x, head] = dy_feats[q, y, head] + dx_feats[q, x, head].
674    // Repack to [n_heads, nq, h*w] for use as additive attention mask.
675    let mut out = vec![0f32; N_HEADS * nq * h * w];
676    for q in 0..nq {
677        for y in 0..h {
678            for x in 0..w {
679                for head in 0..N_HEADS {
680                    let dy = dy_feats[(q * h + y) * N_HEADS + head];
681                    let dx = dx_feats[(q * w + x) * N_HEADS + head];
682                    out[(head * nq + q) * h * w + y * w + x] = dy + dx;
683                }
684            }
685        }
686    }
687    Ok(out)
688}
689
690fn log_norm(v: f32) -> f32 {
691    // sign(v) * log2(|v| + 1) / log2(8)
692    let s = if v < 0.0 { -1.0 } else { 1.0 };
693    s * (v.abs() + 1.0).log2() / 8.0f32.log2()
694}
695
696fn narrow_last(row: &[f32], rows: usize, width: usize, start: usize, len: usize) -> Vec<f32> {
697    let mut out = vec![0f32; rows * len];
698    for r in 0..rows {
699        out[r * len..(r + 1) * len]
700            .copy_from_slice(&row[r * width + start..r * width + start + len]);
701    }
702    out
703}
704
705/// Multi-head attention with optional GGUF packed linears, per-head bias, and key mask.
706#[allow(clippy::too_many_arguments)]
707pub(crate) fn mha_with_bias_maybe_gguf(
708    q: &[f32],
709    k: &[f32],
710    v: &[f32],
711    in_proj_w_t: &[f32],
712    in_proj_b: &[f32],
713    in_gguf_key: Option<&str>,
714    out_proj_w_t: &[f32],
715    out_proj_b: &[f32],
716    out_gguf_key: Option<&str>,
717    gguf_packed: Option<&GgufPackedParams>,
718    batch: usize,
719    l_q: usize,
720    l_k: usize,
721    embed_dim: usize,
722    num_heads: usize,
723    attn_bias_h_lq_lk: Option<&[f32]>,
724    key_padding_mask: Option<&[u8]>,
725) -> Result<Vec<f32>> {
726    if in_gguf_key.is_none() && out_gguf_key.is_none() {
727        return mha_with_bias_f32(
728            q,
729            k,
730            v,
731            in_proj_w_t,
732            in_proj_b,
733            out_proj_w_t,
734            out_proj_b,
735            batch,
736            l_q,
737            l_k,
738            embed_dim,
739            num_heads,
740            attn_bias_h_lq_lk,
741            key_padding_mask,
742        );
743    }
744
745    use super::tensor::{matmul, matmul_bt, softmax_rows};
746    let head_dim = embed_dim / num_heads;
747    let rows_q = batch * l_q;
748    let rows_k = batch * l_k;
749
750    let (q_proj, k_proj, v_proj) = if let Some(in_key) = in_gguf_key {
751        let qkv_q = linear_maybe_gguf(
752            q,
753            rows_q,
754            embed_dim,
755            in_proj_w_t,
756            Some(in_key),
757            gguf_packed,
758            3 * embed_dim,
759            in_proj_b,
760        )?;
761        let qkv_k = linear_maybe_gguf(
762            k,
763            rows_k,
764            embed_dim,
765            in_proj_w_t,
766            Some(in_key),
767            gguf_packed,
768            3 * embed_dim,
769            in_proj_b,
770        )?;
771        let qkv_v = linear_maybe_gguf(
772            v,
773            rows_k,
774            embed_dim,
775            in_proj_w_t,
776            Some(in_key),
777            gguf_packed,
778            3 * embed_dim,
779            in_proj_b,
780        )?;
781        (
782            narrow_last(&qkv_q, rows_q, 3 * embed_dim, 0, embed_dim),
783            narrow_last(&qkv_k, rows_k, 3 * embed_dim, embed_dim, embed_dim),
784            narrow_last(&qkv_v, rows_k, 3 * embed_dim, 2 * embed_dim, embed_dim),
785        )
786    } else {
787        let (wq, wk, wv) = split3(in_proj_w_t, embed_dim);
788        let bq = &in_proj_b[0..embed_dim];
789        let bk = &in_proj_b[embed_dim..2 * embed_dim];
790        let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
791        (
792            linear_maybe_gguf(q, rows_q, embed_dim, &wq, None, gguf_packed, embed_dim, bq)?,
793            linear_maybe_gguf(k, rows_k, embed_dim, &wk, None, gguf_packed, embed_dim, bk)?,
794            linear_maybe_gguf(v, rows_k, embed_dim, &wv, None, gguf_packed, embed_dim, bv)?,
795        )
796    };
797
798    let bh = batch * num_heads;
799    let mut qh = vec![0f32; bh * l_q * head_dim];
800    let mut kh = vec![0f32; bh * l_k * head_dim];
801    let mut vh = vec![0f32; bh * l_k * head_dim];
802    repack(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
803    repack(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
804    repack(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
805
806    let scale = 1.0f32 / (head_dim as f32).sqrt();
807    let mut scores = vec![0f32; l_q * l_k];
808    let mut attn_out = vec![0f32; bh * l_q * head_dim];
809    for bi in 0..batch {
810        for h in 0..num_heads {
811            let bhi = bi * num_heads + h;
812            let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
813            let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
814            let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
815            matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
816            if let Some(bias) = attn_bias_h_lq_lk {
817                let bias_h = &bias[h * l_q * l_k..(h + 1) * l_q * l_k];
818                for i in 0..scores.len() {
819                    scores[i] += bias_h[i];
820                }
821            }
822            if let Some(mask) = key_padding_mask {
823                let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
824                for r in 0..l_q {
825                    let row = &mut scores[r * l_k..(r + 1) * l_k];
826                    for (c, m) in mask_b.iter().enumerate() {
827                        if *m != 0 {
828                            row[c] = f32::NEG_INFINITY;
829                        }
830                    }
831                }
832            }
833            softmax_rows(&mut scores, l_q, l_k);
834            let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
835            matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
836        }
837    }
838
839    let mut packed = vec![0f32; batch * l_q * embed_dim];
840    for bi in 0..batch {
841        for l in 0..l_q {
842            for h in 0..num_heads {
843                let src = ((bi * num_heads + h) * l_q + l) * head_dim;
844                let dst = (bi * l_q + l) * embed_dim + h * head_dim;
845                packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
846            }
847        }
848    }
849    linear_maybe_gguf(
850        &packed,
851        batch * l_q,
852        embed_dim,
853        out_proj_w_t,
854        out_gguf_key,
855        gguf_packed,
856        embed_dim,
857        out_proj_b,
858    )
859}
860
861/// F32-only MHA (used when weights are materialized at extract).
862#[allow(clippy::too_many_arguments)]
863fn mha_with_bias_f32(
864    q: &[f32],
865    k: &[f32],
866    v: &[f32],
867    in_proj_w_t: &[f32],
868    in_proj_b: &[f32],
869    out_proj_w_t: &[f32],
870    out_proj_b: &[f32],
871    batch: usize,
872    l_q: usize,
873    l_k: usize,
874    embed_dim: usize,
875    num_heads: usize,
876    attn_bias_h_lq_lk: Option<&[f32]>,
877    key_padding_mask: Option<&[u8]>,
878) -> Result<Vec<f32>> {
879    use super::tensor::{matmul, matmul_bt, softmax_rows};
880    let head_dim = embed_dim / num_heads;
881    let (wq, wk, wv) = split3(in_proj_w_t, embed_dim);
882    let bq = &in_proj_b[0..embed_dim];
883    let bk = &in_proj_b[embed_dim..2 * embed_dim];
884    let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
885
886    let q_proj = linear(q, batch * l_q, embed_dim, &wq, embed_dim, bq)?;
887    let k_proj = linear(k, batch * l_k, embed_dim, &wk, embed_dim, bk)?;
888    let v_proj = linear(v, batch * l_k, embed_dim, &wv, embed_dim, bv)?;
889
890    let bh = batch * num_heads;
891    let mut qh = vec![0f32; bh * l_q * head_dim];
892    let mut kh = vec![0f32; bh * l_k * head_dim];
893    let mut vh = vec![0f32; bh * l_k * head_dim];
894    repack(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
895    repack(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
896    repack(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
897
898    let scale = 1.0f32 / (head_dim as f32).sqrt();
899    let mut scores = vec![0f32; l_q * l_k];
900    let mut attn_out = vec![0f32; bh * l_q * head_dim];
901    for bi in 0..batch {
902        for h in 0..num_heads {
903            let bhi = bi * num_heads + h;
904            let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
905            let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
906            let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
907            matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
908            if let Some(bias) = attn_bias_h_lq_lk {
909                let bias_h = &bias[h * l_q * l_k..(h + 1) * l_q * l_k];
910                for i in 0..scores.len() {
911                    scores[i] += bias_h[i];
912                }
913            }
914            if let Some(mask) = key_padding_mask {
915                let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
916                for r in 0..l_q {
917                    let row = &mut scores[r * l_k..(r + 1) * l_k];
918                    for (c, m) in mask_b.iter().enumerate() {
919                        if *m != 0 {
920                            row[c] = f32::NEG_INFINITY;
921                        }
922                    }
923                }
924            }
925            softmax_rows(&mut scores, l_q, l_k);
926            let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
927            matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
928        }
929    }
930
931    let mut packed = vec![0f32; batch * l_q * embed_dim];
932    for bi in 0..batch {
933        for l in 0..l_q {
934            for h in 0..num_heads {
935                let src = ((bi * num_heads + h) * l_q + l) * head_dim;
936                let dst = (bi * l_q + l) * embed_dim + h * head_dim;
937                packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
938            }
939        }
940    }
941    linear(
942        &packed,
943        batch * l_q,
944        embed_dim,
945        out_proj_w_t,
946        embed_dim,
947        out_proj_b,
948    )
949}
950
951fn split3(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
952    let mut wq = vec![0f32; e * e];
953    let mut wk = vec![0f32; e * e];
954    let mut wv = vec![0f32; e * e];
955    for i in 0..e {
956        for j in 0..e {
957            wq[i * e + j] = w_t[i * 3 * e + j];
958            wk[i * e + j] = w_t[i * 3 * e + e + j];
959            wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
960        }
961    }
962    (wq, wk, wv)
963}
964
965fn repack(src: &[f32], dst: &mut [f32], batch: usize, l: usize, num_heads: usize, head_dim: usize) {
966    let e = num_heads * head_dim;
967    for bi in 0..batch {
968        for li in 0..l {
969            for h in 0..num_heads {
970                let s = (bi * l + li) * e + h * head_dim;
971                let d = ((bi * num_heads + h) * l + li) * head_dim;
972                dst[d..d + head_dim].copy_from_slice(&src[s..s + head_dim]);
973            }
974        }
975    }
976}
977
978/// Run the full decoder (batch must be 1 for the boxRPB path).
979#[allow(clippy::too_many_arguments)]
980pub fn forward_decoder(
981    weights: &Sam3DecoderWeights,
982    memory: &[f32],      // [batch, h*w, D] batch-first (matches our encoder output)
983    memory_pos: &[f32],  // same layout
984    memory_text: &[f32], // [seq, batch, D] seq-first
985    text_attention_mask: &[u8], // [batch, seq]
986    batch: usize,
987    h: usize,
988    w: usize,
989    seq_len: usize,
990    gguf_packed: Option<&GgufPackedParams>,
991) -> Result<Sam3DecoderOutput> {
992    ensure!(weights.loaded, "SAM3 detector decoder not loaded");
993    ensure!(batch == 1, "decoder forward requires batch=1 for boxRPB");
994    let hw = h * w;
995    let nq = NUM_QUERIES;
996
997    // Build initial tgt and reference boxes.
998    let mut tgt = vec![0f32; nq * batch * D_MODEL]; // seq-first [nq, bs, D]
999    for q in 0..nq {
1000        let src = &weights.query_embed[q * D_MODEL..(q + 1) * D_MODEL];
1001        for b in 0..batch {
1002            tgt[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL].copy_from_slice(src);
1003        }
1004    }
1005    let mut reference_boxes = vec![0f32; nq * batch * 4];
1006    for q in 0..nq {
1007        let src = &weights.reference_points[q * 4..(q + 1) * 4];
1008        for b in 0..batch {
1009            let dst = &mut reference_boxes[(q * batch + b) * 4..(q * batch + b + 1) * 4];
1010            for k in 0..4 {
1011                dst[k] = sigmoid(src[k]);
1012            }
1013        }
1014    }
1015
1016    let mut presence_out = vec![0f32; batch * D_MODEL];
1017    for b in 0..batch {
1018        presence_out[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(&weights.presence_token);
1019    }
1020
1021    let mut intermediate = Vec::with_capacity(N_LAYERS);
1022    let mut intermediate_ref_boxes = Vec::with_capacity(N_LAYERS);
1023    let mut presence_logits = Vec::with_capacity(N_LAYERS);
1024
1025    // First entry of intermediate_ref_boxes is the initial reference boxes.
1026    intermediate_ref_boxes.push(reference_boxes.clone());
1027
1028    // Reorder memory and memory_text into batch-first [bs, len, D] for MHA.
1029    // memory is already [bs, hw, D]. memory_text is [seq, bs, D] → [bs, seq, D].
1030    let mut memory_text_bf = vec![0f32; batch * seq_len * D_MODEL];
1031    for b in 0..batch {
1032        for l in 0..seq_len {
1033            let src = (l * batch + b) * D_MODEL;
1034            let dst = (b * seq_len + l) * D_MODEL;
1035            memory_text_bf[dst..dst + D_MODEL].copy_from_slice(&memory_text[src..src + D_MODEL]);
1036        }
1037    }
1038
1039    for (layer_idx, layer) in weights.layers.iter().enumerate() {
1040        // Compute query_pos = ref_point_head(sineembed(ref_boxes)).
1041        let sine = sineembed_for_position_4d(&reference_boxes, nq, batch, D_MODEL);
1042        let query_pos = mlp2_forward(&weights.ref_point_head, &sine, nq * batch, gguf_packed)?;
1043        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1044            use std::io::Write as _;
1045            let path = format!("{dir}/host_layer{layer_idx}_query_pos.f32");
1046            let mut f = std::fs::File::create(&path).unwrap();
1047            for v in &query_pos {
1048                f.write_all(&v.to_le_bytes()).unwrap();
1049            }
1050        }
1051
1052        // Build self-attn input: prepend presence_token to tgt. The
1053        // sequence becomes [presence, q0, ..., q199] = 201 tokens with
1054        // pos = [0, qp0, ..., qp199] (zeros for presence).
1055        let sa_len = 1 + nq;
1056        let mut sa_x = vec![0f32; sa_len * batch * D_MODEL];
1057        let mut sa_pos = vec![0f32; sa_len * batch * D_MODEL];
1058        for b in 0..batch {
1059            sa_x[b * D_MODEL..(b + 1) * D_MODEL]
1060                .copy_from_slice(&presence_out[b * D_MODEL..(b + 1) * D_MODEL]);
1061        }
1062        for q in 0..nq {
1063            for b in 0..batch {
1064                let src = &tgt[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1065                sa_x[((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1066                    .copy_from_slice(src);
1067                let qp = &query_pos[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1068                sa_pos[((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1069                    .copy_from_slice(qp);
1070            }
1071        }
1072        // Reorder to batch-first for MHA helper.
1073        let mut sa_x_bf = vec![0f32; batch * sa_len * D_MODEL];
1074        let mut sa_pos_bf = vec![0f32; batch * sa_len * D_MODEL];
1075        for b in 0..batch {
1076            for l in 0..sa_len {
1077                let s = (l * batch + b) * D_MODEL;
1078                let d = (b * sa_len + l) * D_MODEL;
1079                sa_x_bf[d..d + D_MODEL].copy_from_slice(&sa_x[s..s + D_MODEL]);
1080                sa_pos_bf[d..d + D_MODEL].copy_from_slice(&sa_pos[s..s + D_MODEL]);
1081            }
1082        }
1083        // q=k=sa_x+sa_pos, v=sa_x.
1084        let mut qk = vec![0f32; sa_x_bf.len()];
1085        for i in 0..qk.len() {
1086            qk[i] = sa_x_bf[i] + sa_pos_bf[i];
1087        }
1088        let sa = mha_with_bias_maybe_gguf(
1089            &qk,
1090            &qk,
1091            &sa_x_bf,
1092            &layer.self_attn_in_w_t,
1093            &layer.self_attn_in_b,
1094            layer.self_attn_in_gguf_key.as_deref(),
1095            &layer.self_attn_out_w_t,
1096            &layer.self_attn_out_b,
1097            layer.self_attn_out_gguf_key.as_deref(),
1098            gguf_packed,
1099            batch,
1100            sa_len,
1101            sa_len,
1102            D_MODEL,
1103            N_HEADS,
1104            None,
1105            None,
1106        )?;
1107        for i in 0..sa_x_bf.len() {
1108            sa_x_bf[i] += sa[i];
1109        }
1110        // Post-norm.
1111        let sa_x_bf = layer_norm(&sa_x_bf, &layer.norm2_w, &layer.norm2_b, D_MODEL, 1e-5)?;
1112        // Split presence + tgt back out (seq-first ordering).
1113        let mut new_presence = vec![0f32; batch * D_MODEL];
1114        for b in 0..batch {
1115            let src = &sa_x_bf[(b * sa_len) * D_MODEL..(b * sa_len + 1) * D_MODEL];
1116            new_presence[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(src);
1117        }
1118        let mut after_sa = vec![0f32; batch * nq * D_MODEL];
1119        for b in 0..batch {
1120            for q in 0..nq {
1121                let src = (b * sa_len + 1 + q) * D_MODEL;
1122                let dst = (b * nq + q) * D_MODEL;
1123                after_sa[dst..dst + D_MODEL].copy_from_slice(&sa_x_bf[src..src + D_MODEL]);
1124            }
1125        }
1126        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1127            use std::io::Write as _;
1128            let path = format!("{dir}/host_layer{layer_idx}_sa_queries.f32");
1129            let mut f = std::fs::File::create(&path).unwrap();
1130            for v in &after_sa {
1131                f.write_all(&v.to_le_bytes()).unwrap();
1132            }
1133        }
1134
1135        // Text cross-attention. q = after_sa + query_pos (batch-first
1136        // ordering: [bs, nq, D]).
1137        let mut q_text = vec![0f32; batch * nq * D_MODEL];
1138        for b in 0..batch {
1139            for q in 0..nq {
1140                let dst = (b * nq + q) * D_MODEL;
1141                let qp = &query_pos[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1142                for c in 0..D_MODEL {
1143                    q_text[dst + c] = after_sa[dst + c] + qp[c];
1144                }
1145            }
1146        }
1147        let text_attn = mha_with_bias_maybe_gguf(
1148            &q_text,
1149            &memory_text_bf,
1150            &memory_text_bf,
1151            &layer.ca_text_in_w_t,
1152            &layer.ca_text_in_b,
1153            layer.ca_text_in_gguf_key.as_deref(),
1154            &layer.ca_text_out_w_t,
1155            &layer.ca_text_out_b,
1156            layer.ca_text_out_gguf_key.as_deref(),
1157            gguf_packed,
1158            batch,
1159            nq,
1160            seq_len,
1161            D_MODEL,
1162            N_HEADS,
1163            None,
1164            Some(text_attention_mask),
1165        )?;
1166        let mut tgt_after_ca_text = vec![0f32; batch * nq * D_MODEL];
1167        for i in 0..tgt_after_ca_text.len() {
1168            tgt_after_ca_text[i] = after_sa[i] + text_attn[i];
1169        }
1170        let tgt_after_ca_text = layer_norm(
1171            &tgt_after_ca_text,
1172            &layer.catext_norm_w,
1173            &layer.catext_norm_b,
1174            D_MODEL,
1175            1e-5,
1176        )?;
1177        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1178            use std::io::Write as _;
1179            let path = format!("{dir}/host_layer{layer_idx}_after_ca_text_q.f32");
1180            let mut f = std::fs::File::create(&path).unwrap();
1181            for v in &tgt_after_ca_text {
1182                f.write_all(&v.to_le_bytes()).unwrap();
1183            }
1184        }
1185
1186        // Image cross-attention with boxRPB log mask.
1187        let rpb = boxrpb_log_mask(weights, &reference_boxes, nq, h, w, gguf_packed)?;
1188        // Need to prepend a row of zeros for the presence token mask, but
1189        // since we don't include presence in the cross-attn for our path
1190        // (presence was consumed during self-attn only and cross-attn
1191        // doesn't gain a presence row in the layer forward — let me
1192        // re-check upstream): upstream concatenates a zero row in mask
1193        // before calling cross_attn so that cross_attn's mask has shape
1194        // (bs*nheads, nq+1, hw). The presence token IS included in cross-
1195        // attn. We follow the same path.
1196        let cross_len_q = 1 + nq;
1197        // Build a [n_heads, cross_len_q, hw] mask with presence row of 0.
1198        let mut full_mask = vec![0f32; N_HEADS * cross_len_q * hw];
1199        for head in 0..N_HEADS {
1200            // presence row: 0s already.
1201            for q in 0..nq {
1202                let src = (head * nq + q) * hw;
1203                let dst = (head * cross_len_q + 1 + q) * hw;
1204                full_mask[dst..dst + hw].copy_from_slice(&rpb[src..src + hw]);
1205            }
1206        }
1207        // Cross-attention input: prepend the new presence to the tgt.
1208        let mut ca_in_seq_first = vec![0f32; cross_len_q * batch * D_MODEL];
1209        for b in 0..batch {
1210            // Presence at position 0.
1211            ca_in_seq_first[b * D_MODEL..(b + 1) * D_MODEL]
1212                .copy_from_slice(&new_presence[b * D_MODEL..(b + 1) * D_MODEL]);
1213            for q in 0..nq {
1214                let src = &tgt_after_ca_text[(b * nq + q) * D_MODEL..(b * nq + q + 1) * D_MODEL];
1215                ca_in_seq_first
1216                    [((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1217                    .copy_from_slice(src);
1218            }
1219        }
1220        // Reorder to batch-first.
1221        let mut ca_in_bf = vec![0f32; batch * cross_len_q * D_MODEL];
1222        let mut ca_pos_bf = vec![0f32; batch * cross_len_q * D_MODEL];
1223        for b in 0..batch {
1224            for l in 0..cross_len_q {
1225                let s = (l * batch + b) * D_MODEL;
1226                let d = (b * cross_len_q + l) * D_MODEL;
1227                ca_in_bf[d..d + D_MODEL].copy_from_slice(&ca_in_seq_first[s..s + D_MODEL]);
1228                if l == 0 {
1229                    // presence pos = 0
1230                } else {
1231                    let qp = &query_pos
1232                        [((l - 1) * batch + b) * D_MODEL..((l - 1) * batch + b + 1) * D_MODEL];
1233                    ca_pos_bf[d..d + D_MODEL].copy_from_slice(qp);
1234                }
1235            }
1236        }
1237        // Q = ca_in + ca_pos; K = memory + memory_pos; V = memory.
1238        let mut q_img = vec![0f32; ca_in_bf.len()];
1239        for i in 0..q_img.len() {
1240            q_img[i] = ca_in_bf[i] + ca_pos_bf[i];
1241        }
1242        let mut k_img = vec![0f32; memory.len()];
1243        for i in 0..k_img.len() {
1244            k_img[i] = memory[i] + memory_pos[i];
1245        }
1246        let ca_out = mha_with_bias_maybe_gguf(
1247            &q_img,
1248            &k_img,
1249            memory,
1250            &layer.cross_attn_in_w_t,
1251            &layer.cross_attn_in_b,
1252            layer.cross_attn_in_gguf_key.as_deref(),
1253            &layer.cross_attn_out_w_t,
1254            &layer.cross_attn_out_b,
1255            layer.cross_attn_out_gguf_key.as_deref(),
1256            gguf_packed,
1257            batch,
1258            cross_len_q,
1259            hw,
1260            D_MODEL,
1261            N_HEADS,
1262            Some(&full_mask),
1263            None,
1264        )?;
1265        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1266            use std::io::Write as _;
1267            let path = format!("{dir}/host_layer{layer_idx}_ca_img_proj.f32");
1268            let mut f = std::fs::File::create(&path).unwrap();
1269            for v in &ca_out {
1270                f.write_all(&v.to_le_bytes()).unwrap();
1271            }
1272        }
1273        for i in 0..ca_in_bf.len() {
1274            ca_in_bf[i] += ca_out[i];
1275        }
1276        // Post-norm1.
1277        let ca_in_bf = layer_norm(&ca_in_bf, &layer.norm1_w, &layer.norm1_b, D_MODEL, 1e-5)?;
1278        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1279            use std::io::Write as _;
1280            // Extract queries only (rows 1..lq) of [B, lq, D].
1281            let mut q_only = vec![0f32; batch * nq * D_MODEL];
1282            for b in 0..batch {
1283                for q in 0..nq {
1284                    let src = (b * cross_len_q + 1 + q) * D_MODEL;
1285                    let dst = (b * nq + q) * D_MODEL;
1286                    q_only[dst..dst + D_MODEL].copy_from_slice(&ca_in_bf[src..src + D_MODEL]);
1287                }
1288            }
1289            let path = format!("{dir}/host_layer{layer_idx}_after_ca_img_q.f32");
1290            let mut f = std::fs::File::create(&path).unwrap();
1291            for v in &q_only {
1292                f.write_all(&v.to_le_bytes()).unwrap();
1293            }
1294        }
1295
1296        // FFN.
1297        let mut ff = linear_maybe_gguf(
1298            &ca_in_bf,
1299            batch * cross_len_q,
1300            D_MODEL,
1301            &layer.linear1_w_t,
1302            layer.linear1_gguf_key.as_deref(),
1303            gguf_packed,
1304            DIM_FF,
1305            &layer.linear1_b,
1306        )?;
1307        for v in ff.iter_mut() {
1308            if *v < 0.0 {
1309                *v = 0.0;
1310            }
1311        }
1312        let ffn = linear_maybe_gguf(
1313            &ff,
1314            batch * cross_len_q,
1315            DIM_FF,
1316            &layer.linear2_w_t,
1317            layer.linear2_gguf_key.as_deref(),
1318            gguf_packed,
1319            D_MODEL,
1320            &layer.linear2_b,
1321        )?;
1322        let mut after_ffn = ca_in_bf.clone();
1323        for i in 0..after_ffn.len() {
1324            after_ffn[i] += ffn[i];
1325        }
1326        // Post-norm3.
1327        let after_ffn = layer_norm(&after_ffn, &layer.norm3_w, &layer.norm3_b, D_MODEL, 1e-5)?;
1328
1329        // Split off presence and tgt.
1330        let mut layer_presence = vec![0f32; batch * D_MODEL];
1331        let mut layer_tgt = vec![0f32; batch * nq * D_MODEL];
1332        for b in 0..batch {
1333            let src_p = &after_ffn[(b * cross_len_q) * D_MODEL..(b * cross_len_q + 1) * D_MODEL];
1334            layer_presence[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(src_p);
1335            for q in 0..nq {
1336                let src = (b * cross_len_q + 1 + q) * D_MODEL;
1337                let dst = (b * nq + q) * D_MODEL;
1338                layer_tgt[dst..dst + D_MODEL].copy_from_slice(&after_ffn[src..src + D_MODEL]);
1339            }
1340        }
1341        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1342            use std::io::Write as _;
1343            for (vals, name) in [(&layer_tgt, "new_tgt"), (&layer_presence, "new_presence")] {
1344                let path = format!("{dir}/host_layer{layer_idx}_{name}.f32");
1345                let mut f = std::fs::File::create(&path).unwrap();
1346                for v in vals {
1347                    f.write_all(&v.to_le_bytes()).unwrap();
1348                }
1349            }
1350        }
1351        // tgt becomes the layer_tgt (in seq-first ordering).
1352        for q in 0..nq {
1353            for b in 0..batch {
1354                let src = (b * nq + q) * D_MODEL;
1355                let dst = (q * batch + b) * D_MODEL;
1356                tgt[dst..dst + D_MODEL].copy_from_slice(&layer_tgt[src..src + D_MODEL]);
1357            }
1358        }
1359        // presence_out is the new layer presence (in batch ordering).
1360        presence_out.copy_from_slice(&layer_presence);
1361
1362        // Box refinement: delta = bbox_embed(out_norm(layer_tgt)); ref = sigmoid(inv_sig(ref) + delta).
1363        // out_norm in batch-first order.
1364        let out_norm = layer_norm(&layer_tgt, &weights.norm_w, &weights.norm_b, D_MODEL, 1e-5)?;
1365        if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1366            use std::io::Write as _;
1367            let path = format!("{dir}/host_layer{layer_idx}_out_norm.f32");
1368            let mut f = std::fs::File::create(&path).unwrap();
1369            for v in &out_norm {
1370                f.write_all(&v.to_le_bytes()).unwrap();
1371            }
1372        }
1373        let delta = mlp3_forward(&weights.bbox_embed, &out_norm, batch * nq, gguf_packed)?;
1374        let mut new_ref = vec![0f32; nq * batch * 4];
1375        for q in 0..nq {
1376            for b in 0..batch {
1377                let cur = &reference_boxes[(q * batch + b) * 4..(q * batch + b + 1) * 4];
1378                let d = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1379                for k in 0..4 {
1380                    new_ref[(q * batch + b) * 4 + k] = sigmoid(inverse_sigmoid(cur[k]) + d[k]);
1381                }
1382            }
1383        }
1384        reference_boxes = new_ref;
1385        if layer_idx != N_LAYERS - 1 {
1386            intermediate_ref_boxes.push(reference_boxes.clone());
1387        }
1388
1389        // Intermediate output post-norm in seq-first.
1390        let mut out_seq_first = vec![0f32; nq * batch * D_MODEL];
1391        for q in 0..nq {
1392            for b in 0..batch {
1393                let src = (b * nq + q) * D_MODEL;
1394                let dst = (q * batch + b) * D_MODEL;
1395                out_seq_first[dst..dst + D_MODEL].copy_from_slice(&out_norm[src..src + D_MODEL]);
1396            }
1397        }
1398        intermediate.push(out_seq_first);
1399
1400        // Presence logits per layer.
1401        let p_norm = layer_norm(
1402            &layer_presence,
1403            &weights.presence_token_out_norm_w,
1404            &weights.presence_token_out_norm_b,
1405            D_MODEL,
1406            1e-5,
1407        )?;
1408        let p_logit = mlp3_forward(&weights.presence_token_head, &p_norm, batch, gguf_packed)?;
1409        presence_logits.push(p_logit);
1410    }
1411
1412    // Stack.
1413    let mut int_stack = vec![0f32; N_LAYERS * nq * batch * D_MODEL];
1414    for (li, layer_out) in intermediate.iter().enumerate() {
1415        int_stack[li * nq * batch * D_MODEL..(li + 1) * nq * batch * D_MODEL]
1416            .copy_from_slice(layer_out);
1417    }
1418    let mut ref_stack = vec![0f32; N_LAYERS * nq * batch * 4];
1419    for (li, ref_l) in intermediate_ref_boxes.iter().enumerate() {
1420        ref_stack[li * nq * batch * 4..(li + 1) * nq * batch * 4].copy_from_slice(ref_l);
1421    }
1422    let mut presence_stack = vec![0f32; N_LAYERS * batch];
1423    for (li, p) in presence_logits.iter().enumerate() {
1424        for b in 0..batch {
1425            presence_stack[li * batch + b] = p[b];
1426        }
1427    }
1428
1429    Ok(Sam3DecoderOutput {
1430        intermediate: int_stack,
1431        intermediate_ref_boxes: ref_stack,
1432        presence_logits: presence_stack,
1433        presence_feats: presence_out,
1434        num_layers: N_LAYERS,
1435        num_queries: nq,
1436        batch,
1437        d_model: D_MODEL,
1438    })
1439}