Skip to main content

rlx_sam3/
detector_encoder_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HIR-native SAM3 detector encoder (6 transformer layers).
17
18use super::detector_encoder::{Sam3EncoderLayerWeights, Sam3EncoderWeights};
19use super::packed_gguf::packed_linear;
20use anyhow::{Result, ensure};
21use rlx_flow::CompileProfile;
22use rlx_flow::{GgufPackedLinear, GgufPackedParams};
23use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
24use rlx_ir::op::{MaskKind, Op};
25use rlx_ir::shape;
26use rlx_ir::{DType, Graph, Shape};
27use rlx_runtime::{CompiledGraph, Device};
28use std::collections::HashMap;
29
30/// HIR build product (F32 norms + optional U8 packed linears).
31pub struct Sam3EncoderHirParts {
32    pub hir: HirModule,
33    pub params: HashMap<String, Vec<f32>>,
34    pub typed_params: Vec<(String, Vec<u8>, DType)>,
35}
36
37/// Compiled encoder graph + uploaded parameters, ready for many `run`s.
38pub struct Sam3CompiledEncoder {
39    pub compiled: CompiledGraph,
40    pub batch: usize,
41    pub hw: usize,
42    pub seq: usize,
43    pub d: usize,
44}
45
46impl Sam3CompiledEncoder {
47    pub fn new(
48        weights: &Sam3EncoderWeights,
49        batch: usize,
50        hw: usize,
51        seq: usize,
52        device: Device,
53    ) -> Result<Self> {
54        Self::new_with_profile(weights, batch, hw, seq, device, &CompileProfile::sam3())
55    }
56
57    pub fn new_with_profile(
58        weights: &Sam3EncoderWeights,
59        batch: usize,
60        hw: usize,
61        seq: usize,
62        device: Device,
63        profile: &CompileProfile,
64    ) -> Result<Self> {
65        Self::new_with_profile_and_gguf(weights, batch, hw, seq, device, profile, None)
66    }
67
68    pub fn new_with_profile_and_gguf(
69        weights: &Sam3EncoderWeights,
70        batch: usize,
71        hw: usize,
72        seq: usize,
73        device: Device,
74        profile: &CompileProfile,
75        gguf_packed: Option<&GgufPackedParams>,
76    ) -> Result<Self> {
77        let parts = build_encoder_hir(weights, batch, hw, seq, gguf_packed)?;
78        let mut compiled =
79            rlx_core::flow_bridge::compile_hir_with_profile(device, parts.hir, profile)?;
80        rlx_core::flow_util::attach_built_params(&mut compiled, parts.params, &parts.typed_params);
81        Ok(Self {
82            compiled,
83            batch,
84            hw,
85            seq,
86            d: D_MODEL,
87        })
88    }
89
90    #[allow(clippy::too_many_arguments)]
91    pub fn run(
92        &mut self,
93        src_bchw: &[f32],
94        src_pos_bchw: &[f32],
95        prompt_seq_first: &[f32],
96        prompt_kpm: &[u8],
97        src_h: usize,
98        src_w: usize,
99    ) -> Result<Vec<f32>> {
100        let hw = src_h * src_w;
101        ensure!(
102            hw == self.hw,
103            "compiled encoder expects hw={}, got {hw}",
104            self.hw
105        );
106        let mut src_bhwc = vec![0f32; self.batch * hw * self.d];
107        let mut pos_bhwc = vec![0f32; self.batch * hw * self.d];
108        for b in 0..self.batch {
109            for s in 0..hw {
110                for c in 0..self.d {
111                    src_bhwc[(b * hw + s) * self.d + c] = src_bchw[((b * self.d + c) * hw) + s];
112                    pos_bhwc[(b * hw + s) * self.d + c] = src_pos_bchw[((b * self.d + c) * hw) + s];
113                }
114            }
115        }
116        let mut prompt_bf = vec![0f32; self.batch * self.seq * self.d];
117        for b in 0..self.batch {
118            for l in 0..self.seq {
119                let s = (l * self.batch + b) * self.d;
120                let dst = (b * self.seq + l) * self.d;
121                prompt_bf[dst..dst + self.d].copy_from_slice(&prompt_seq_first[s..s + self.d]);
122            }
123        }
124        let prompt_kpm_inv: Vec<f32> = prompt_kpm
125            .iter()
126            .map(|&v| if v == 0 { 1.0 } else { 0.0 })
127            .collect();
128        let outputs = self.compiled.run(&[
129            ("src", src_bhwc.as_slice()),
130            ("src_pos", pos_bhwc.as_slice()),
131            ("prompt", prompt_bf.as_slice()),
132            ("prompt_kpm_inv", prompt_kpm_inv.as_slice()),
133        ]);
134        outputs
135            .into_iter()
136            .next()
137            .ok_or_else(|| anyhow::anyhow!("encoder graph produced no outputs"))
138    }
139}
140
141const D_MODEL: usize = 256;
142const DIM_FF: usize = 2048;
143const N_HEADS: usize = 8;
144const HEAD_DIM: usize = D_MODEL / N_HEADS;
145
146fn enc_layer_key(base: &str, li: usize, suffix: &str) -> String {
147    format!("{base}.layers.{li}.{suffix}")
148}
149
150fn gguf_weight_param(
151    g: &mut HirMut<'_>,
152    typed: &mut Vec<(String, Vec<u8>, DType)>,
153    cache: &mut HashMap<String, HirNodeId>,
154    ir_name: &str,
155    p: &GgufPackedLinear,
156) -> HirNodeId {
157    if let Some(&id) = cache.get(ir_name) {
158        return id;
159    }
160    let id = g.param(ir_name, Shape::new(&[p.w_q.len()], DType::U8));
161    typed.push((ir_name.to_string(), p.w_q.clone(), DType::U8));
162    cache.insert(ir_name.to_string(), id);
163    id
164}
165
166fn linear_gguf_matmul(
167    g: &mut HirMut<'_>,
168    typed: &mut Vec<(String, Vec<u8>, DType)>,
169    cache: &mut HashMap<String, HirNodeId>,
170    ir_stem: &str,
171    p: &GgufPackedLinear,
172    input: HirNodeId,
173    in_dim: usize,
174    out_dim: usize,
175) -> Result<HirNodeId> {
176    ensure!(
177        p.in_dim == in_dim && p.out_dim == out_dim,
178        "packed linear {ir_stem}: shape {}x{} vs {in_dim}x{out_dim}",
179        p.in_dim,
180        p.out_dim
181    );
182    let w_name = format!("{ir_stem}.w");
183    let w_id = gguf_weight_param(g, typed, cache, &w_name, p);
184    let cur = g.shape(input);
185    let mut dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
186    *dims.last_mut().unwrap() = out_dim;
187    let out_shape = Shape::new(&dims, DType::F32);
188    Ok(g.add_node(
189        Op::DequantMatMul { scheme: p.scheme },
190        vec![input, w_id],
191        out_shape,
192    ))
193}
194
195fn add_f32_bias(
196    g: &mut HirMut<'_>,
197    params: &mut HashMap<String, Vec<f32>>,
198    name: &str,
199    input: HirNodeId,
200    bias: &[f32],
201) -> HirNodeId {
202    if bias.iter().all(|&v| v == 0.0) {
203        return input;
204    }
205    let out_dim = bias.len();
206    let b_id = add_param(g, name, bias.to_vec(), Shape::new(&[out_dim], DType::F32));
207    params.insert(name.to_string(), bias.to_vec());
208    g.add(input, b_id)
209}
210
211fn linear_gguf_bias(
212    g: &mut HirMut<'_>,
213    params: &mut HashMap<String, Vec<f32>>,
214    typed: &mut Vec<(String, Vec<u8>, DType)>,
215    cache: &mut HashMap<String, HirNodeId>,
216    ir_stem: &str,
217    p: &GgufPackedLinear,
218    input: HirNodeId,
219    bias: &[f32],
220    in_dim: usize,
221    out_dim: usize,
222) -> Result<HirNodeId> {
223    let y = linear_gguf_matmul(g, typed, cache, ir_stem, p, input, in_dim, out_dim)?;
224    Ok(add_f32_bias(g, params, &format!("{ir_stem}.b"), y, bias))
225}
226
227fn in_proj_qkv(
228    g: &mut HirMut<'_>,
229    params: &mut HashMap<String, Vec<f32>>,
230    typed: &mut Vec<(String, Vec<u8>, DType)>,
231    cache: &mut HashMap<String, HirNodeId>,
232    gguf_packed: Option<&GgufPackedParams>,
233    gguf_key: &str,
234    ir_stem: &str,
235    layer_w_t: &[f32],
236    layer_b: &[f32],
237    input_q: HirNodeId,
238    input_k: HirNodeId,
239    input_v: HirNodeId,
240    d: usize,
241) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
242    if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
243        let qkv_q = linear_gguf_bias(
244            g,
245            params,
246            typed,
247            cache,
248            ir_stem,
249            p,
250            input_q,
251            layer_b,
252            d,
253            3 * d,
254        )?;
255        let qkv_k = linear_gguf_bias(
256            g,
257            params,
258            typed,
259            cache,
260            ir_stem,
261            p,
262            input_k,
263            layer_b,
264            d,
265            3 * d,
266        )?;
267        let qkv_v = linear_gguf_bias(
268            g,
269            params,
270            typed,
271            cache,
272            ir_stem,
273            p,
274            input_v,
275            layer_b,
276            d,
277            3 * d,
278        )?;
279        let axis = g.shape(qkv_q).rank().saturating_sub(1);
280        let q = g.narrow_(qkv_q, axis, 0, d);
281        let k = g.narrow_(qkv_k, axis, d, d);
282        let v = g.narrow_(qkv_v, axis, 2 * d, d);
283        return Ok((q, k, v));
284    }
285    let (wq, wk, wv) = split_qkv(layer_w_t, d);
286    let bq = layer_b[0..d].to_vec();
287    let bk = layer_b[d..2 * d].to_vec();
288    let bv = layer_b[2 * d..3 * d].to_vec();
289    let batch_q = g.shape(input_q).dims()[0].unwrap_static();
290    let seq_q = g.shape(input_q).dims()[1].unwrap_static();
291    let batch_k = g.shape(input_k).dims()[0].unwrap_static();
292    let seq_k = g.shape(input_k).dims()[1].unwrap_static();
293    let batch_v = g.shape(input_v).dims()[0].unwrap_static();
294    let seq_v = g.shape(input_v).dims()[1].unwrap_static();
295    let q = qkv_linear(
296        g,
297        params,
298        &format!("{ir_stem}.q"),
299        input_q,
300        wq,
301        bq,
302        batch_q,
303        seq_q,
304        d,
305    );
306    let k = qkv_linear(
307        g,
308        params,
309        &format!("{ir_stem}.k"),
310        input_k,
311        wk,
312        bk,
313        batch_k,
314        seq_k,
315        d,
316    );
317    let v = qkv_linear(
318        g,
319        params,
320        &format!("{ir_stem}.v"),
321        input_v,
322        wv,
323        bv,
324        batch_v,
325        seq_v,
326        d,
327    );
328    Ok((q, k, v))
329}
330
331fn linear_fused_or_gguf(
332    g: &mut HirMut<'_>,
333    params: &mut HashMap<String, Vec<f32>>,
334    typed: &mut Vec<(String, Vec<u8>, DType)>,
335    cache: &mut HashMap<String, HirNodeId>,
336    gguf_packed: Option<&GgufPackedParams>,
337    gguf_key: &str,
338    ir_stem: &str,
339    input: HirNodeId,
340    w_t: Vec<f32>,
341    bias: Vec<f32>,
342    in_dim: usize,
343    out_dim: usize,
344) -> Result<HirNodeId> {
345    if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
346        return linear_gguf_bias(
347            g, params, typed, cache, ir_stem, p, input, &bias, in_dim, out_dim,
348        );
349    }
350    Ok(linear_with_bias(
351        g, params, ir_stem, input, w_t, bias, in_dim, out_dim,
352    ))
353}
354
355fn split_qkv(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
356    let mut wq = vec![0f32; e * e];
357    let mut wk = vec![0f32; e * e];
358    let mut wv = vec![0f32; e * e];
359    for i in 0..e {
360        for j in 0..e {
361            wq[i * e + j] = w_t[i * 3 * e + j];
362            wk[i * e + j] = w_t[i * 3 * e + e + j];
363            wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
364        }
365    }
366    (wq, wk, wv)
367}
368
369fn add_param(g: &mut HirMut<'_>, name: &str, _data: Vec<f32>, shape: Shape) -> HirNodeId {
370    g.param(name, shape)
371}
372
373/// Lower encoder HIR to legacy [`Graph`] (via [`super::flow::Sam3DetectorEncoderFlow`]).
374pub fn build_sam3_detector_encoder_graph(
375    weights: &Sam3EncoderWeights,
376    batch: usize,
377    hw: usize,
378    seq: usize,
379) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
380    let parts = build_encoder_hir(weights, batch, hw, seq, None)?;
381    rlx_core::flow_util::graph_from_hir(parts.hir, parts.params)
382}
383
384/// Build native HIR encoder module + parameter blobs.
385pub fn build_encoder_hir(
386    weights: &Sam3EncoderWeights,
387    batch: usize,
388    hw: usize,
389    seq: usize,
390    gguf_packed: Option<&GgufPackedParams>,
391) -> Result<Sam3EncoderHirParts> {
392    let mut hir = HirModule::new("sam3_detector_encoder");
393    let mut g = HirMut::new(&mut hir);
394    let mut params: HashMap<String, Vec<f32>> = HashMap::new();
395    let mut typed_params = Vec::new();
396    let mut gguf_w_cache: HashMap<String, HirNodeId> = HashMap::new();
397    let f = DType::F32;
398    let d = D_MODEL;
399    let enc_base = &weights.prefix;
400
401    let src = g.input("src", Shape::new(&[batch, hw, d], f));
402    let src_pos = g.input("src_pos", Shape::new(&[batch, hw, d], f));
403    let prompt = g.input("prompt", Shape::new(&[batch, seq, d], f));
404    let prompt_kpm_inv = g.input("prompt_kpm_inv", Shape::new(&[batch, seq], f));
405
406    let mut tgt = src;
407    for (li, layer) in weights.layers.iter().enumerate() {
408        tgt = emit_sam3_detector_encoder_layer(
409            &mut g,
410            &mut params,
411            &mut typed_params,
412            &mut gguf_w_cache,
413            gguf_packed,
414            enc_base,
415            li,
416            layer,
417            batch,
418            hw,
419            seq,
420            tgt,
421            src_pos,
422            prompt,
423            prompt_kpm_inv,
424        )?;
425    }
426    g.set_outputs(vec![tgt]);
427    Ok(Sam3EncoderHirParts {
428        hir,
429        params,
430        typed_params,
431    })
432}
433
434/// One SAM3 fusion encoder layer (self-attn + cross-attn + FFN).
435#[allow(clippy::too_many_arguments)]
436pub fn emit_sam3_detector_encoder_layer(
437    g: &mut HirMut<'_>,
438    params: &mut HashMap<String, Vec<f32>>,
439    typed_params: &mut Vec<(String, Vec<u8>, DType)>,
440    gguf_w_cache: &mut HashMap<String, HirNodeId>,
441    gguf_packed: Option<&GgufPackedParams>,
442    enc_base: &str,
443    li: usize,
444    layer: &Sam3EncoderLayerWeights,
445    _batch: usize,
446    _hw: usize,
447    _seq: usize,
448    tgt: HirNodeId,
449    src_pos: HirNodeId,
450    prompt: HirNodeId,
451    prompt_kpm_inv: HirNodeId,
452) -> Result<HirNodeId> {
453    let f = DType::F32;
454    let d = D_MODEL;
455    let nh = N_HEADS;
456    let dh = HEAD_DIM;
457    let dim_ff = DIM_FF;
458
459    let n1_w = add_param(
460        g,
461        &format!("l{li}.norm1.w"),
462        layer.norm1_w.clone(),
463        Shape::new(&[d], f),
464    );
465    params.insert(format!("l{li}.norm1.w"), layer.norm1_w.clone());
466    let n1_b = add_param(
467        g,
468        &format!("l{li}.norm1.b"),
469        layer.norm1_b.clone(),
470        Shape::new(&[d], f),
471    );
472    params.insert(format!("l{li}.norm1.b"), layer.norm1_b.clone());
473    let n1 = g.ln(tgt, n1_w, n1_b, 1e-5);
474
475    let qk_in = g.add(n1, src_pos);
476
477    let (q_node, k_node, v_node) = in_proj_qkv(
478        g,
479        params,
480        typed_params,
481        gguf_w_cache,
482        gguf_packed,
483        &enc_layer_key(enc_base, li, "self_attn.in_proj_weight"),
484        &format!("l{li}.sa.in_proj"),
485        &layer.self_attn_in_w_t,
486        &layer.self_attn_in_b,
487        qk_in,
488        qk_in,
489        n1,
490        d,
491    )?;
492
493    let sa_attn = g.attention_kind(
494        q_node,
495        k_node,
496        v_node,
497        nh,
498        dh,
499        MaskKind::None,
500        shape::attention_shape(g.shape(q_node)),
501    );
502    let sa_out = linear_fused_or_gguf(
503        g,
504        params,
505        typed_params,
506        gguf_w_cache,
507        gguf_packed,
508        &enc_layer_key(enc_base, li, "self_attn.out_proj.weight"),
509        &format!("l{li}.sa.proj"),
510        sa_attn,
511        layer.self_attn_out_w_t.clone(),
512        layer.self_attn_out_b.clone(),
513        d,
514        d,
515    )?;
516    let mut tgt = g.add(tgt, sa_out);
517
518    let n2_w = add_param(
519        g,
520        &format!("l{li}.norm2.w"),
521        layer.norm2_w.clone(),
522        Shape::new(&[d], f),
523    );
524    params.insert(format!("l{li}.norm2.w"), layer.norm2_w.clone());
525    let n2_b = add_param(
526        g,
527        &format!("l{li}.norm2.b"),
528        layer.norm2_b.clone(),
529        Shape::new(&[d], f),
530    );
531    params.insert(format!("l{li}.norm2.b"), layer.norm2_b.clone());
532    let n2 = g.ln(tgt, n2_w, n2_b, 1e-5);
533
534    let (qc, kc, vc) = in_proj_qkv(
535        g,
536        params,
537        typed_params,
538        gguf_w_cache,
539        gguf_packed,
540        &enc_layer_key(enc_base, li, "cross_attn_image.in_proj_weight"),
541        &format!("l{li}.ca.in_proj"),
542        &layer.cross_attn_in_w_t,
543        &layer.cross_attn_in_b,
544        n2,
545        prompt,
546        prompt,
547        d,
548    )?;
549
550    let ca_attn = g.attention(
551        qc,
552        kc,
553        vc,
554        prompt_kpm_inv,
555        nh,
556        dh,
557        shape::attention_shape(g.shape(qc)),
558    );
559    let ca_out = linear_fused_or_gguf(
560        g,
561        params,
562        typed_params,
563        gguf_w_cache,
564        gguf_packed,
565        &enc_layer_key(enc_base, li, "cross_attn_image.out_proj.weight"),
566        &format!("l{li}.ca.proj"),
567        ca_attn,
568        layer.cross_attn_out_w_t.clone(),
569        layer.cross_attn_out_b.clone(),
570        d,
571        d,
572    )?;
573    tgt = g.add(tgt, ca_out);
574
575    let n3_w = add_param(
576        g,
577        &format!("l{li}.norm3.w"),
578        layer.norm3_w.clone(),
579        Shape::new(&[d], f),
580    );
581    params.insert(format!("l{li}.norm3.w"), layer.norm3_w.clone());
582    let n3_b = add_param(
583        g,
584        &format!("l{li}.norm3.b"),
585        layer.norm3_b.clone(),
586        Shape::new(&[d], f),
587    );
588    params.insert(format!("l{li}.norm3.b"), layer.norm3_b.clone());
589    let n3 = g.ln(tgt, n3_w, n3_b, 1e-5);
590
591    let ff1 = linear_fused_or_gguf(
592        g,
593        params,
594        typed_params,
595        gguf_w_cache,
596        gguf_packed,
597        &enc_layer_key(enc_base, li, "linear1.weight"),
598        &format!("l{li}.ffn.fc1"),
599        n3,
600        layer.linear1_w_t.clone(),
601        layer.linear1_b.clone(),
602        d,
603        dim_ff,
604    )?;
605    let relud = g.relu(ff1);
606    let ff2 = linear_fused_or_gguf(
607        g,
608        params,
609        typed_params,
610        gguf_w_cache,
611        gguf_packed,
612        &enc_layer_key(enc_base, li, "linear2.weight"),
613        &format!("l{li}.ffn.fc2"),
614        relud,
615        layer.linear2_w_t.clone(),
616        layer.linear2_b.clone(),
617        dim_ff,
618        d,
619    )?;
620    Ok(g.add(tgt, ff2))
621}
622
623fn qkv_linear(
624    g: &mut HirMut<'_>,
625    params: &mut HashMap<String, Vec<f32>>,
626    name: &str,
627    input: HirNodeId,
628    w: Vec<f32>,
629    b: Vec<f32>,
630    batch: usize,
631    seq: usize,
632    d: usize,
633) -> HirNodeId {
634    let f = DType::F32;
635    let w_name = format!("{name}.w");
636    let b_name = format!("{name}.b");
637    let w_id = g.param(&w_name, Shape::new(&[d, d], f));
638    params.insert(w_name, w);
639    let b_id = g.param(&b_name, Shape::new(&[d], f));
640    params.insert(b_name, b);
641    let out_shape = Shape::new(&[batch, seq, d], f);
642    g.add_node(
643        Op::FusedMatMulBiasAct { activation: None },
644        vec![input, w_id, b_id],
645        out_shape,
646    )
647}
648
649fn linear_with_bias(
650    g: &mut HirMut<'_>,
651    params: &mut HashMap<String, Vec<f32>>,
652    name: &str,
653    input: HirNodeId,
654    w: Vec<f32>,
655    b: Vec<f32>,
656    in_dim: usize,
657    out_dim: usize,
658) -> HirNodeId {
659    let f = DType::F32;
660    let w_name = format!("{name}.w");
661    let b_name = format!("{name}.b");
662    let w_id = g.param(&w_name, Shape::new(&[in_dim, out_dim], f));
663    params.insert(w_name, w);
664    let b_id = g.param(&b_name, Shape::new(&[out_dim], f));
665    params.insert(b_name, b);
666    let cur_shape = g.shape(input);
667    let mut out_dims: Vec<usize> = cur_shape.dims().iter().map(|d| d.unwrap_static()).collect();
668    *out_dims.last_mut().unwrap() = out_dim;
669    g.add_node(
670        Op::FusedMatMulBiasAct { activation: None },
671        vec![input, w_id, b_id],
672        Shape::new(&out_dims, f),
673    )
674}
675
676#[allow(clippy::too_many_arguments)]
677pub fn forward_encoder_ir_on(
678    weights: &Sam3EncoderWeights,
679    src_bchw: &[f32],
680    src_pos_bchw: &[f32],
681    prompt_seq_first: &[f32],
682    prompt_kpm: &[u8],
683    batch: usize,
684    src_h: usize,
685    src_w: usize,
686    prompt_len: usize,
687    device: Device,
688) -> Result<Vec<f32>> {
689    forward_encoder_ir_on_with_profile(
690        weights,
691        src_bchw,
692        src_pos_bchw,
693        prompt_seq_first,
694        prompt_kpm,
695        batch,
696        src_h,
697        src_w,
698        prompt_len,
699        device,
700        &CompileProfile::sam3(),
701        None,
702    )
703}
704
705/// Same as [`forward_encoder_ir_on`] with an explicit tier-1 profile.
706#[allow(clippy::too_many_arguments)]
707pub fn forward_encoder_ir_on_with_profile(
708    weights: &Sam3EncoderWeights,
709    src_bchw: &[f32],
710    src_pos_bchw: &[f32],
711    prompt_seq_first: &[f32],
712    prompt_kpm: &[u8],
713    batch: usize,
714    src_h: usize,
715    src_w: usize,
716    prompt_len: usize,
717    device: Device,
718    profile: &CompileProfile,
719    gguf_packed: Option<&GgufPackedParams>,
720) -> Result<Vec<f32>> {
721    ensure!(weights.loaded, "SAM3 detector encoder not loaded");
722    let hw = src_h * src_w;
723    ensure!(
724        src_bchw.len() == batch * D_MODEL * hw,
725        "encoder src shape mismatch"
726    );
727    ensure!(
728        prompt_seq_first.len() == prompt_len * batch * D_MODEL,
729        "encoder prompt shape mismatch"
730    );
731
732    let mut src_bhwc = vec![0f32; batch * hw * D_MODEL];
733    let mut pos_bhwc = vec![0f32; batch * hw * D_MODEL];
734    for b in 0..batch {
735        for s in 0..hw {
736            for c in 0..D_MODEL {
737                src_bhwc[(b * hw + s) * D_MODEL + c] = src_bchw[((b * D_MODEL + c) * hw) + s];
738                pos_bhwc[(b * hw + s) * D_MODEL + c] = src_pos_bchw[((b * D_MODEL + c) * hw) + s];
739            }
740        }
741    }
742
743    let mut prompt_bf = vec![0f32; batch * prompt_len * D_MODEL];
744    for b in 0..batch {
745        for l in 0..prompt_len {
746            let s = (l * batch + b) * D_MODEL;
747            let dst = (b * prompt_len + l) * D_MODEL;
748            prompt_bf[dst..dst + D_MODEL].copy_from_slice(&prompt_seq_first[s..s + D_MODEL]);
749        }
750    }
751    let prompt_kpm_inv: Vec<f32> = prompt_kpm
752        .iter()
753        .map(|&v| if v == 0 { 1.0 } else { 0.0 })
754        .collect();
755
756    let parts = build_encoder_hir(weights, batch, hw, prompt_len, gguf_packed)?;
757    let mut compiled = rlx_core::flow_bridge::compile_hir_with_profile(device, parts.hir, profile)?;
758    rlx_core::flow_util::attach_built_params(&mut compiled, parts.params, &parts.typed_params);
759    let outputs = compiled.run(&[
760        ("src", src_bhwc.as_slice()),
761        ("src_pos", pos_bhwc.as_slice()),
762        ("prompt", prompt_bf.as_slice()),
763        ("prompt_kpm_inv", prompt_kpm_inv.as_slice()),
764    ]);
765    let out = outputs
766        .into_iter()
767        .next()
768        .ok_or_else(|| anyhow::anyhow!("encoder graph produced no outputs"))?;
769    Ok(out)
770}
771
772#[allow(clippy::too_many_arguments)]
773pub fn forward_encoder_ir(
774    weights: &Sam3EncoderWeights,
775    src_bchw: &[f32],
776    src_pos_bchw: &[f32],
777    prompt_seq_first: &[f32],
778    prompt_kpm: &[u8],
779    batch: usize,
780    src_h: usize,
781    src_w: usize,
782    prompt_len: usize,
783) -> Result<Vec<f32>> {
784    forward_encoder_ir_on_with_profile(
785        weights,
786        src_bchw,
787        src_pos_bchw,
788        prompt_seq_first,
789        prompt_kpm,
790        batch,
791        src_h,
792        src_w,
793        prompt_len,
794        Device::Cpu,
795        &CompileProfile::sam3(),
796        None,
797    )
798}