Skip to main content

rlx_sam/
image_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 ViT image encoder HIR builder.
17//!
18//! Mirrors `candle-transformers/src/models/segment_anything/image_encoder.rs`.
19//! Decomposes attention into primitives (rlx-ir's `attention_` op is a
20//! black box and can't host the inline rel-pos add SAM uses).
21//!
22//! Two attention modes:
23//!   - **Global** (window_size == 0): full S = hw·hw attention. Used by
24//!     blocks listed in `global_attn_indexes`.
25//!   - **Windowed** (window_size > 0): pad spatial dims to a multiple
26//!     of `window_size` via concat-with-zeros, reshape into
27//!     `[B·nW, ws, ws, C]`, attention within each window, reverse the
28//!     reshape, narrow off the padding.
29//!
30//! The neck (Conv2d 1×1 + LN2d + Conv2d 3×3 + LN2d → `[B, 256, hw, hw]`)
31//! is appended to the encoder HIR via [`rlx_core::vision_ops_ir`].
32
33use super::config::{SAM_EMBED_HW, SamEncoderConfig};
34use super::preprocess::{SamPreprocessWeights, extract_preprocess_weights};
35use anyhow::{Result, anyhow, ensure};
36use rlx_core::vision_ops_ir::{bhwc_to_nchw, conv2d_bias, conv2d_no_bias, layer_norm2d_nchw};
37use rlx_core::weight_map::WeightMap;
38use rlx_ir::HirGraphExt;
39use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
40use rlx_ir::*;
41use std::collections::HashMap;
42
43struct SamBuilder {
44    hir: HirModule,
45    params: HashMap<String, Vec<f32>>,
46}
47
48impl SamBuilder {
49    fn new(name: &str) -> Self {
50        Self {
51            hir: HirModule::new(name),
52            params: HashMap::new(),
53        }
54    }
55
56    fn m(&mut self) -> HirMut<'_> {
57        HirMut::new(&mut self.hir)
58    }
59}
60
61#[allow(dead_code)]
62fn lower_hir(hir: HirModule) -> Result<Graph> {
63    Graph::from_hir(hir).map_err(|e| anyhow!("{e}"))
64}
65
66/// Build the SAM ViT image-encoder HIR (body + neck).
67///
68/// Input: `"hidden"` shape `[1, hw·hw, embed_dim]` — patch tokens from
69/// [`crate::sam::preprocess::assemble_patch_tokens`].
70///
71/// Output: `[1, out_chans, hw, hw]` NCHW image embeddings.
72pub fn build_sam_encoder_hir(
73    cfg: &SamEncoderConfig,
74    weights: &mut WeightMap,
75) -> Result<(HirModule, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
76    let mut b = SamBuilder::new("sam_image_encoder");
77    let f = DType::F32;
78
79    // Host-side preprocess weights (patch projection + abs pos embed).
80    // Drain these *before* iterating blocks so the keys are gone when
81    // we later assert the WeightMap is empty.
82    let preprocess = extract_preprocess_weights(weights, cfg)?;
83
84    let e = cfg.embed_dim;
85    let nh = cfg.num_heads;
86    let dh = cfg.head_dim();
87    let scale = 1.0 / (dh as f32).sqrt();
88    let eps = cfg.layer_norm_eps as f32;
89    let hw = SAM_EMBED_HW;
90    let s = hw * hw; // 64·64 = 4096
91
92    // Input: pre-assembled patch tokens [1, 4096, E].
93    let hidden_input = b.m().input("hidden", Shape::new(&[1, s, e], f));
94
95    let mut x = hidden_input;
96    for layer_idx in 0..cfg.depth {
97        let lp = format!("image_encoder.blocks.{layer_idx}");
98        let is_global = cfg.global_attn_indexes.contains(&layer_idx);
99        let ws = if is_global { 0 } else { cfg.window_size };
100
101        // ── Pre-LN1 ──
102        let n1_g = load_p(&mut b, weights, &format!("{lp}.norm1.weight"), false)?;
103        let n1_b = load_p(&mut b, weights, &format!("{lp}.norm1.bias"), false)?;
104        let normed = b.m().ln(x, n1_g, n1_b, eps);
105
106        // ── Attention (windowed or global) ──
107        let attn_out = if ws == 0 {
108            attention_global(
109                &mut b,
110                weights,
111                &lp,
112                normed,
113                e,
114                nh,
115                dh,
116                scale,
117                hw,
118                cfg.use_rel_pos,
119                cfg.qkv_bias,
120            )?
121        } else {
122            attention_windowed(
123                &mut b,
124                weights,
125                &lp,
126                normed,
127                e,
128                nh,
129                dh,
130                scale,
131                hw,
132                ws,
133                cfg.use_rel_pos,
134                cfg.qkv_bias,
135            )?
136        };
137
138        // Residual
139        x = b.m().add(x, attn_out);
140
141        // ── Pre-LN2 + MLP (4× expansion, plain GELU) ──
142        let n2_g = load_p(&mut b, weights, &format!("{lp}.norm2.weight"), false)?;
143        let n2_b = load_p(&mut b, weights, &format!("{lp}.norm2.bias"), false)?;
144        let normed2 = b.m().ln(x, n2_g, n2_b, eps);
145
146        let fc1_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.weight"), true)?;
147        let fc1_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.bias"), false)?;
148        let fc2_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.weight"), true)?;
149        let fc2_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.bias"), false)?;
150
151        let up_mm = b.m().mm(normed2, fc1_w);
152        let up = b.m().add(up_mm, fc1_b);
153        // candle's `Activation::Gelu` dispatches to `Tensor::gelu_erf()`
154        // — the exact erf form — for SAM's MlpBlock. Use the matching
155        // erf kernel here.
156        let act = b.m().gelu(up);
157        let down_mm = b.m().mm(act, fc2_w);
158        let ffn = b.m().add(down_mm, fc2_b);
159
160        x = b.m().add(x, ffn);
161    }
162
163    // ── Neck: BHWC → NCHW, 1×1 conv, LN2d, 3×3 conv, LN2d ──
164    let oc = cfg.out_chans;
165    let nchw = bhwc_to_nchw(&mut b.m(), x, 1, hw, hw, e);
166    let c1_w = load_p(&mut b, weights, "image_encoder.neck.0.weight", false)?;
167    let c1_b = load_p(&mut b, weights, "image_encoder.neck.0.bias", false)?;
168    let feat = conv2d_bias(
169        &mut b.m(),
170        nchw,
171        c1_w,
172        c1_b,
173        1,
174        oc,
175        1,
176        1,
177        [1, 1],
178        [0, 0],
179        hw,
180        hw,
181    );
182    let ln1_g = load_p(&mut b, weights, "image_encoder.neck.1.weight", false)?;
183    let ln1_b = load_p(&mut b, weights, "image_encoder.neck.1.bias", false)?;
184    let feat = layer_norm2d_nchw(&mut b.m(), feat, ln1_g, ln1_b, eps);
185    let c2_w = load_p(&mut b, weights, "image_encoder.neck.2.weight", false)?;
186    let feat = conv2d_no_bias(&mut b.m(), feat, c2_w, 1, oc, 3, 3, [1, 1], [1, 1], hw, hw);
187    let ln2_g = load_p(&mut b, weights, "image_encoder.neck.3.weight", false)?;
188    let ln2_b = load_p(&mut b, weights, "image_encoder.neck.3.bias", false)?;
189    let out = layer_norm2d_nchw(&mut b.m(), feat, ln2_g, ln2_b, eps);
190
191    b.hir.set_outputs(vec![out]);
192
193    Ok((b.hir, b.params, preprocess))
194}
195
196/// Lowered graph wrapper for legacy callers (via [`super::flow::SamEncoderFlow`]).
197pub fn build_sam_encoder_graph(
198    cfg: &SamEncoderConfig,
199    weights: &mut WeightMap,
200) -> Result<(Graph, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
201    let built = super::flow::build_sam_encoder_built(cfg, weights)?;
202    let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
203    Ok((graph, params, built.preprocess))
204}
205
206/// Global-attention block: full self-attention over all `hw·hw` tokens.
207#[allow(clippy::too_many_arguments)]
208fn attention_global(
209    sb: &mut SamBuilder,
210    w: &mut WeightMap,
211    lp: &str,
212    x: HirNodeId, // [1, S, E]
213    e: usize,
214    nh: usize,
215    dh: usize,
216    scale: f32,
217    hw: usize,
218    use_rel_pos: bool,
219    qkv_bias: bool,
220) -> Result<HirNodeId> {
221    let s = hw * hw;
222    decomposed_attention(
223        sb,
224        w,
225        lp,
226        x,
227        e,
228        nh,
229        dh,
230        scale,
231        hw,
232        hw,
233        s,
234        1,
235        use_rel_pos,
236        qkv_bias,
237    )
238}
239
240/// Windowed-attention block: pad → partition into `nW = (hw_p/ws)²`
241/// windows → attention within each window → reverse partition → crop.
242#[allow(clippy::too_many_arguments)]
243fn attention_windowed(
244    sb: &mut SamBuilder,
245    w: &mut WeightMap,
246    lp: &str,
247    x: HirNodeId, // [1, S, E] flat (= [1, hw, hw, E] BHWC, flattened)
248    e: usize,
249    nh: usize,
250    dh: usize,
251    scale: f32,
252    hw: usize,
253    ws: usize,
254    use_rel_pos: bool,
255    qkv_bias: bool,
256) -> Result<HirNodeId> {
257    // Restore spatial: [1, S, E] → [1, hw, hw, E]
258    let bhwc = sb.m().reshape_(x, vec![1, hw as i64, hw as i64, e as i64]);
259
260    let pad = (ws - hw % ws) % ws;
261    let hw_p = hw + pad;
262    let n_win_per_side = hw_p / ws;
263    let n_win = n_win_per_side * n_win_per_side;
264
265    // Pad with concat-zeros along axes 1, 2.
266    let padded = if pad > 0 {
267        let z_h = pad_zero_param(sb, &format!("{lp}.attn._pad_h"), &[1, pad, hw, e]);
268        let p1 = sb.m().concat_(vec![bhwc, z_h], 1); // [1, hw_p, hw, E]
269        let z_w = pad_zero_param(sb, &format!("{lp}.attn._pad_w"), &[1, hw_p, pad, e]);
270        sb.m().concat_(vec![p1, z_w], 2) // [1, hw_p, hw_p, E]
271    } else {
272        bhwc
273    };
274
275    // [1, hw_p, hw_p, E] → [1, nw, ws, nw, ws, E] → transpose(2,3)
276    //   → [1, nw, nw, ws, ws, E] → reshape [nw², ws, ws, E]
277    let reshaped = sb.m().reshape_(
278        padded,
279        vec![
280            1,
281            n_win_per_side as i64,
282            ws as i64,
283            n_win_per_side as i64,
284            ws as i64,
285            e as i64,
286        ],
287    );
288    let transposed = sb.m().transpose_(reshaped, vec![0, 1, 3, 2, 4, 5]);
289    let windowed = sb.m().reshape_(
290        transposed,
291        vec![n_win as i64, ws as i64, ws as i64, e as i64],
292    );
293    // Flatten spatial for the attention: [nw², ws², E]
294    let win_flat = sb
295        .m()
296        .reshape_(windowed, vec![n_win as i64, (ws * ws) as i64, e as i64]);
297
298    // Run decomposed attention. Window has spatial dims (ws, ws);
299    // sequence length S = ws·ws; batch dim = n_win.
300    let attn_out = decomposed_attention(
301        sb,
302        w,
303        lp,
304        win_flat,
305        e,
306        nh,
307        dh,
308        scale,
309        ws,
310        ws,
311        ws * ws,
312        n_win,
313        use_rel_pos,
314        qkv_bias,
315    )?;
316    // attn_out: [nw², ws·ws, E]
317
318    // Reverse: [nw², ws², E] → [nw², ws, ws, E] → [1, nw, nw, ws, ws, E]
319    //   → transpose(2,3) → [1, nw, ws, nw, ws, E] → [1, hw_p, hw_p, E]
320    let un = sb
321        .m()
322        .reshape_(attn_out, vec![n_win as i64, ws as i64, ws as i64, e as i64]);
323    let un = sb.m().reshape_(
324        un,
325        vec![
326            1,
327            n_win_per_side as i64,
328            n_win_per_side as i64,
329            ws as i64,
330            ws as i64,
331            e as i64,
332        ],
333    );
334    let un = sb.m().transpose_(un, vec![0, 1, 3, 2, 4, 5]);
335    let un = sb
336        .m()
337        .reshape_(un, vec![1, hw_p as i64, hw_p as i64, e as i64]);
338    // Crop off the padding
339    let un = if pad > 0 {
340        let cropped_h = sb.m().narrow_(un, 1, 0, hw);
341        sb.m().narrow_(cropped_h, 2, 0, hw)
342    } else {
343        un
344    };
345    // Flatten back to [1, S, E]
346    Ok(sb.m().reshape_(un, vec![1, (hw * hw) as i64, e as i64]))
347}
348
349/// Decomposed multi-head attention with optional decomposed rel_pos.
350/// Input `[B, S, E]`; output `[B, S, E]`.
351///
352/// `h, w` are the spatial dims of the attention window (S = h·w).
353/// For windowed attention `B = n_win`, `h = w = ws`. For global,
354/// `B = 1`, `h = w = hw`.
355#[allow(clippy::too_many_arguments)]
356fn decomposed_attention(
357    sb: &mut SamBuilder,
358    w: &mut WeightMap,
359    lp: &str,
360    x: HirNodeId, // [B, S, E]
361    e: usize,
362    nh: usize,
363    dh: usize,
364    scale: f32,
365    h: usize,
366    w_dim: usize,
367    s: usize, // = h * w_dim
368    batch: usize,
369    use_rel_pos: bool,
370    qkv_bias: bool,
371) -> Result<HirNodeId> {
372    // 1) QKV projection. Bias param is loaded *before* the mm so its
373    //    HirNodeId is lower — `FuseMatMulBiasAct` walks nodes in topo
374    //    order and assumes the bias has been copied into the new id
375    //    map before the matmul is rewritten.
376    let qkv_w_node = load_p(sb, w, &format!("{lp}.attn.qkv.weight"), true)?;
377    let qkv_b_node = if qkv_bias {
378        Some(load_p(sb, w, &format!("{lp}.attn.qkv.bias"), false)?)
379    } else {
380        None
381    };
382    let qkv_mm = sb.m().mm(x, qkv_w_node); // [B, S, 3E]
383    let qkv = if let Some(b) = qkv_b_node {
384        sb.m().add(qkv_mm, b)
385    } else {
386        qkv_mm
387    };
388
389    // 2) Reshape & permute to [3, B·nh, S, dh].
390    //    [B, S, 3E] → [B, S, 3, nh, dh] → permute(2,0,3,1,4) → [3, B, nh, S, dh]
391    //    → reshape [3, B·nh, S, dh].
392    let qkv5 = sb
393        .m()
394        .reshape_(qkv, vec![batch as i64, s as i64, 3, nh as i64, dh as i64]);
395    let qkv_perm = sb.m().transpose_(qkv5, vec![2, 0, 3, 1, 4]); // [3, B, nh, S, dh]
396    let qkv_flat = sb
397        .m()
398        .reshape_(qkv_perm, vec![3, (batch * nh) as i64, s as i64, dh as i64]);
399    let q = sb.m().narrow_(qkv_flat, 0, 0, 1);
400    let q = sb
401        .m()
402        .reshape_(q, vec![(batch * nh) as i64, s as i64, dh as i64]);
403    let k = sb.m().narrow_(qkv_flat, 0, 1, 1);
404    let k = sb
405        .m()
406        .reshape_(k, vec![(batch * nh) as i64, s as i64, dh as i64]);
407    let v = sb.m().narrow_(qkv_flat, 0, 2, 1);
408    let v = sb
409        .m()
410        .reshape_(v, vec![(batch * nh) as i64, s as i64, dh as i64]);
411
412    // 3) attn = (q * scale) @ k.T   shape [B·nh, S, S]
413    let scale_node = scalar_param(sb, &format!("{lp}.attn._scale"), scale);
414    let q_scaled = sb.m().mul(q, scale_node);
415    let k_t = sb.m().transpose_(k, vec![0, 2, 1]); // [B·nh, dh, S]
416    let scores = sb.m().mm(q_scaled, k_t); // [B·nh, S, S]
417
418    // 4) Optionally add decomposed rel_pos.
419    let scores = if use_rel_pos {
420        // rel_pos_h: [2h-1, dh]  rel_pos_w: [2w-1, dh]
421        // We pre-resolve get_rel_pos() host-side into r_h: [h, h, dh] and
422        // r_w: [w, w, dh] indexed buffers (cheap, ≤ 27×27×64 elements).
423        let (mut r_h_data, mut r_w_data) = extract_rel_pos(w, lp, h, w_dim, dh)?;
424        // Bisect helpers:
425        //   RLX_SAM_DEBUG_ZERO_RELPOS=1  zero both r_h and r_w
426        //   RLX_SAM_DEBUG_ZERO_RELH=1    zero only r_h (keep rel_w)
427        //   RLX_SAM_DEBUG_ZERO_RELW=1    zero only r_w (keep rel_h)
428        if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELPOS") {
429            r_h_data.iter_mut().for_each(|v| *v = 0.0);
430            r_w_data.iter_mut().for_each(|v| *v = 0.0);
431        }
432        if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELH") {
433            r_h_data.iter_mut().for_each(|v| *v = 0.0);
434        }
435        if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELW") {
436            r_w_data.iter_mut().for_each(|v| *v = 0.0);
437        }
438        let r_h_node = const_param(
439            sb,
440            &format!("{lp}.attn._rel_h_indexed"),
441            &[h, h, dh],
442            r_h_data,
443        );
444        let r_w_node = const_param(
445            sb,
446            &format!("{lp}.attn._rel_w_indexed"),
447            &[w_dim, w_dim, dh],
448            r_w_data,
449        );
450        add_decomposed_rel_pos(sb, scores, q, r_h_node, r_w_node, batch, nh, h, w_dim, dh)?
451    } else {
452        scores
453    };
454
455    // 5) softmax over last axis
456    let attn_w = sb.m().sm(scores, -1);
457
458    // 6) attn @ V → [B·nh, S, dh]
459    let attn_v = sb.m().mm(attn_w, v);
460
461    // 7) Reverse the head split: [B·nh, S, dh] → [B, nh, S, dh] → [B, S, nh, dh] → [B, S, E]
462    let reshaped = sb
463        .m()
464        .reshape_(attn_v, vec![batch as i64, nh as i64, s as i64, dh as i64]);
465    let perm = sb.m().transpose_(reshaped, vec![0, 2, 1, 3]); // [B, S, nh, dh]
466    let merged = sb
467        .m()
468        .reshape_(perm, vec![batch as i64, s as i64, e as i64]);
469
470    // 8) Output projection (always biased).
471    let proj_w = load_p(sb, w, &format!("{lp}.attn.proj.weight"), true)?;
472    let proj_b = load_p(sb, w, &format!("{lp}.attn.proj.bias"), false)?;
473    let proj_mm = sb.m().mm(merged, proj_w);
474    Ok(sb.m().add(proj_mm, proj_b))
475}
476
477/// Add decomposed relative positional bias to attention scores.
478///
479/// Math (per the SAM paper, candle's `add_decomposed_rel_pos`):
480///   r_q = q.reshape(B·nh, h, w, dh)
481///   rel_h[bhw,c] = sum_c r_q[bhw,c] · r_h_indexed[hq, hk, c]    → [B·nh, h, w, h]
482///   rel_w[bhw,c] = sum_c r_q[bhw,c] · r_w_indexed[wq, wk, c]    → [B·nh, h, w, w]
483///   scores += rel_h.unsqueeze(4) + rel_w.unsqueeze(3)           → [B·nh, h, w, h, w]
484///   scores.reshape(B·nh, h·w, h·w)
485#[allow(clippy::too_many_arguments)]
486fn add_decomposed_rel_pos(
487    sb: &mut SamBuilder,
488    scores: HirNodeId, // [B·nh, S, S]
489    q: HirNodeId,      // [B·nh, S, dh]
490    r_h: HirNodeId,    // [h, h, dh]  (pre-indexed)
491    r_w: HirNodeId,    // [w, w, dh]
492    batch: usize,
493    nh: usize,
494    h: usize,
495    w: usize,
496    dh: usize,
497) -> Result<HirNodeId> {
498    let bh = batch * nh;
499    // r_q: [bh, h, w, dh]
500    let r_q = sb
501        .m()
502        .reshape_(q, vec![bh as i64, h as i64, w as i64, dh as i64]);
503
504    // rel_h: "bhwc, hkc -> bhwk".
505    // Unrolled-per-h_q: rlx-cpu's batched 3-D matmul gives subtly wrong
506    // results in this exact shape regime, so we lower the einsum to
507    // `h` independent 2-D matmuls (one per h_q index) and `sb.m().concat_`
508    // them back. Each per-h_q matmul is `[bh, w, dh] @ [dh, h_k]`,
509    // which uses the well-tested flat sgemm path (rhs has no batch
510    // dim, only the lhs does — that's the case the Sgemm flatten
511    // trick was designed for).
512    let mut rel_h_slices: Vec<HirNodeId> = Vec::with_capacity(h);
513    for h_q in 0..h {
514        // r_q at h_q: narrow axis 1, then squeeze.
515        let rq_slice = sb.m().narrow_(r_q, 1, h_q, 1); // [bh, 1, w, dh]
516        let rq_slice = sb
517            .m()
518            .reshape_(rq_slice, vec![bh as i64, w as i64, dh as i64]);
519        // r_h at h_q: narrow axis 0, then squeeze + transpose to [dh, h].
520        let rh_slice = sb.m().narrow_(r_h, 0, h_q, 1); // [1, h, dh]
521        let rh_slice = sb.m().reshape_(rh_slice, vec![h as i64, dh as i64]); // [h_k, dh]
522        let rh_t = sb.m().transpose_(rh_slice, vec![1, 0]); // [dh, h_k]
523        let mm = sb.m().mm(rq_slice, rh_t); // [bh, w, h_k]
524        // Add a leading length-1 axis so we can concat into [bh, h, w, h_k].
525        let mm5 = sb.m().reshape_(mm, vec![bh as i64, 1, w as i64, h as i64]);
526        rel_h_slices.push(mm5);
527    }
528    let rel_h_4d = sb.m().concat_(rel_h_slices, 1); // [bh, h, w, h]
529
530    // rel_w: same idea, w_q as the unrolled axis.
531    let mut rel_w_slices: Vec<HirNodeId> = Vec::with_capacity(w);
532    for w_q in 0..w {
533        let rq_slice = sb.m().narrow_(r_q, 2, w_q, 1); // [bh, h, 1, dh]
534        let rq_slice = sb
535            .m()
536            .reshape_(rq_slice, vec![bh as i64, h as i64, dh as i64]);
537        let rw_slice = sb.m().narrow_(r_w, 0, w_q, 1); // [1, w, dh]
538        let rw_slice = sb.m().reshape_(rw_slice, vec![w as i64, dh as i64]); // [w_k, dh]
539        let rw_t = sb.m().transpose_(rw_slice, vec![1, 0]); // [dh, w_k]
540        let mm = sb.m().mm(rq_slice, rw_t); // [bh, h, w_k]
541        let mm5 = sb.m().reshape_(mm, vec![bh as i64, h as i64, 1, w as i64]);
542        rel_w_slices.push(mm5);
543    }
544    let rel_w_4d = sb.m().concat_(rel_w_slices, 2); // [bh, h, w, w]
545
546    // Broadcast-add into the [bh, h, w, h, w] view of scores.
547    //
548    // History: rlx-cpu's BiasAdd misroute for mid-shape singletons is
549    // now fixed (`is_trailing_bias_broadcast`), so CPU uses simple
550    // unsqueeze+add. The rlx-metal BinaryBroadcast MSL kernel exists
551    // but produces wrong results on the SAM rel_pos pattern (suspect:
552    // setBytes alignment of inline `constant uint*` for ranks > 4 —
553    // needs focused debugging). Until then, materialise both rel
554    // tensors to the full output shape via `concat`-tile so the add
555    // is a same-shape op and works on every backend.
556    let scores_5d = sb.m().reshape_(
557        scores,
558        vec![bh as i64, h as i64, w as i64, h as i64, w as i64],
559    );
560    let rel_h_5d = sb
561        .m()
562        .reshape_(rel_h_4d, vec![bh as i64, h as i64, w as i64, h as i64, 1]);
563    let rel_h_tiled = {
564        let mut copies = Vec::with_capacity(w);
565        for _ in 0..w {
566            copies.push(rel_h_5d);
567        }
568        sb.m().concat_(copies, 4) // [bh, h, w, h, w]
569    };
570    let rel_w_5d = sb
571        .m()
572        .reshape_(rel_w_4d, vec![bh as i64, h as i64, w as i64, 1, w as i64]);
573    let rel_w_tiled = {
574        let mut copies = Vec::with_capacity(h);
575        for _ in 0..h {
576            copies.push(rel_w_5d);
577        }
578        sb.m().concat_(copies, 3) // [bh, h, w, h, w]
579    };
580    let s1 = sb.m().add(scores_5d, rel_h_tiled);
581    let s2 = sb.m().add(s1, rel_w_tiled);
582    Ok(sb
583        .m()
584        .reshape_(s2, vec![bh as i64, (h * w) as i64, (h * w) as i64]))
585}
586
587/// Resolve candle's `get_rel_pos()` host-side into per-axis bias
588/// tables of shape `[q_size, k_size, dh]` (here q_size == k_size).
589///
590/// Stored `rel_pos_h` has shape `[2·max(q,k)-1, dh]`; we gather along
591/// axis 0 using `relative_coords[i,j] = i - j + (k-1)` (since q==k,
592/// scale factors collapse to 1).
593fn extract_rel_pos(
594    weights: &mut WeightMap,
595    lp: &str,
596    h: usize,
597    w: usize,
598    dh: usize,
599) -> Result<(Vec<f32>, Vec<f32>)> {
600    let (rel_h_raw, rh_shape) = weights.take(&format!("{lp}.attn.rel_pos_h"))?;
601    let (rel_w_raw, rw_shape) = weights.take(&format!("{lp}.attn.rel_pos_w"))?;
602    ensure!(
603        rh_shape == vec![2 * h - 1, dh],
604        "{lp}.attn.rel_pos_h expected [{}, {dh}], got {rh_shape:?}",
605        2 * h - 1
606    );
607    ensure!(
608        rw_shape == vec![2 * w - 1, dh],
609        "{lp}.attn.rel_pos_w expected [{}, {dh}], got {rw_shape:?}",
610        2 * w - 1
611    );
612
613    let mut r_h = vec![0f32; h * h * dh];
614    for q in 0..h {
615        for k in 0..h {
616            let idx = (q as isize - k as isize + (h as isize - 1)) as usize;
617            let src = &rel_h_raw[idx * dh..(idx + 1) * dh];
618            let dst = &mut r_h[(q * h + k) * dh..(q * h + k + 1) * dh];
619            dst.copy_from_slice(src);
620        }
621    }
622    let mut r_w = vec![0f32; w * w * dh];
623    for q in 0..w {
624        for k in 0..w {
625            let idx = (q as isize - k as isize + (w as isize - 1)) as usize;
626            let src = &rel_w_raw[idx * dh..(idx + 1) * dh];
627            let dst = &mut r_w[(q * w + k) * dh..(q * w + k + 1) * dh];
628            dst.copy_from_slice(src);
629        }
630    }
631    Ok((r_h, r_w))
632}
633
634// ─── Neck (Conv2d 1×1 + LN2d + Conv2d 3×3 + LN2d) host-side ────────
635
636/// Weights for the four neck layers, kept on the host because rlx-ir
637/// doesn't have f32 forward Conv2d (and 3×3 padding=1 doesn't reduce
638/// to matmul).
639pub struct NeckWeights {
640    pub conv1_w: Vec<f32>, // [out_chans, embed_dim] (1×1 conv = per-channel linear)
641    pub ln1_g: Vec<f32>,   // [out_chans]
642    pub ln1_b: Vec<f32>,
643    pub conv2_w: Vec<f32>, // [out_chans, out_chans, 3, 3]
644    pub ln2_g: Vec<f32>,
645    pub ln2_b: Vec<f32>,
646    pub embed_dim: usize,
647    pub out_chans: usize,
648    pub eps: f32,
649}
650
651#[allow(dead_code)]
652fn extract_neck_weights(weights: &mut WeightMap, cfg: &SamEncoderConfig) -> Result<NeckWeights> {
653    let (conv1_w_raw, c1_shape) = weights.take("image_encoder.neck.0.weight")?;
654    ensure!(
655        c1_shape == vec![cfg.out_chans, cfg.embed_dim, 1, 1],
656        "neck.0.weight expected [{}, {}, 1, 1], got {c1_shape:?}",
657        cfg.out_chans,
658        cfg.embed_dim
659    );
660    let conv1_w = conv1_w_raw; // [out_chans, embed_dim] after flattening last two singleton dims
661    let (ln1_g, _) = weights.take("image_encoder.neck.1.weight")?;
662    let (ln1_b, _) = weights.take("image_encoder.neck.1.bias")?;
663    let (conv2_w, c2_shape) = weights.take("image_encoder.neck.2.weight")?;
664    ensure!(
665        c2_shape == vec![cfg.out_chans, cfg.out_chans, 3, 3],
666        "neck.2.weight expected [{}, {}, 3, 3], got {c2_shape:?}",
667        cfg.out_chans,
668        cfg.out_chans
669    );
670    let (ln2_g, _) = weights.take("image_encoder.neck.3.weight")?;
671    let (ln2_b, _) = weights.take("image_encoder.neck.3.bias")?;
672    Ok(NeckWeights {
673        conv1_w,
674        ln1_g,
675        ln1_b,
676        conv2_w,
677        ln2_g,
678        ln2_b,
679        embed_dim: cfg.embed_dim,
680        out_chans: cfg.out_chans,
681        eps: cfg.layer_norm_eps as f32,
682    })
683}
684
685/// Run the encoder neck on the host. `body_out` is the encoder body's
686/// output reshaped to `[hw·hw, embed_dim]` (BHWC flattened). Returns
687/// `[out_chans, hw, hw]` NCHW image embeddings.
688pub fn apply_neck_host(neck: &NeckWeights, body_out: &[f32], hw: usize) -> Vec<f32> {
689    let e = neck.embed_dim;
690    let oc = neck.out_chans;
691    let eps = neck.eps;
692
693    // 1) Conv 1×1: per-pixel linear projection from embed_dim → out_chans.
694    //    body_out is BHWC; treat as [hw·hw, embed_dim] and matmul by
695    //    conv1_w.T (i.e. `out[s, oc] = sum_e body_out[s, e] * conv1_w[oc, e]`).
696    let s = hw * hw;
697    let mut feat = vec![0f32; s * oc]; // BHWC: [hw·hw, oc]
698    for si in 0..s {
699        for oi in 0..oc {
700            let mut acc = 0f32;
701            for ei in 0..e {
702                acc += body_out[si * e + ei] * neck.conv1_w[oi * e + ei];
703            }
704            feat[si * oc + oi] = acc;
705        }
706    }
707
708    // 2) LN2d: normalize over channel dim (per spatial position).
709    layernorm2d_inplace(&mut feat, s, oc, &neck.ln1_g, &neck.ln1_b, eps);
710
711    // 3) Conv 3×3 padding=1, stride=1. We compute it in NCHW. The input
712    //    is currently BHWC = [hw·hw, oc]; convert to NCHW = [oc, hw, hw].
713    let mut nchw = vec![0f32; oc * hw * hw];
714    for y in 0..hw {
715        for x in 0..hw {
716            for c in 0..oc {
717                nchw[c * hw * hw + y * hw + x] = feat[(y * hw + x) * oc + c];
718            }
719        }
720    }
721    let conv2_out = conv2d_3x3_pad1(&nchw, oc, oc, hw, hw, &neck.conv2_w);
722
723    // 4) LN2d again. Convert back to BHWC for the LN, then back to NCHW.
724    let mut bhwc = vec![0f32; s * oc];
725    for c in 0..oc {
726        for y in 0..hw {
727            for x in 0..hw {
728                bhwc[(y * hw + x) * oc + c] = conv2_out[c * hw * hw + y * hw + x];
729            }
730        }
731    }
732    layernorm2d_inplace(&mut bhwc, s, oc, &neck.ln2_g, &neck.ln2_b, eps);
733
734    let mut out_nchw = vec![0f32; oc * hw * hw];
735    for y in 0..hw {
736        for x in 0..hw {
737            for c in 0..oc {
738                out_nchw[c * hw * hw + y * hw + x] = bhwc[(y * hw + x) * oc + c];
739            }
740        }
741    }
742    out_nchw
743}
744
745/// LN over channel dim of BHWC `[S, C]` (matches candle's LayerNorm2d).
746fn layernorm2d_inplace(data: &mut [f32], s: usize, c: usize, g: &[f32], b: &[f32], eps: f32) {
747    for si in 0..s {
748        let row = &mut data[si * c..(si + 1) * c];
749        let mean: f32 = row.iter().sum::<f32>() / c as f32;
750        let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / c as f32;
751        let inv = 1.0 / (var + eps).sqrt();
752        for k in 0..c {
753            row[k] = (row[k] - mean) * inv * g[k] + b[k];
754        }
755    }
756}
757
758/// 3×3 Conv2d with stride=1, padding=1, no bias. NCHW in, NCHW out.
759/// Reference implementation — not vectorized, fine for the SAM neck
760/// (1 call per inference, 64×64×256).
761fn conv2d_3x3_pad1(
762    input: &[f32],
763    in_c: usize,
764    out_c: usize,
765    h: usize,
766    w: usize,
767    weight: &[f32], // [out_c, in_c, 3, 3]
768) -> Vec<f32> {
769    let mut out = vec![0f32; out_c * h * w];
770    for oc in 0..out_c {
771        for y in 0..h {
772            for x in 0..w {
773                let mut acc = 0f32;
774                for ic in 0..in_c {
775                    for ky in 0..3 {
776                        let iy = y as isize + ky as isize - 1;
777                        if iy < 0 || iy >= h as isize {
778                            continue;
779                        }
780                        for kx in 0..3 {
781                            let ix = x as isize + kx as isize - 1;
782                            if ix < 0 || ix >= w as isize {
783                                continue;
784                            }
785                            let v = input[ic * h * w + iy as usize * w + ix as usize];
786                            let wi = ((oc * in_c + ic) * 3 + ky) * 3 + kx;
787                            acc += v * weight[wi];
788                        }
789                    }
790                }
791                out[oc * h * w + y * w + x] = acc;
792            }
793        }
794    }
795    out
796}
797
798// ─── Small builder helpers ─────────────────────────────────────────
799
800fn load_p(
801    sb: &mut SamBuilder,
802    weights: &mut WeightMap,
803    key: &str,
804    transpose: bool,
805) -> Result<HirNodeId> {
806    let (data, shape) = if transpose {
807        weights
808            .take_transposed(key)
809            .map_err(|e| anyhow!("transpose-load `{key}`: {e}"))?
810    } else {
811        weights
812            .take(key)
813            .map_err(|e| anyhow!("load `{key}`: {e}"))?
814    };
815    let name = key.to_string();
816    let id = sb.m().param(&name, Shape::new(&shape, DType::F32));
817    sb.params.insert(name, data);
818    Ok(id)
819}
820
821#[allow(dead_code)]
822fn scalar_param(sb: &mut SamBuilder, name: &str, value: f32) -> HirNodeId {
823    let id = sb.m().param(name, Shape::new(&[1], DType::F32));
824    sb.params.insert(name.to_string(), vec![value]);
825    id
826}
827
828fn const_param(sb: &mut SamBuilder, name: &str, shape: &[usize], data: Vec<f32>) -> HirNodeId {
829    let id = sb.m().param(name, Shape::new(shape, DType::F32));
830    sb.params.insert(name.to_string(), data);
831    id
832}
833
834fn pad_zero_param(sb: &mut SamBuilder, name: &str, shape: &[usize]) -> HirNodeId {
835    let n: usize = shape.iter().product();
836    let id = sb.m().param(name, Shape::new(shape, DType::F32));
837    sb.params.insert(name.to_string(), vec![0f32; n]);
838    id
839}