Skip to main content

rlx_sam3/
detector_decoder_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! IR-lowered detector decoder.
17//!
18//! The per-layer compute (self-attn, text cross-attn, image cross-attn
19//! with explicit boxRPB add, FFN) is expressed as native `HirModule` +
20//! `HirMut`, compiled once per layer on the requested device. Box-refinement /
21//! sineembed / presence-token concatenation stay in Rust because they're
22//! iterative and small.
23
24use super::detector_decoder::{
25    Mlp2, Mlp3, Sam3DecoderLayerWeights, Sam3DecoderOutput, Sam3DecoderWeights, mlp2_forward,
26    mlp2_forward_into, mlp3_forward, mlp3_forward_into,
27};
28use super::packed_gguf::packed_linear;
29use anyhow::{Result, ensure};
30use rlx_flow::CompileProfile;
31use rlx_flow::{GgufPackedLinear, GgufPackedParams};
32use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
33use rlx_ir::op::{Activation, MaskKind, Op};
34use rlx_ir::shape;
35use rlx_ir::{DType, Shape};
36use rlx_runtime::{CompiledGraph, Device};
37use std::collections::HashMap;
38
39const D_MODEL: usize = 256;
40const DIM_FF: usize = 2048;
41const N_HEADS: usize = 8;
42const HEAD_DIM: usize = D_MODEL / N_HEADS;
43const NUM_QUERIES: usize = 200;
44const N_LAYERS: usize = 6;
45
46type LayerHirParts = (
47    HirModule,
48    HashMap<String, Vec<f32>>,
49    Vec<(String, Vec<u8>, DType)>,
50);
51type LayerRunOut = (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>);
52
53fn dec_layer_key(base: &str, li: usize, suffix: &str) -> String {
54    format!("{base}.layers.{li}.{suffix}")
55}
56
57fn gguf_weight_param(
58    g: &mut HirMut<'_>,
59    typed: &mut Vec<(String, Vec<u8>, DType)>,
60    cache: &mut HashMap<String, HirNodeId>,
61    ir_name: &str,
62    p: &GgufPackedLinear,
63) -> HirNodeId {
64    if let Some(&id) = cache.get(ir_name) {
65        return id;
66    }
67    let id = g.param(ir_name, Shape::new(&[p.w_q.len()], DType::U8));
68    typed.push((ir_name.to_string(), p.w_q.clone(), DType::U8));
69    cache.insert(ir_name.to_string(), id);
70    id
71}
72
73fn linear_gguf_matmul(
74    g: &mut HirMut<'_>,
75    typed: &mut Vec<(String, Vec<u8>, DType)>,
76    cache: &mut HashMap<String, HirNodeId>,
77    ir_stem: &str,
78    p: &GgufPackedLinear,
79    input: HirNodeId,
80    in_dim: usize,
81    out_dim: usize,
82) -> Result<HirNodeId> {
83    ensure!(
84        p.in_dim == in_dim && p.out_dim == out_dim,
85        "packed linear {ir_stem}: shape {}x{} vs {in_dim}x{out_dim}",
86        p.in_dim,
87        p.out_dim
88    );
89    let w_name = format!("{ir_stem}.w");
90    let w_id = gguf_weight_param(g, typed, cache, &w_name, p);
91    let cur = g.shape(input);
92    let mut dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
93    *dims.last_mut().unwrap() = out_dim;
94    let out_shape = Shape::new(&dims, DType::F32);
95    Ok(g.add_node(
96        Op::DequantMatMul { scheme: p.scheme },
97        vec![input, w_id],
98        out_shape,
99    ))
100}
101
102fn add_f32_bias(
103    g: &mut HirMut<'_>,
104    params: &mut HashMap<String, Vec<f32>>,
105    name: &str,
106    input: HirNodeId,
107    bias: &[f32],
108) -> HirNodeId {
109    if bias.iter().all(|&v| v == 0.0) {
110        return input;
111    }
112    let out_dim = bias.len();
113    let b_id = add_param(
114        g,
115        params,
116        name,
117        bias.to_vec(),
118        Shape::new(&[out_dim], DType::F32),
119    );
120    g.add(input, b_id)
121}
122
123fn linear_gguf_bias(
124    g: &mut HirMut<'_>,
125    params: &mut HashMap<String, Vec<f32>>,
126    typed: &mut Vec<(String, Vec<u8>, DType)>,
127    cache: &mut HashMap<String, HirNodeId>,
128    ir_stem: &str,
129    p: &GgufPackedLinear,
130    input: HirNodeId,
131    bias: &[f32],
132    in_dim: usize,
133    out_dim: usize,
134) -> Result<HirNodeId> {
135    let y = linear_gguf_matmul(g, typed, cache, ir_stem, p, input, in_dim, out_dim)?;
136    Ok(add_f32_bias(g, params, &format!("{ir_stem}.b"), y, bias))
137}
138
139fn in_proj_qkv(
140    g: &mut HirMut<'_>,
141    params: &mut HashMap<String, Vec<f32>>,
142    typed: &mut Vec<(String, Vec<u8>, DType)>,
143    cache: &mut HashMap<String, HirNodeId>,
144    gguf_packed: Option<&GgufPackedParams>,
145    gguf_key: &str,
146    ir_stem: &str,
147    layer_w_t: &[f32],
148    layer_b: &[f32],
149    input_q: HirNodeId,
150    input_k: HirNodeId,
151    input_v: HirNodeId,
152    d: usize,
153) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
154    if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
155        let qkv_q = linear_gguf_bias(
156            g,
157            params,
158            typed,
159            cache,
160            ir_stem,
161            p,
162            input_q,
163            layer_b,
164            d,
165            3 * d,
166        )?;
167        let qkv_k = linear_gguf_bias(
168            g,
169            params,
170            typed,
171            cache,
172            ir_stem,
173            p,
174            input_k,
175            layer_b,
176            d,
177            3 * d,
178        )?;
179        let qkv_v = linear_gguf_bias(
180            g,
181            params,
182            typed,
183            cache,
184            ir_stem,
185            p,
186            input_v,
187            layer_b,
188            d,
189            3 * d,
190        )?;
191        let axis = g.shape(qkv_q).rank().saturating_sub(1);
192        let q = g.narrow_(qkv_q, axis, 0, d);
193        let k = g.narrow_(qkv_k, axis, d, d);
194        let v = g.narrow_(qkv_v, axis, 2 * d, d);
195        return Ok((q, k, v));
196    }
197    let (wq, wk, wv) = split_qkv(layer_w_t, d);
198    let bq = layer_b[0..d].to_vec();
199    let bk = layer_b[d..2 * d].to_vec();
200    let bv = layer_b[2 * d..3 * d].to_vec();
201    let batch_q = g.shape(input_q).dims()[0].unwrap_static();
202    let seq_q = g.shape(input_q).dims()[1].unwrap_static();
203    let batch_k = g.shape(input_k).dims()[0].unwrap_static();
204    let seq_k = g.shape(input_k).dims()[1].unwrap_static();
205    let batch_v = g.shape(input_v).dims()[0].unwrap_static();
206    let seq_v = g.shape(input_v).dims()[1].unwrap_static();
207    let q = linear_bias_shaped(
208        g,
209        params,
210        &format!("{ir_stem}.q"),
211        input_q,
212        wq,
213        bq,
214        d,
215        d,
216        Some(batch_q),
217        Some(seq_q),
218    );
219    let k = linear_bias_shaped(
220        g,
221        params,
222        &format!("{ir_stem}.k"),
223        input_k,
224        wk,
225        bk,
226        d,
227        d,
228        Some(batch_k),
229        Some(seq_k),
230    );
231    let v = linear_bias_shaped(
232        g,
233        params,
234        &format!("{ir_stem}.v"),
235        input_v,
236        wv,
237        bv,
238        d,
239        d,
240        Some(batch_v),
241        Some(seq_v),
242    );
243    Ok((q, k, v))
244}
245
246fn linear_fused_or_gguf(
247    g: &mut HirMut<'_>,
248    params: &mut HashMap<String, Vec<f32>>,
249    typed: &mut Vec<(String, Vec<u8>, DType)>,
250    cache: &mut HashMap<String, HirNodeId>,
251    gguf_packed: Option<&GgufPackedParams>,
252    gguf_key: &str,
253    ir_stem: &str,
254    input: HirNodeId,
255    w_t: Vec<f32>,
256    bias: Vec<f32>,
257    in_dim: usize,
258    out_dim: usize,
259) -> Result<HirNodeId> {
260    if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
261        return linear_gguf_bias(
262            g, params, typed, cache, ir_stem, p, input, &bias, in_dim, out_dim,
263        );
264    }
265    Ok(linear_bias(
266        g, params, ir_stem, input, w_t, bias, in_dim, out_dim,
267    ))
268}
269
270fn split_qkv(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
271    let mut wq = vec![0f32; e * e];
272    let mut wk = vec![0f32; e * e];
273    let mut wv = vec![0f32; e * e];
274    for i in 0..e {
275        for j in 0..e {
276            wq[i * e + j] = w_t[i * 3 * e + j];
277            wk[i * e + j] = w_t[i * 3 * e + e + j];
278            wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
279        }
280    }
281    (wq, wk, wv)
282}
283
284fn add_param(
285    g: &mut HirMut<'_>,
286    params: &mut HashMap<String, Vec<f32>>,
287    name: &str,
288    data: Vec<f32>,
289    shape: Shape,
290) -> HirNodeId {
291    let id = g.param(name, shape);
292    params.insert(name.to_string(), data);
293    id
294}
295
296fn linear_bias(
297    g: &mut HirMut<'_>,
298    params: &mut HashMap<String, Vec<f32>>,
299    name: &str,
300    input: HirNodeId,
301    w: Vec<f32>,
302    b: Vec<f32>,
303    in_dim: usize,
304    out_dim: usize,
305) -> HirNodeId {
306    linear_bias_shaped(g, params, name, input, w, b, in_dim, out_dim, None, None)
307}
308
309fn linear_bias_shaped(
310    g: &mut HirMut<'_>,
311    params: &mut HashMap<String, Vec<f32>>,
312    name: &str,
313    input: HirNodeId,
314    w: Vec<f32>,
315    b: Vec<f32>,
316    in_dim: usize,
317    out_dim: usize,
318    batch: Option<usize>,
319    seq: Option<usize>,
320) -> HirNodeId {
321    let f = DType::F32;
322    let w_id = add_param(
323        g,
324        params,
325        &format!("{name}.w"),
326        w,
327        Shape::new(&[in_dim, out_dim], f),
328    );
329    let b_id = add_param(
330        g,
331        params,
332        &format!("{name}.b"),
333        b,
334        Shape::new(&[out_dim], f),
335    );
336    let out_shape = if let (Some(batch), Some(seq)) = (batch, seq) {
337        Shape::new(&[batch, seq, out_dim], f)
338    } else {
339        let cur = g.shape(input);
340        let mut out_dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
341        *out_dims.last_mut().unwrap() = out_dim;
342        Shape::new(&out_dims, f)
343    };
344    g.add_node(
345        Op::FusedMatMulBiasAct { activation: None },
346        vec![input, w_id, b_id],
347        out_shape,
348    )
349}
350
351fn fused_matmul_bias_act(
352    g: &mut HirMut<'_>,
353    input: HirNodeId,
354    w: HirNodeId,
355    b: HirNodeId,
356    activation: Option<Activation>,
357    out_shape: Shape,
358) -> HirNodeId {
359    g.add_node(
360        Op::FusedMatMulBiasAct { activation },
361        vec![input, w, b],
362        out_shape,
363    )
364}
365
366fn attention_bias(
367    g: &mut HirMut<'_>,
368    q: HirNodeId,
369    k: HirNodeId,
370    v: HirNodeId,
371    bias: HirNodeId,
372    num_heads: usize,
373    head_dim: usize,
374) -> HirNodeId {
375    let attn_shape = shape::attention_shape(g.shape(q));
376    g.add_node(
377        Op::Attention {
378            num_heads,
379            head_dim,
380            mask_kind: MaskKind::Bias,
381            score_scale: None,
382            attn_logit_softcap: None,
383        },
384        vec![q, k, v, bias],
385        attn_shape,
386    )
387}
388
389/// Build the boxRPB MLP + outer-add subgraph. Takes log-normed deltas
390/// (cheap geometry computed on host) and emits the `[B, H, nq+1, hw]`
391/// additive log-bias used by the image cross-attention. Replaces the
392/// host-side `boxrpb_log_full_into` for the per-layer hot path.
393///
394/// `deltas_x: [B, nq, w, 2]`, `deltas_y: [B, nq, h, 2]` → bias.
395#[allow(clippy::too_many_arguments)]
396fn mlp2_relu_pair_gguf(
397    g: &mut HirMut<'_>,
398    params: &mut HashMap<String, Vec<f32>>,
399    typed: &mut Vec<(String, Vec<u8>, DType)>,
400    cache: &mut HashMap<String, HirNodeId>,
401    gguf_packed: Option<&GgufPackedParams>,
402    mlp: &Mlp2,
403    stem: &str,
404    input: HirNodeId,
405    rows: usize,
406    hidden_dim: usize,
407    out_dim: usize,
408) -> Result<HirNodeId> {
409    let h = if let Some(p) = mlp
410        .w0_gguf_key
411        .as_deref()
412        .and_then(|key| gguf_packed.and_then(|m| super::packed_gguf::packed_linear(m, key)))
413    {
414        let y = linear_gguf_bias(
415            g,
416            params,
417            typed,
418            cache,
419            &format!("{stem}.fc0"),
420            p,
421            input,
422            &mlp.b0,
423            mlp.in_dim,
424            hidden_dim,
425        )?;
426        g.relu(y)
427    } else {
428        let w_id = add_param(
429            g,
430            params,
431            &format!("{stem}.w0"),
432            mlp.w0_t.clone(),
433            Shape::new(&[mlp.in_dim, hidden_dim], DType::F32),
434        );
435        let b_id = add_param(
436            g,
437            params,
438            &format!("{stem}.b0"),
439            mlp.b0.clone(),
440            Shape::new(&[hidden_dim], DType::F32),
441        );
442        fused_matmul_bias_act(
443            g,
444            input,
445            w_id,
446            b_id,
447            Some(Activation::Relu),
448            Shape::new(&[rows, hidden_dim], DType::F32),
449        )
450    };
451    if let Some(p) = mlp
452        .w1_gguf_key
453        .as_deref()
454        .and_then(|key| gguf_packed.and_then(|m| super::packed_gguf::packed_linear(m, key)))
455    {
456        return linear_gguf_bias(
457            g,
458            params,
459            typed,
460            cache,
461            &format!("{stem}.fc1"),
462            p,
463            h,
464            &mlp.b1,
465            hidden_dim,
466            out_dim,
467        );
468    }
469    let w_id = add_param(
470        g,
471        params,
472        &format!("{stem}.w1"),
473        mlp.w1_t.clone(),
474        Shape::new(&[hidden_dim, out_dim], DType::F32),
475    );
476    let b_id = add_param(
477        g,
478        params,
479        &format!("{stem}.b1"),
480        mlp.b1.clone(),
481        Shape::new(&[out_dim], DType::F32),
482    );
483    Ok(fused_matmul_bias_act(
484        g,
485        h,
486        w_id,
487        b_id,
488        None,
489        Shape::new(&[rows, out_dim], DType::F32),
490    ))
491}
492
493fn build_boxrpb_subgraph(
494    g: &mut HirMut<'_>,
495    params: &mut HashMap<String, Vec<f32>>,
496    typed: &mut Vec<(String, Vec<u8>, DType)>,
497    gguf_cache: &mut HashMap<String, HirNodeId>,
498    gguf_packed: Option<&GgufPackedParams>,
499    boxrpb_x: &Mlp2,
500    boxrpb_y: &Mlp2,
501    deltas_x: HirNodeId,
502    deltas_y: HirNodeId,
503    batch: usize,
504    nq: usize,
505    nh: usize,
506    h: usize,
507    w: usize,
508) -> Result<HirNodeId> {
509    let f = DType::F32;
510    let hidden_x = boxrpb_x.hidden;
511    let hidden_y = boxrpb_y.hidden;
512    assert_eq!(boxrpb_x.in_dim, 2);
513    assert_eq!(boxrpb_y.in_dim, 2);
514    assert_eq!(boxrpb_x.out_dim, nh);
515    assert_eq!(boxrpb_y.out_dim, nh);
516
517    let dx_flat = g.reshape_(deltas_x, vec![(batch * nq * w) as i64, 2]);
518    let dx_o = mlp2_relu_pair_gguf(
519        g,
520        params,
521        typed,
522        gguf_cache,
523        gguf_packed,
524        boxrpb_x,
525        "boxrpb_x",
526        dx_flat,
527        batch * nq * w,
528        hidden_x,
529        nh,
530    )?;
531    let dx_4d = g.reshape_(dx_o, vec![batch as i64, nq as i64, w as i64, nh as i64]);
532    let dx_perm = g.transpose_(dx_4d, vec![0, 3, 1, 2]);
533    let dx_bc = g.reshape_(
534        dx_perm,
535        vec![batch as i64, nh as i64, nq as i64, 1, w as i64],
536    );
537
538    let dy_flat = g.reshape_(deltas_y, vec![(batch * nq * h) as i64, 2]);
539    let dy_o = mlp2_relu_pair_gguf(
540        g,
541        params,
542        typed,
543        gguf_cache,
544        gguf_packed,
545        boxrpb_y,
546        "boxrpb_y",
547        dy_flat,
548        batch * nq * h,
549        hidden_y,
550        nh,
551    )?;
552    let dy_4d = g.reshape_(dy_o, vec![batch as i64, nq as i64, h as i64, nh as i64]);
553    let dy_perm = g.transpose_(dy_4d, vec![0, 3, 1, 2]);
554    let dy_bc = g.reshape_(
555        dy_perm,
556        vec![batch as i64, nh as i64, nq as i64, h as i64, 1],
557    );
558
559    let rpb_q = g.add(dx_bc, dy_bc);
560    let rpb_q_flat = g.reshape_(
561        rpb_q,
562        vec![batch as i64, nh as i64, nq as i64, (h * w) as i64],
563    );
564
565    let hw = h * w;
566    let _lq = nq + 1;
567    let zero_pres = add_param(
568        g,
569        params,
570        "rpb_zero_presence",
571        vec![0f32; batch * nh * hw],
572        Shape::new(&[batch, nh, 1, hw], f),
573    );
574    Ok(g.concat_(vec![zero_pres, rpb_q_flat], 2))
575}
576
577struct DecoderLayerHirParts {
578    params: HashMap<String, Vec<f32>>,
579    typed_params: Vec<(String, Vec<u8>, DType)>,
580}
581
582/// Build the per-layer decoder HIR body.
583#[allow(clippy::too_many_arguments)]
584fn build_layer_body(
585    hir: &mut HirModule,
586    layer: &Sam3DecoderLayerWeights,
587    boxrpb_x: &Mlp2,
588    boxrpb_y: &Mlp2,
589    norm_w: &[f32],
590    norm_b: &[f32],
591    dec_base: &str,
592    li: usize,
593    batch: usize,
594    h: usize,
595    w: usize,
596    seq: usize,
597    use_bias_attn: bool,
598    boxrpb_in_ir: bool,
599    gguf_packed: Option<&GgufPackedParams>,
600) -> Result<DecoderLayerHirParts> {
601    let hw = h * w;
602    let mut g = HirMut::new(hir);
603    let mut params: HashMap<String, Vec<f32>> = HashMap::new();
604    let mut typed_params = Vec::new();
605    let mut gguf_w_cache: HashMap<String, HirNodeId> = HashMap::new();
606    let f = DType::F32;
607    let d = D_MODEL;
608    let nh = N_HEADS;
609    let dh = HEAD_DIM;
610    let nq = NUM_QUERIES;
611    let lq = nq + 1;
612
613    let tgt = g.input("tgt", Shape::new(&[batch, nq, d], f));
614    let query_pos = g.input("query_pos", Shape::new(&[batch, nq, d], f));
615    let presence = g.input("presence", Shape::new(&[batch, 1, d], f));
616    let memory = g.input("memory", Shape::new(&[batch, hw, d], f));
617    let memory_pos = g.input("memory_pos", Shape::new(&[batch, hw, d], f));
618    let text = g.input("text", Shape::new(&[batch, seq, d], f));
619    let text_kpm_inv = g.input("text_kpm_inv", Shape::new(&[batch, seq], f));
620    let rpb_bias = if boxrpb_in_ir {
621        let dx = g.input("deltas_x", Shape::new(&[batch, nq, w, 2], f));
622        let dy = g.input("deltas_y", Shape::new(&[batch, nq, h, 2], f));
623        build_boxrpb_subgraph(
624            &mut g,
625            &mut params,
626            &mut typed_params,
627            &mut gguf_w_cache,
628            gguf_packed,
629            boxrpb_x,
630            boxrpb_y,
631            dx,
632            dy,
633            batch,
634            nq,
635            nh,
636            h,
637            w,
638        )?
639    } else {
640        g.input("rpb_bias", Shape::new(&[batch, nh, lq, hw], f))
641    };
642
643    let sa_x = g.concat_(vec![presence, tgt], 1);
644    let zero_pos = add_param(
645        &mut g,
646        &mut params,
647        "zero_presence_pos",
648        vec![0f32; batch * d],
649        Shape::new(&[batch, 1, d], f),
650    );
651    let sa_pos = g.concat_(vec![zero_pos, query_pos], 1);
652    let sa_qk = g.add(sa_x, sa_pos);
653
654    let (q_sa, k_sa, v_sa) = in_proj_qkv(
655        &mut g,
656        &mut params,
657        &mut typed_params,
658        &mut gguf_w_cache,
659        gguf_packed,
660        &dec_layer_key(dec_base, li, "self_attn.in_proj_weight"),
661        "sa.in_proj",
662        &layer.self_attn_in_w_t,
663        &layer.self_attn_in_b,
664        sa_qk,
665        sa_qk,
666        sa_x,
667        d,
668    )?;
669    let sa_attn = g.attention_kind(
670        q_sa,
671        k_sa,
672        v_sa,
673        nh,
674        dh,
675        MaskKind::None,
676        shape::attention_shape(g.shape(q_sa)),
677    );
678    let sa_proj = linear_fused_or_gguf(
679        &mut g,
680        &mut params,
681        &mut typed_params,
682        &mut gguf_w_cache,
683        gguf_packed,
684        &dec_layer_key(dec_base, li, "self_attn.out_proj.weight"),
685        "sa.out",
686        sa_attn,
687        layer.self_attn_out_w_t.clone(),
688        layer.self_attn_out_b.clone(),
689        d,
690        d,
691    )?;
692    let sa_res = g.add(sa_x, sa_proj);
693    let n2_w = add_param(
694        &mut g,
695        &mut params,
696        "norm2.w",
697        layer.norm2_w.clone(),
698        Shape::new(&[d], f),
699    );
700    let n2_b = add_param(
701        &mut g,
702        &mut params,
703        "norm2.b",
704        layer.norm2_b.clone(),
705        Shape::new(&[d], f),
706    );
707    let sa_normed = g.ln(sa_res, n2_w, n2_b, 1e-5);
708    let presence_after_sa = g.narrow_(sa_normed, 1, 0, 1);
709    let queries_after_sa = g.narrow_(sa_normed, 1, 1, nq);
710
711    let q_text_in = g.add(queries_after_sa, query_pos);
712    let (q_text, k_text, v_text) = in_proj_qkv(
713        &mut g,
714        &mut params,
715        &mut typed_params,
716        &mut gguf_w_cache,
717        gguf_packed,
718        &dec_layer_key(dec_base, li, "ca_text.in_proj_weight"),
719        "ca_text.in_proj",
720        &layer.ca_text_in_w_t,
721        &layer.ca_text_in_b,
722        q_text_in,
723        text,
724        text,
725        d,
726    )?;
727    let ca_text_attn = g.attention(
728        q_text,
729        k_text,
730        v_text,
731        text_kpm_inv,
732        nh,
733        dh,
734        shape::attention_shape(g.shape(q_text)),
735    );
736    let ca_text_proj = linear_fused_or_gguf(
737        &mut g,
738        &mut params,
739        &mut typed_params,
740        &mut gguf_w_cache,
741        gguf_packed,
742        &dec_layer_key(dec_base, li, "ca_text.out_proj.weight"),
743        "ca_text.out",
744        ca_text_attn,
745        layer.ca_text_out_w_t.clone(),
746        layer.ca_text_out_b.clone(),
747        d,
748        d,
749    )?;
750    let after_ca_text_res = g.add(queries_after_sa, ca_text_proj);
751    let cat_w = add_param(
752        &mut g,
753        &mut params,
754        "catext_norm.w",
755        layer.catext_norm_w.clone(),
756        Shape::new(&[d], f),
757    );
758    let cat_b = add_param(
759        &mut g,
760        &mut params,
761        "catext_norm.b",
762        layer.catext_norm_b.clone(),
763        Shape::new(&[d], f),
764    );
765    let after_ca_text = g.ln(after_ca_text_res, cat_w, cat_b, 1e-5);
766
767    let ca_in = g.concat_(vec![presence_after_sa, after_ca_text], 1);
768    let ca_q_in = g.add(ca_in, sa_pos);
769    let k_mem_in = g.add(memory, memory_pos);
770
771    let (q_img, k_img, v_img) = in_proj_qkv(
772        &mut g,
773        &mut params,
774        &mut typed_params,
775        &mut gguf_w_cache,
776        gguf_packed,
777        &dec_layer_key(dec_base, li, "cross_attn.in_proj_weight"),
778        "ca_img.in_proj",
779        &layer.cross_attn_in_w_t,
780        &layer.cross_attn_in_b,
781        ca_q_in,
782        k_mem_in,
783        memory,
784        d,
785    )?;
786
787    let attn_flat = if use_bias_attn {
788        attention_bias(&mut g, q_img, k_img, v_img, rpb_bias, nh, dh)
789    } else {
790        let q_4d = g.reshape_(q_img, vec![batch as i64, lq as i64, nh as i64, dh as i64]);
791        let q_perm = g.transpose_(q_4d, vec![0, 2, 1, 3]);
792        let k_4d = g.reshape_(k_img, vec![batch as i64, hw as i64, nh as i64, dh as i64]);
793        let k_perm = g.transpose_(k_4d, vec![0, 2, 1, 3]);
794        let v_4d = g.reshape_(v_img, vec![batch as i64, hw as i64, nh as i64, dh as i64]);
795        let v_perm = g.transpose_(v_4d, vec![0, 2, 1, 3]);
796        let k_t = g.transpose_(k_perm, vec![0, 1, 3, 2]);
797        let scores = g.mm(q_perm, k_t);
798        let scale_val = 1.0f32 / (HEAD_DIM as f32).sqrt();
799        let scale_node = add_param(
800            &mut g,
801            &mut params,
802            "img.scale",
803            vec![scale_val],
804            Shape::new(&[1], f),
805        );
806        let scores_scaled = g.mul(scores, scale_node);
807        let scores_biased = g.add(scores_scaled, rpb_bias);
808        let probs = g.sm(scores_biased, -1);
809        let attn_out = g.mm(probs, v_perm);
810        let attn_perm = g.transpose_(attn_out, vec![0, 2, 1, 3]);
811        g.reshape_(attn_perm, vec![batch as i64, lq as i64, d as i64])
812    };
813    let ca_img_proj = linear_fused_or_gguf(
814        &mut g,
815        &mut params,
816        &mut typed_params,
817        &mut gguf_w_cache,
818        gguf_packed,
819        &dec_layer_key(dec_base, li, "cross_attn.out_proj.weight"),
820        "ca_img.out",
821        attn_flat,
822        layer.cross_attn_out_w_t.clone(),
823        layer.cross_attn_out_b.clone(),
824        d,
825        d,
826    )?;
827    let ca_img_res = g.add(ca_in, ca_img_proj);
828    let n1_w = add_param(
829        &mut g,
830        &mut params,
831        "norm1.w",
832        layer.norm1_w.clone(),
833        Shape::new(&[d], f),
834    );
835    let n1_b = add_param(
836        &mut g,
837        &mut params,
838        "norm1.b",
839        layer.norm1_b.clone(),
840        Shape::new(&[d], f),
841    );
842    let after_ca_img = g.ln(ca_img_res, n1_w, n1_b, 1e-5);
843
844    let ff1 = linear_fused_or_gguf(
845        &mut g,
846        &mut params,
847        &mut typed_params,
848        &mut gguf_w_cache,
849        gguf_packed,
850        &dec_layer_key(dec_base, li, "linear1.weight"),
851        "ffn.fc1",
852        after_ca_img,
853        layer.linear1_w_t.clone(),
854        layer.linear1_b.clone(),
855        d,
856        DIM_FF,
857    )?;
858    let relud = g.relu(ff1);
859    let ff2 = linear_fused_or_gguf(
860        &mut g,
861        &mut params,
862        &mut typed_params,
863        &mut gguf_w_cache,
864        gguf_packed,
865        &dec_layer_key(dec_base, li, "linear2.weight"),
866        "ffn.fc2",
867        relud,
868        layer.linear2_w_t.clone(),
869        layer.linear2_b.clone(),
870        DIM_FF,
871        d,
872    )?;
873    let ffn_res = g.add(after_ca_img, ff2);
874    let n3_w = add_param(
875        &mut g,
876        &mut params,
877        "norm3.w",
878        layer.norm3_w.clone(),
879        Shape::new(&[d], f),
880    );
881    let n3_b = add_param(
882        &mut g,
883        &mut params,
884        "norm3.b",
885        layer.norm3_b.clone(),
886        Shape::new(&[d], f),
887    );
888    let after_ffn = g.ln(ffn_res, n3_w, n3_b, 1e-5);
889
890    let new_presence = g.narrow_(after_ffn, 1, 0, 1);
891    let new_tgt = g.narrow_(after_ffn, 1, 1, nq);
892
893    let dec_norm_w = add_param(
894        &mut g,
895        &mut params,
896        "dec.norm.w",
897        norm_w.to_vec(),
898        Shape::new(&[d], f),
899    );
900    let dec_norm_b = add_param(
901        &mut g,
902        &mut params,
903        "dec.norm.b",
904        norm_b.to_vec(),
905        Shape::new(&[d], f),
906    );
907    let out_norm = g.ln(new_tgt, dec_norm_w, dec_norm_b, 1e-5);
908
909    g.set_outputs(vec![new_tgt, new_presence, out_norm]);
910    let _ = (q_img, k_img, v_img, ca_img_proj);
911    Ok(DecoderLayerHirParts {
912        params,
913        typed_params,
914    })
915}
916
917/// Build a per-layer native HIR module.
918fn build_layer_hir(
919    layer: &Sam3DecoderLayerWeights,
920    boxrpb_x: &Mlp2,
921    boxrpb_y: &Mlp2,
922    norm_w: &[f32],
923    norm_b: &[f32],
924    dec_base: &str,
925    li: usize,
926    batch: usize,
927    h: usize,
928    w: usize,
929    seq: usize,
930    use_bias_attn: bool,
931    boxrpb_in_ir: bool,
932    gguf_packed: Option<&GgufPackedParams>,
933) -> Result<LayerHirParts> {
934    let mut hir = HirModule::new("sam3_dec_layer");
935    let parts = build_layer_body(
936        &mut hir,
937        layer,
938        boxrpb_x,
939        boxrpb_y,
940        norm_w,
941        norm_b,
942        dec_base,
943        li,
944        batch,
945        h,
946        w,
947        seq,
948        use_bias_attn,
949        boxrpb_in_ir,
950        gguf_packed,
951    )?;
952    Ok((hir, parts.params, parts.typed_params))
953}
954
955/// Compile-once-per-layer decoder, runnable across many frames.
956pub struct Sam3CompiledDecoder {
957    layers: Vec<CompiledGraph>,
958    bbox_embed: Mlp3,
959    ref_point_head: Mlp2,
960    boxrpb_x: Mlp2,
961    boxrpb_y: Mlp2,
962    initial_query_embed: Vec<f32>,
963    initial_reference_points: Vec<f32>,
964    cached_layer0_query_pos: Vec<f32>,
965    /// Layer-0 cached inputs for the boxRPB IR subgraph path (GPU
966    /// backends). Constant geometry, ~115KB each.
967    cached_layer0_deltas_x: Option<Vec<f32>>,
968    cached_layer0_deltas_y: Option<Vec<f32>>,
969    /// Layer-0 cached `rpb_bias [B, H, lq, hw]` for the host-MLP path
970    /// (CPU backend). Constant boxRPB tensor, ~66MB.
971    cached_layer0_rpb: Option<Vec<f32>>,
972    #[allow(dead_code)]
973    cached_initial_ref_boxes: Vec<f32>,
974    boxrpb_in_ir: bool,
975    presence_token: Vec<f32>,
976    presence_head: Mlp3,
977    presence_norm_w: Vec<f32>,
978    presence_norm_b: Vec<f32>,
979    /// Per-call delta scratch (geometry only — the MLP forward and
980    /// outer-add run inside the IR graph for GPU backends).
981    scratch_deltas_x: Vec<f32>,
982    scratch_deltas_y: Vec<f32>,
983    /// Host-MLP path scratch buffers (CPU backend). Allocated only
984    /// when boxrpb_in_ir is false.
985    scratch_rpb: Option<Vec<f32>>,
986    scratch_dx_thq: Option<Vec<f32>>,
987    scratch_dy_thq: Option<Vec<f32>>,
988    scratch_boxrpb_x_hidden: Option<Vec<f32>>,
989    scratch_boxrpb_y_hidden: Option<Vec<f32>>,
990    scratch_boxrpb_x_feats: Option<Vec<f32>>,
991    scratch_boxrpb_y_feats: Option<Vec<f32>>,
992    /// Scratch for ref_point_head sineembed + MLP intermediates.
993    scratch_sine: Vec<f32>,
994    scratch_rph_hidden: Vec<f32>,
995    /// Output of ref_point_head MLP for layers 1..N (layer 0 uses cache).
996    scratch_query_pos: Vec<f32>,
997    /// Scratch for box-refinement `mlp3_forward(bbox_embed)`.
998    scratch_bbox_h0: Vec<f32>,
999    scratch_bbox_h1: Vec<f32>,
1000    scratch_bbox_out: Vec<f32>,
1001    pub batch: usize,
1002    pub hw: usize,
1003    pub seq: usize,
1004    gguf_packed: Option<GgufPackedParams>,
1005}
1006
1007impl Sam3CompiledDecoder {
1008    pub fn new(
1009        weights: &Sam3DecoderWeights,
1010        batch: usize,
1011        hw: usize,
1012        seq: usize,
1013        device: Device,
1014    ) -> Result<Self> {
1015        Self::new_with_profile(weights, batch, hw, seq, device, &CompileProfile::sam3())
1016    }
1017
1018    pub fn new_with_profile(
1019        weights: &Sam3DecoderWeights,
1020        batch: usize,
1021        hw: usize,
1022        seq: usize,
1023        device: Device,
1024        profile: &CompileProfile,
1025    ) -> Result<Self> {
1026        Self::new_with_profile_and_gguf(weights, batch, hw, seq, device, profile, None)
1027    }
1028
1029    pub fn new_with_profile_and_gguf(
1030        weights: &Sam3DecoderWeights,
1031        batch: usize,
1032        hw: usize,
1033        seq: usize,
1034        device: Device,
1035        profile: &CompileProfile,
1036        gguf_packed: Option<&GgufPackedParams>,
1037    ) -> Result<Self> {
1038        ensure!(weights.loaded, "decoder weights not loaded");
1039        let nq = NUM_QUERIES;
1040        let d = D_MODEL;
1041        let h_w = (hw as f64).sqrt().round() as usize;
1042        ensure!(
1043            h_w * h_w == hw,
1044            "boxRPB cache requires square spatial grid; got hw={hw}"
1045        );
1046        let mut layers = Vec::with_capacity(N_LAYERS);
1047        // Metal: opt-in to bias-mask SDPA via env var. The default
1048        // routes Metal through the MPSGraph manual-decomp because the
1049        // bias-aware SDPA kernels (sdpa_long, sdpa_fa_f32) currently
1050        // produce incorrect output — needs debugging.
1051        let use_bias_attn = if matches!(device, Device::Metal) {
1052            rlx_ir::env::flag("RLX_SAM3_METAL_BIAS_SDPA")
1053        } else {
1054            true
1055        };
1056        // MLX: boxRPB in-graph (saves per-call rpb_bias upload). CPU: opt-in
1057        // via `RLX_SAM3_BOXRPB_IR` (host BLAS is faster for tiny K=2 GEMMs).
1058        // Metal: 5D broadcast outer-add in boxRPB subgraph is wrong on Metal IR
1059        // — keep host boxRPB until fixed (`RLX_SAM3_BOXRPB_IR` ignored on Metal).
1060        let boxrpb_in_ir = matches!(device, Device::Mlx)
1061            || (matches!(device, Device::Cpu) && rlx_ir::env::flag("RLX_SAM3_BOXRPB_IR"));
1062        let dec_base = &weights.prefix;
1063        for (li, layer) in weights.layers.iter().enumerate() {
1064            let (hir, params, typed) = build_layer_hir(
1065                layer,
1066                &weights.boxrpb_x,
1067                &weights.boxrpb_y,
1068                &weights.norm_w,
1069                &weights.norm_b,
1070                dec_base,
1071                li,
1072                batch,
1073                h_w,
1074                h_w,
1075                seq,
1076                use_bias_attn,
1077                boxrpb_in_ir,
1078                gguf_packed,
1079            )?;
1080            let mut compiled =
1081                rlx_core::flow_bridge::compile_hir_with_profile(device, hir, profile)?;
1082            rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
1083            layers.push(compiled);
1084        }
1085        // Precompute layer-0 inputs that depend only on constant model
1086        // weights (initial reference_points → cached query_pos and
1087        // boxRPB deltas). The boxRPB MLP+outer-add now runs inside the
1088        // graph, so we only cache the cheap geometry deltas.
1089        let mut cached_initial_ref_boxes = vec![0f32; batch * nq * 4];
1090        for b in 0..batch {
1091            for q in 0..nq {
1092                for k in 0..4 {
1093                    let v = weights.reference_points[q * 4 + k];
1094                    cached_initial_ref_boxes[(b * nq + q) * 4 + k] = sigmoid(v);
1095                }
1096            }
1097        }
1098        let sine = sineembed_4d(&cached_initial_ref_boxes, batch, nq, d);
1099        let cached_layer0_query_pos =
1100            mlp2_forward(&weights.ref_point_head, &sine, batch * nq, gguf_packed)?;
1101        let lq = nq + 1;
1102        let nh = N_HEADS;
1103        let (cached_layer0_deltas_x, cached_layer0_deltas_y, cached_layer0_rpb) = if boxrpb_in_ir {
1104            let mut dx = vec![0f32; batch * nq * h_w * 2];
1105            let mut dy = vec![0f32; batch * nq * h_w * 2];
1106            compute_deltas_into(
1107                &cached_initial_ref_boxes,
1108                batch,
1109                nq,
1110                h_w,
1111                h_w,
1112                &mut dx,
1113                &mut dy,
1114            );
1115            (Some(dx), Some(dy), None)
1116        } else {
1117            let rpb = boxrpb_log_full(
1118                &weights.boxrpb_x,
1119                &weights.boxrpb_y,
1120                &cached_initial_ref_boxes,
1121                batch,
1122                nq,
1123                h_w,
1124                h_w,
1125                gguf_packed,
1126            )?;
1127            (None, None, Some(rpb))
1128        };
1129        Ok(Self {
1130            layers,
1131            bbox_embed: weights.bbox_embed.clone(),
1132            ref_point_head: weights.ref_point_head.clone(),
1133            boxrpb_x: weights.boxrpb_x.clone(),
1134            boxrpb_y: weights.boxrpb_y.clone(),
1135            initial_query_embed: weights.query_embed.clone(),
1136            initial_reference_points: weights.reference_points.clone(),
1137            cached_layer0_query_pos,
1138            cached_layer0_deltas_x,
1139            cached_layer0_deltas_y,
1140            cached_layer0_rpb,
1141            cached_initial_ref_boxes,
1142            boxrpb_in_ir,
1143            presence_token: weights.presence_token.clone(),
1144            presence_head: weights.presence_token_head.clone(),
1145            presence_norm_w: weights.presence_token_out_norm_w.clone(),
1146            presence_norm_b: weights.presence_token_out_norm_b.clone(),
1147            scratch_deltas_x: if boxrpb_in_ir {
1148                vec![0f32; batch * nq * h_w * 2]
1149            } else {
1150                Vec::new()
1151            },
1152            scratch_deltas_y: if boxrpb_in_ir {
1153                vec![0f32; batch * nq * h_w * 2]
1154            } else {
1155                Vec::new()
1156            },
1157            scratch_rpb: (!boxrpb_in_ir).then(|| vec![0f32; batch * nh * lq * hw]),
1158            scratch_dx_thq: (!boxrpb_in_ir).then(|| vec![0f32; nh * nq * h_w]),
1159            scratch_dy_thq: (!boxrpb_in_ir).then(|| vec![0f32; nh * nq * h_w]),
1160            scratch_boxrpb_x_hidden: (!boxrpb_in_ir)
1161                .then(|| vec![0f32; nq * h_w * weights.boxrpb_x.hidden]),
1162            scratch_boxrpb_y_hidden: (!boxrpb_in_ir)
1163                .then(|| vec![0f32; nq * h_w * weights.boxrpb_y.hidden]),
1164            scratch_boxrpb_x_feats: (!boxrpb_in_ir)
1165                .then(|| vec![0f32; nq * h_w * weights.boxrpb_x.out_dim]),
1166            scratch_boxrpb_y_feats: (!boxrpb_in_ir)
1167                .then(|| vec![0f32; nq * h_w * weights.boxrpb_y.out_dim]),
1168            scratch_sine: vec![0f32; batch * nq * 2 * d],
1169            scratch_rph_hidden: vec![0f32; batch * nq * weights.ref_point_head.hidden],
1170            scratch_query_pos: vec![0f32; batch * nq * weights.ref_point_head.out_dim],
1171            scratch_bbox_h0: vec![0f32; batch * nq * weights.bbox_embed.hidden],
1172            scratch_bbox_h1: vec![0f32; batch * nq * weights.bbox_embed.hidden],
1173            scratch_bbox_out: vec![0f32; batch * nq * weights.bbox_embed.out_dim],
1174            batch,
1175            hw,
1176            seq,
1177            gguf_packed: gguf_packed.cloned(),
1178        })
1179    }
1180
1181    /// Run the decoder. Inputs are batch-first: `memory [B, hw, D]`,
1182    /// `memory_pos [B, hw, D]`, `text [B, seq, D]` (note: text is
1183    /// batch-first here, not seq-first), `text_kpm` (1 = PAD).
1184    pub fn run(
1185        &mut self,
1186        memory: &[f32],
1187        memory_pos: &[f32],
1188        text_seq_first: &[f32],
1189        text_kpm: &[u8],
1190        h: usize,
1191        w: usize,
1192    ) -> Result<LayerRunOut> {
1193        let hw = h * w;
1194        ensure!(hw == self.hw);
1195        let batch = self.batch;
1196        let nq = NUM_QUERIES;
1197        let d = D_MODEL;
1198        let nh = N_HEADS;
1199        let lq = nq + 1;
1200        let seq = self.seq;
1201
1202        // Initial tgt = query_embed expanded to batch.
1203        let mut tgt = vec![0f32; batch * nq * d];
1204        for b in 0..batch {
1205            tgt[b * nq * d..(b + 1) * nq * d].copy_from_slice(&self.initial_query_embed);
1206        }
1207        // Initial ref_boxes = sigmoid(reference_points).
1208        let mut ref_boxes = vec![0f32; batch * nq * 4];
1209        for b in 0..batch {
1210            for q in 0..nq {
1211                for k in 0..4 {
1212                    let v = self.initial_reference_points[q * 4 + k];
1213                    ref_boxes[(b * nq + q) * 4 + k] = sigmoid(v);
1214                }
1215            }
1216        }
1217        let mut presence = vec![0f32; batch * d];
1218        for b in 0..batch {
1219            presence[b * d..(b + 1) * d].copy_from_slice(&self.presence_token);
1220        }
1221
1222        // Text → batch-first.
1223        let mut text_bf = vec![0f32; batch * seq * d];
1224        for b in 0..batch {
1225            for l in 0..seq {
1226                let s = (l * batch + b) * d;
1227                let dst = (b * seq + l) * d;
1228                text_bf[dst..dst + d].copy_from_slice(&text_seq_first[s..s + d]);
1229            }
1230        }
1231        let text_kpm_inv: Vec<f32> = text_kpm
1232            .iter()
1233            .map(|&v| if v == 0 { 1.0 } else { 0.0 })
1234            .collect();
1235
1236        let mut intermediate = Vec::with_capacity(N_LAYERS);
1237        let mut intermediate_ref_boxes = Vec::with_capacity(N_LAYERS);
1238        intermediate_ref_boxes.push(ref_boxes.clone());
1239        let mut presence_logits = Vec::with_capacity(N_LAYERS);
1240
1241        let profile = rlx_ir::env::flag("RLX_SAM3_PROFILE");
1242        let mut t_qpos = 0u128;
1243        let mut t_rpb = 0u128;
1244        let mut t_graph = 0u128;
1245        let mut t_box = 0u128;
1246        let mut t_other = 0u128;
1247        for li in 0..N_LAYERS {
1248            let tq = std::time::Instant::now();
1249            // Compute query_pos = ref_point_head(sineembed(ref_boxes)).
1250            // Layer 0's ref_boxes is the constant `sigmoid(reference_points)`,
1251            // so its query_pos and boxRPB are precomputed once at
1252            // construction and reused per call.
1253            // For layer 0, use the cached query_pos slice directly and
1254            // the cached rpb buffer. For other layers, recompute into
1255            // pre-allocated scratch buffers so we don't malloc 33MB/layer.
1256            let query_pos_slice: &[f32];
1257            let rpb_slice: &[f32];
1258            let deltas_x_slice: &[f32];
1259            let deltas_y_slice: &[f32];
1260            if li == 0 {
1261                query_pos_slice = &self.cached_layer0_query_pos;
1262                if self.boxrpb_in_ir {
1263                    deltas_x_slice = self.cached_layer0_deltas_x.as_ref().unwrap();
1264                    deltas_y_slice = self.cached_layer0_deltas_y.as_ref().unwrap();
1265                    rpb_slice = &[];
1266                } else {
1267                    rpb_slice = self.cached_layer0_rpb.as_ref().unwrap();
1268                    deltas_x_slice = &[];
1269                    deltas_y_slice = &[];
1270                }
1271            } else {
1272                sineembed_4d_into(&ref_boxes, batch, nq, d, &mut self.scratch_sine);
1273                mlp2_forward_into(
1274                    &self.ref_point_head,
1275                    &self.scratch_sine,
1276                    batch * nq,
1277                    &mut self.scratch_rph_hidden,
1278                    &mut self.scratch_query_pos,
1279                    self.gguf_packed.as_ref(),
1280                )?;
1281                query_pos_slice = &self.scratch_query_pos;
1282                if self.boxrpb_in_ir {
1283                    compute_deltas_into(
1284                        &ref_boxes,
1285                        batch,
1286                        nq,
1287                        h,
1288                        w,
1289                        &mut self.scratch_deltas_x,
1290                        &mut self.scratch_deltas_y,
1291                    );
1292                    deltas_x_slice = &self.scratch_deltas_x;
1293                    deltas_y_slice = &self.scratch_deltas_y;
1294                    rpb_slice = &[];
1295                } else {
1296                    let mut host_deltas_x = vec![0f32; nq * w * 2];
1297                    let mut host_deltas_y = vec![0f32; nq * h * 2];
1298                    boxrpb_log_full_into(
1299                        &self.boxrpb_x,
1300                        &self.boxrpb_y,
1301                        &ref_boxes,
1302                        batch,
1303                        nq,
1304                        h,
1305                        w,
1306                        self.scratch_rpb.as_mut().unwrap(),
1307                        self.scratch_dx_thq.as_mut().unwrap(),
1308                        self.scratch_dy_thq.as_mut().unwrap(),
1309                        &mut host_deltas_x,
1310                        &mut host_deltas_y,
1311                        self.scratch_boxrpb_x_hidden.as_mut().unwrap(),
1312                        self.scratch_boxrpb_y_hidden.as_mut().unwrap(),
1313                        self.scratch_boxrpb_x_feats.as_mut().unwrap(),
1314                        self.scratch_boxrpb_y_feats.as_mut().unwrap(),
1315                        self.gguf_packed.as_ref(),
1316                    )?;
1317                    rpb_slice = self.scratch_rpb.as_ref().unwrap();
1318                    deltas_x_slice = &[];
1319                    deltas_y_slice = &[];
1320                }
1321            }
1322            if profile {
1323                t_qpos += tq.elapsed().as_micros();
1324            }
1325
1326            let tr = std::time::Instant::now();
1327            if profile {
1328                t_rpb += tr.elapsed().as_micros();
1329            }
1330
1331            // Run graph.
1332            let tg = std::time::Instant::now();
1333            let outputs = if self.boxrpb_in_ir {
1334                self.layers[li].run(&[
1335                    ("tgt", tgt.as_slice()),
1336                    ("query_pos", query_pos_slice),
1337                    ("presence", presence.as_slice()),
1338                    ("memory", memory),
1339                    ("memory_pos", memory_pos),
1340                    ("text", text_bf.as_slice()),
1341                    ("text_kpm_inv", text_kpm_inv.as_slice()),
1342                    ("deltas_x", deltas_x_slice),
1343                    ("deltas_y", deltas_y_slice),
1344                ])
1345            } else {
1346                self.layers[li].run(&[
1347                    ("tgt", tgt.as_slice()),
1348                    ("query_pos", query_pos_slice),
1349                    ("presence", presence.as_slice()),
1350                    ("memory", memory),
1351                    ("memory_pos", memory_pos),
1352                    ("text", text_bf.as_slice()),
1353                    ("text_kpm_inv", text_kpm_inv.as_slice()),
1354                    ("rpb_bias", rpb_slice),
1355                ])
1356            };
1357            if profile {
1358                t_graph += tg.elapsed().as_micros();
1359            }
1360            ensure!(outputs.len() == 3, "decoder layer expected 3 outputs");
1361            tgt = outputs[0].clone();
1362            presence = outputs[1].clone();
1363            let out_norm = outputs[2].clone();
1364
1365            let tb = std::time::Instant::now();
1366            // Box refinement: delta = bbox_embed(out_norm); ref = sigmoid(inv_sig(ref) + delta).
1367            mlp3_forward_into(
1368                &self.bbox_embed,
1369                &out_norm,
1370                batch * nq,
1371                &mut self.scratch_bbox_h0,
1372                &mut self.scratch_bbox_h1,
1373                &mut self.scratch_bbox_out,
1374                self.gguf_packed.as_ref(),
1375            )?;
1376            let delta: &[f32] = &self.scratch_bbox_out;
1377            if profile {
1378                t_box += tb.elapsed().as_micros();
1379            }
1380            let to = std::time::Instant::now();
1381            let _ = to;
1382            let _ = &mut t_other;
1383            let mut new_ref = vec![0f32; batch * nq * 4];
1384            for q in 0..nq {
1385                for b in 0..batch {
1386                    let cur = &ref_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1387                    let dl = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1388                    for k in 0..4 {
1389                        new_ref[(b * nq + q) * 4 + k] = sigmoid(inv_sigmoid(cur[k]) + dl[k]);
1390                    }
1391                }
1392            }
1393            ref_boxes = new_ref;
1394            if li != N_LAYERS - 1 {
1395                intermediate_ref_boxes.push(ref_boxes.clone());
1396            }
1397
1398            // Intermediate output in seq-first convention.
1399            let mut out_seq_first = vec![0f32; nq * batch * d];
1400            for q in 0..nq {
1401                for b in 0..batch {
1402                    let src = (b * nq + q) * d;
1403                    let dst = (q * batch + b) * d;
1404                    out_seq_first[dst..dst + d].copy_from_slice(&out_norm[src..src + d]);
1405                }
1406            }
1407            intermediate.push(out_seq_first);
1408
1409            // Presence logits.
1410            let p_norm =
1411                layer_norm_host(&presence, &self.presence_norm_w, &self.presence_norm_b, d);
1412            let p_logit = mlp3_forward(
1413                &self.presence_head,
1414                &p_norm,
1415                batch,
1416                self.gguf_packed.as_ref(),
1417            )?;
1418            presence_logits.push(p_logit);
1419        }
1420        if profile {
1421            let to_ms = |us: u128| us as f32 / 1000.0;
1422            eprintln!(
1423                "  decoder per-stage (6 layers total): qpos={:.1}ms  rpb={:.1}ms  graph={:.1}ms  box={:.1}ms",
1424                to_ms(t_qpos),
1425                to_ms(t_rpb),
1426                to_ms(t_graph),
1427                to_ms(t_box)
1428            );
1429        }
1430
1431        // Stack.
1432        let mut int_stack = vec![0f32; N_LAYERS * nq * batch * d];
1433        for (li, l) in intermediate.iter().enumerate() {
1434            int_stack[li * nq * batch * d..(li + 1) * nq * batch * d].copy_from_slice(l);
1435        }
1436        let mut ref_stack = vec![0f32; N_LAYERS * nq * batch * 4];
1437        for (li, r) in intermediate_ref_boxes.iter().enumerate() {
1438            ref_stack[li * nq * batch * 4..(li + 1) * nq * batch * 4].copy_from_slice(r);
1439        }
1440        let mut presence_stack = vec![0f32; N_LAYERS * batch];
1441        for (li, p) in presence_logits.iter().enumerate() {
1442            for b in 0..batch {
1443                presence_stack[li * batch + b] = p[b];
1444            }
1445        }
1446        let _ = nh;
1447        let _ = lq;
1448        Ok((int_stack, ref_stack, presence_stack, presence))
1449    }
1450}
1451
1452/// IR-compiled detector decoder on the requested device (6 layer graphs).
1453#[allow(clippy::too_many_arguments)]
1454pub fn forward_decoder_ir_on(
1455    weights: &Sam3DecoderWeights,
1456    memory: &[f32],
1457    memory_pos: &[f32],
1458    memory_text: &[f32],
1459    text_attention_mask: &[u8],
1460    batch: usize,
1461    h: usize,
1462    w: usize,
1463    seq_len: usize,
1464    device: Device,
1465) -> Result<Sam3DecoderOutput> {
1466    forward_decoder_ir_on_with_profile(
1467        weights,
1468        memory,
1469        memory_pos,
1470        memory_text,
1471        text_attention_mask,
1472        batch,
1473        h,
1474        w,
1475        seq_len,
1476        device,
1477        &CompileProfile::sam3(),
1478        None,
1479    )
1480}
1481
1482/// Same as [`forward_decoder_ir_on`] with an explicit tier-1 profile.
1483#[allow(clippy::too_many_arguments)]
1484pub fn forward_decoder_ir_on_with_profile(
1485    weights: &Sam3DecoderWeights,
1486    memory: &[f32],
1487    memory_pos: &[f32],
1488    memory_text: &[f32],
1489    text_attention_mask: &[u8],
1490    batch: usize,
1491    h: usize,
1492    w: usize,
1493    seq_len: usize,
1494    device: Device,
1495    profile: &CompileProfile,
1496    gguf_packed: Option<&GgufPackedParams>,
1497) -> Result<Sam3DecoderOutput> {
1498    ensure!(weights.loaded, "decoder weights not loaded");
1499    ensure!(batch == 1, "decoder IR forward requires batch=1 for boxRPB");
1500    let hw = h * w;
1501    let mut dec = Sam3CompiledDecoder::new_with_profile_and_gguf(
1502        weights,
1503        batch,
1504        hw,
1505        seq_len,
1506        device,
1507        profile,
1508        gguf_packed,
1509    )?;
1510    let (intermediate, intermediate_ref_boxes, presence_logits, presence_feats) =
1511        dec.run(memory, memory_pos, memory_text, text_attention_mask, h, w)?;
1512    Ok(Sam3DecoderOutput {
1513        intermediate,
1514        intermediate_ref_boxes,
1515        presence_logits,
1516        presence_feats,
1517        num_layers: N_LAYERS,
1518        num_queries: NUM_QUERIES,
1519        batch,
1520        d_model: D_MODEL,
1521    })
1522}
1523
1524// ── Host helpers (sineembed, boxRPB, mlp, sigmoid) ─────────────────────
1525
1526fn sigmoid(x: f32) -> f32 {
1527    1.0 / (1.0 + (-x).exp())
1528}
1529
1530fn inv_sigmoid(x: f32) -> f32 {
1531    let eps = 1e-3f32;
1532    let x = x.clamp(0.0, 1.0).max(eps).min(1.0 - eps);
1533    (x / (1.0 - x)).ln()
1534}
1535
1536fn layer_norm_host(x: &[f32], gamma: &[f32], beta: &[f32], dim: usize) -> Vec<f32> {
1537    let rows = x.len() / dim;
1538    let mut out = vec![0f32; x.len()];
1539    for r in 0..rows {
1540        let row = &x[r * dim..(r + 1) * dim];
1541        let mean = row.iter().sum::<f32>() / dim as f32;
1542        let var = row.iter().map(|v| (*v - mean).powi(2)).sum::<f32>() / dim as f32;
1543        let inv = 1.0 / (var + 1e-5).sqrt();
1544        for c in 0..dim {
1545            out[r * dim + c] = (row[c] - mean) * inv * gamma[c] + beta[c];
1546        }
1547    }
1548    out
1549}
1550
1551#[allow(dead_code)]
1552fn host_mlp2_forward(mlp: &Mlp2, x: &[f32], rows: usize) -> Result<Vec<f32>> {
1553    let h = matmul_bias_relu(x, &mlp.w0_t, &mlp.b0, rows, mlp.in_dim, mlp.hidden);
1554    Ok(matmul_bias(
1555        &h,
1556        &mlp.w1_t,
1557        &mlp.b1,
1558        rows,
1559        mlp.hidden,
1560        mlp.out_dim,
1561    ))
1562}
1563
1564/// In-place mlp2: `out = w1·relu(w0·x + b0) + b1`. Caller provides the
1565/// hidden scratch buffer and output buffer — no allocation in the hot
1566/// path. First layer uses fused matmul+bias+relu epilogue.
1567#[allow(dead_code)]
1568fn host_mlp2_forward_into(mlp: &Mlp2, x: &[f32], rows: usize, hidden: &mut [f32], out: &mut [f32]) {
1569    rlx_cpu::blas::sgemm_bias_epilogue(
1570        x,
1571        &mlp.w0_t,
1572        &mlp.b0,
1573        hidden,
1574        rows,
1575        mlp.in_dim,
1576        mlp.hidden,
1577        |v| if v < 0.0 { 0.0 } else { v },
1578    );
1579    rlx_cpu::blas::sgemm_bias(
1580        hidden,
1581        &mlp.w1_t,
1582        &mlp.b1,
1583        out,
1584        rows,
1585        mlp.hidden,
1586        mlp.out_dim,
1587    );
1588}
1589
1590#[allow(dead_code)]
1591fn host_mlp3_forward(mlp: &Mlp3, x: &[f32], rows: usize) -> Result<Vec<f32>> {
1592    let h = matmul_bias_relu(x, &mlp.w0_t, &mlp.b0, rows, mlp.in_dim, mlp.hidden);
1593    let h = matmul_bias_relu(&h, &mlp.w1_t, &mlp.b1, rows, mlp.hidden, mlp.hidden);
1594    Ok(matmul_bias(
1595        &h,
1596        &mlp.w2_t,
1597        &mlp.b2,
1598        rows,
1599        mlp.hidden,
1600        mlp.out_dim,
1601    ))
1602}
1603
1604#[allow(dead_code)]
1605fn host_mlp3_forward_into(
1606    mlp: &Mlp3,
1607    x: &[f32],
1608    rows: usize,
1609    h0: &mut [f32],
1610    h1: &mut [f32],
1611    out: &mut [f32],
1612) {
1613    let relu = |v: f32| if v < 0.0 { 0.0 } else { v };
1614    rlx_cpu::blas::sgemm_bias_epilogue(
1615        x, &mlp.w0_t, &mlp.b0, h0, rows, mlp.in_dim, mlp.hidden, relu,
1616    );
1617    rlx_cpu::blas::sgemm_bias_epilogue(
1618        h0, &mlp.w1_t, &mlp.b1, h1, rows, mlp.hidden, mlp.hidden, relu,
1619    );
1620    rlx_cpu::blas::sgemm_bias(h1, &mlp.w2_t, &mlp.b2, out, rows, mlp.hidden, mlp.out_dim);
1621}
1622
1623#[allow(dead_code)]
1624fn matmul_bias(x: &[f32], w_t: &[f32], b: &[f32], rows: usize, k: usize, n: usize) -> Vec<f32> {
1625    let mut out = vec![0f32; rows * n];
1626    rlx_cpu::blas::sgemm_bias(x, w_t, b, &mut out, rows, k, n);
1627    out
1628}
1629
1630#[allow(dead_code)]
1631fn matmul_bias_relu(
1632    x: &[f32],
1633    w_t: &[f32],
1634    b: &[f32],
1635    rows: usize,
1636    k: usize,
1637    n: usize,
1638) -> Vec<f32> {
1639    let mut out = matmul_bias(x, w_t, b, rows, k, n);
1640    for v in out.iter_mut() {
1641        if *v < 0.0 {
1642            *v = 0.0;
1643        }
1644    }
1645    out
1646}
1647
1648fn sineembed_4d(pos: &[f32], batch: usize, nq: usize, d_model: usize) -> Vec<f32> {
1649    let mut out = vec![0.0f32; batch * nq * 2 * d_model];
1650    sineembed_4d_into(pos, batch, nq, d_model, &mut out);
1651    out
1652}
1653
1654fn sineembed_4d_into(pos: &[f32], batch: usize, nq: usize, d_model: usize, out: &mut [f32]) {
1655    let half = d_model / 2;
1656    let scale = 2.0 * std::f32::consts::PI;
1657    let mut dim_t = vec![0.0f32; half];
1658    for i in 0..half {
1659        let exp = 2.0 * ((i / 2) as f32) / half as f32;
1660        dim_t[i] = 10000.0f32.powf(exp);
1661    }
1662    debug_assert_eq!(out.len(), batch * nq * 2 * d_model);
1663    for b in 0..batch {
1664        for q in 0..nq {
1665            let p = &pos[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1666            let vals = [p[1] * scale, p[0] * scale, p[2] * scale, p[3] * scale];
1667            let base = (b * nq + q) * 2 * d_model;
1668            for axis in 0..4 {
1669                let slot = base + axis * half;
1670                for i in 0..half {
1671                    let theta = vals[axis] / dim_t[i];
1672                    out[slot + i] = if i % 2 == 0 { theta.sin() } else { theta.cos() };
1673                }
1674            }
1675        }
1676    }
1677}
1678
1679/// Owning version that allocates; kept for the construct-time cache
1680/// where we run it once. Hot path uses `boxrpb_log_full_into` with a
1681/// pre-allocated scratch buffer.
1682fn boxrpb_log_full(
1683    boxrpb_x: &Mlp2,
1684    boxrpb_y: &Mlp2,
1685    reference_boxes: &[f32],
1686    batch: usize,
1687    nq: usize,
1688    h: usize,
1689    w: usize,
1690    gguf_packed: Option<&GgufPackedParams>,
1691) -> Result<Vec<f32>> {
1692    let nh = N_HEADS;
1693    let lq = nq + 1;
1694    let mut out = vec![0f32; batch * nh * lq * h * w];
1695    let mut dx_thq = vec![0f32; nh * nq * w];
1696    let mut dy_thq = vec![0f32; nh * nq * h];
1697    let mut deltas_x = vec![0f32; nq * w * 2];
1698    let mut deltas_y = vec![0f32; nq * h * 2];
1699    let mut hidden_x = vec![0f32; nq * w * boxrpb_x.hidden];
1700    let mut hidden_y = vec![0f32; nq * h * boxrpb_y.hidden];
1701    let mut feats_x = vec![0f32; nq * w * boxrpb_x.out_dim];
1702    let mut feats_y = vec![0f32; nq * h * boxrpb_y.out_dim];
1703    boxrpb_log_full_into(
1704        boxrpb_x,
1705        boxrpb_y,
1706        reference_boxes,
1707        batch,
1708        nq,
1709        h,
1710        w,
1711        &mut out,
1712        &mut dx_thq,
1713        &mut dy_thq,
1714        &mut deltas_x,
1715        &mut deltas_y,
1716        &mut hidden_x,
1717        &mut hidden_y,
1718        &mut feats_x,
1719        &mut feats_y,
1720        gguf_packed,
1721    )?;
1722    Ok(out)
1723}
1724
1725#[allow(clippy::too_many_arguments)]
1726fn boxrpb_log_full_into(
1727    boxrpb_x: &Mlp2,
1728    boxrpb_y: &Mlp2,
1729    reference_boxes: &[f32],
1730    batch: usize,
1731    nq: usize,
1732    h: usize,
1733    w: usize,
1734    out: &mut [f32],
1735    dx_thq: &mut [f32],
1736    dy_thq: &mut [f32],
1737    deltas_x: &mut [f32],
1738    deltas_y: &mut [f32],
1739    hidden_x: &mut [f32],
1740    hidden_y: &mut [f32],
1741    feats_x: &mut [f32],
1742    feats_y: &mut [f32],
1743    gguf_packed: Option<&GgufPackedParams>,
1744) -> Result<()> {
1745    let nh = N_HEADS;
1746    let lq = nq + 1;
1747    debug_assert_eq!(out.len(), batch * nh * lq * h * w);
1748    debug_assert_eq!(dx_thq.len(), nh * nq * w);
1749    debug_assert_eq!(dy_thq.len(), nh * nq * h);
1750    debug_assert_eq!(deltas_x.len(), nq * w * 2);
1751    debug_assert_eq!(deltas_y.len(), nq * h * 2);
1752    debug_assert_eq!(hidden_x.len(), nq * w * boxrpb_x.hidden);
1753    debug_assert_eq!(hidden_y.len(), nq * h * boxrpb_y.hidden);
1754    debug_assert_eq!(feats_x.len(), nq * w * boxrpb_x.out_dim);
1755    debug_assert_eq!(feats_y.len(), nq * h * boxrpb_y.out_dim);
1756    // Zero the presence rows once — non-presence rows get overwritten.
1757    for head in 0..nh {
1758        for b in 0..batch {
1759            let off = b * nh * lq * h * w + head * lq * h * w;
1760            // Presence row at lq=0
1761            for i in 0..h * w {
1762                out[off + i] = 0.0;
1763            }
1764        }
1765    }
1766    let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
1767    let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
1768
1769    for b in 0..batch {
1770        for q in 0..nq {
1771            let p = &reference_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1772            let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
1773            let x0 = cx - 0.5 * bw;
1774            let x1 = cx + 0.5 * bw;
1775            let y0 = cy - 0.5 * bh;
1776            let y1 = cy + 0.5 * bh;
1777            for xi in 0..w {
1778                let dx0 = (coords_w[xi] - x0) * 8.0;
1779                let dx1 = (coords_w[xi] - x1) * 8.0;
1780                deltas_x[(q * w + xi) * 2] = log_norm(dx0);
1781                deltas_x[(q * w + xi) * 2 + 1] = log_norm(dx1);
1782            }
1783            for yi in 0..h {
1784                let dy0 = (coords_h[yi] - y0) * 8.0;
1785                let dy1 = (coords_h[yi] - y1) * 8.0;
1786                deltas_y[(q * h + yi) * 2] = log_norm(dy0);
1787                deltas_y[(q * h + yi) * 2 + 1] = log_norm(dy1);
1788            }
1789        }
1790        mlp2_forward_into(boxrpb_x, deltas_x, nq * w, hidden_x, feats_x, gguf_packed)?;
1791        mlp2_forward_into(boxrpb_y, deltas_y, nq * h, hidden_y, feats_y, gguf_packed)?;
1792        let dx_feats: &[f32] = feats_x;
1793        let dy_feats: &[f32] = feats_y;
1794        // Transpose dx/dy from [pos, head] to [head, q, pos] so the
1795        // outer-add per (head, q) reads contiguous slices.
1796        for q in 0..nq {
1797            for xi in 0..w {
1798                let src_base = (q * w + xi) * nh;
1799                for head in 0..nh {
1800                    dx_thq[(head * nq + q) * w + xi] = dx_feats[src_base + head];
1801                }
1802            }
1803            for yi in 0..h {
1804                let src_base = (q * h + yi) * nh;
1805                for head in 0..nh {
1806                    dy_thq[(head * nq + q) * h + yi] = dy_feats[src_base + head];
1807                }
1808            }
1809        }
1810        let base = b * nh * lq * h * w;
1811        let total = nh * nq;
1812        let out_ptr = out.as_mut_ptr() as usize;
1813        let dx_ptr = dx_thq.as_ptr() as usize;
1814        let dy_ptr = dy_thq.as_ptr() as usize;
1815        rlx_cpu::pool::par_for(total, 8, &|off, cnt| unsafe {
1816            for idx in off..off + cnt {
1817                let head = idx / nq;
1818                let q = idx % nq;
1819                let dst = (out_ptr as *mut f32).add(base + (head * lq + 1 + q) * h * w);
1820                let dx_row =
1821                    std::slice::from_raw_parts((dx_ptr as *const f32).add((head * nq + q) * w), w);
1822                let dy_row =
1823                    std::slice::from_raw_parts((dy_ptr as *const f32).add((head * nq + q) * h), h);
1824                for y in 0..h {
1825                    let dy = dy_row[y];
1826                    let row_dst = dst.add(y * w);
1827                    for x in 0..w {
1828                        *row_dst.add(x) = dy + dx_row[x];
1829                    }
1830                }
1831            }
1832        });
1833    }
1834    Ok(())
1835}
1836
1837fn log_norm(v: f32) -> f32 {
1838    let s = if v < 0.0 { -1.0 } else { 1.0 };
1839    s * (v.abs() + 1.0).log2() / 8.0f32.log2()
1840}
1841
1842/// Geometry-only delta computation. Output layout matches the IR
1843/// `deltas_x [B, nq, w, 2]` / `deltas_y [B, nq, h, 2]` inputs that
1844/// feed the boxRPB subgraph: per (batch, query, spatial-coord), pair
1845/// of `log_norm((coord - left)*8)` and `log_norm((coord - right)*8)`.
1846fn compute_deltas_into(
1847    reference_boxes: &[f32],
1848    batch: usize,
1849    nq: usize,
1850    h: usize,
1851    w: usize,
1852    deltas_x: &mut [f32],
1853    deltas_y: &mut [f32],
1854) {
1855    debug_assert_eq!(deltas_x.len(), batch * nq * w * 2);
1856    debug_assert_eq!(deltas_y.len(), batch * nq * h * 2);
1857    let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
1858    let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
1859    for b in 0..batch {
1860        for q in 0..nq {
1861            let p = &reference_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1862            let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
1863            let x0 = cx - 0.5 * bw;
1864            let x1 = cx + 0.5 * bw;
1865            let y0 = cy - 0.5 * bh;
1866            let y1 = cy + 0.5 * bh;
1867            let dx_off = ((b * nq + q) * w) * 2;
1868            for xi in 0..w {
1869                let dx0 = (coords_w[xi] - x0) * 8.0;
1870                let dx1 = (coords_w[xi] - x1) * 8.0;
1871                deltas_x[dx_off + xi * 2] = log_norm(dx0);
1872                deltas_x[dx_off + xi * 2 + 1] = log_norm(dx1);
1873            }
1874            let dy_off = ((b * nq + q) * h) * 2;
1875            for yi in 0..h {
1876                let dy0 = (coords_h[yi] - y0) * 8.0;
1877                let dy1 = (coords_h[yi] - y1) * 8.0;
1878                deltas_y[dy_off + yi * 2] = log_norm(dy0);
1879                deltas_y[dy_off + yi * 2 + 1] = log_norm(dy1);
1880            }
1881        }
1882    }
1883}
1884
1885/// Standalone boxRPB subgraph for CPU parity checks (`build_boxrpb_subgraph`).
1886#[allow(dead_code)]
1887fn build_boxrpb_check_hir(
1888    boxrpb_x: &Mlp2,
1889    boxrpb_y: &Mlp2,
1890    batch: usize,
1891    nq: usize,
1892    h: usize,
1893    w: usize,
1894) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1895    let nh = N_HEADS;
1896    let mut hir = HirModule::new("sam3_boxrpb_check");
1897    let mut params = HashMap::new();
1898    let mut typed = Vec::new();
1899    let mut gguf_cache = HashMap::new();
1900    {
1901        let mut g = HirMut::new(&mut hir);
1902        let f = DType::F32;
1903        let deltas_x = g.input("deltas_x", Shape::new(&[batch, nq, w, 2], f));
1904        let deltas_y = g.input("deltas_y", Shape::new(&[batch, nq, h, 2], f));
1905        let out = build_boxrpb_subgraph(
1906            &mut g,
1907            &mut params,
1908            &mut typed,
1909            &mut gguf_cache,
1910            None,
1911            boxrpb_x,
1912            boxrpb_y,
1913            deltas_x,
1914            deltas_y,
1915            batch,
1916            nq,
1917            nh,
1918            h,
1919            w,
1920        )?;
1921        g.set_outputs(vec![out]);
1922    }
1923    Ok((hir, params))
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928    use super::*;
1929
1930    fn synth_mlp2(in_d: usize, hidden: usize, out_d: usize) -> Mlp2 {
1931        Mlp2 {
1932            w0_t: vec![0.01; in_d * hidden],
1933            b0: vec![0.0; hidden],
1934            w1_t: vec![0.02; hidden * out_d],
1935            b1: vec![0.0; out_d],
1936            in_dim: in_d,
1937            hidden,
1938            out_dim: out_d,
1939            w0_gguf_key: None,
1940            w1_gguf_key: None,
1941        }
1942    }
1943
1944    #[test]
1945    fn sam3_boxrpb_ir_matches_host_cpu() -> Result<()> {
1946        let batch = 1usize;
1947        let nq = 2usize;
1948        let h = 4usize;
1949        let w = 4usize;
1950        let nh = N_HEADS;
1951        let boxrpb_x = synth_mlp2(2, 16, nh);
1952        let boxrpb_y = synth_mlp2(2, 16, nh);
1953        let ref_boxes = vec![
1954            0.5, 0.5, 0.4, 0.4, //
1955            0.3, 0.7, 0.2, 0.3,
1956        ];
1957        let host = boxrpb_log_full(&boxrpb_x, &boxrpb_y, &ref_boxes, batch, nq, h, w, None)?;
1958
1959        let mut deltas_x = vec![0f32; batch * nq * w * 2];
1960        let mut deltas_y = vec![0f32; batch * nq * h * 2];
1961        compute_deltas_into(&ref_boxes, batch, nq, h, w, &mut deltas_x, &mut deltas_y);
1962
1963        let (hir, params) = build_boxrpb_check_hir(&boxrpb_x, &boxrpb_y, batch, nq, h, w)?;
1964        let mut compiled = rlx_core::flow_bridge::compile_hir_sam(Device::Cpu, hir)?;
1965        for (name, data) in &params {
1966            compiled.set_param(name, data);
1967        }
1968        let ir = compiled
1969            .run(&[("deltas_x", &deltas_x), ("deltas_y", &deltas_y)])
1970            .into_iter()
1971            .next()
1972            .unwrap();
1973
1974        let fd = host
1975            .iter()
1976            .zip(&ir)
1977            .map(|(a, b)| (a - b).abs())
1978            .fold(0f32, f32::max);
1979        assert!(fd < 5e-2, "sam3 boxRPB IR vs host max |Δ| = {fd:.3e}");
1980        Ok(())
1981    }
1982}