1use super::detector_encoder::{Sam3EncoderLayerWeights, Sam3EncoderWeights};
19use super::packed_gguf::packed_linear;
20use anyhow::{Result, ensure};
21use rlx_flow::CompileProfile;
22use rlx_flow::{GgufPackedLinear, GgufPackedParams};
23use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
24use rlx_ir::op::{MaskKind, Op};
25use rlx_ir::shape;
26use rlx_ir::{DType, Graph, Shape};
27use rlx_runtime::{CompiledGraph, Device};
28use std::collections::HashMap;
29
30pub struct Sam3EncoderHirParts {
32 pub hir: HirModule,
33 pub params: HashMap<String, Vec<f32>>,
34 pub typed_params: Vec<(String, Vec<u8>, DType)>,
35}
36
37pub struct Sam3CompiledEncoder {
39 pub compiled: CompiledGraph,
40 pub batch: usize,
41 pub hw: usize,
42 pub seq: usize,
43 pub d: usize,
44}
45
46impl Sam3CompiledEncoder {
47 pub fn new(
48 weights: &Sam3EncoderWeights,
49 batch: usize,
50 hw: usize,
51 seq: usize,
52 device: Device,
53 ) -> Result<Self> {
54 Self::new_with_profile(weights, batch, hw, seq, device, &CompileProfile::sam3())
55 }
56
57 pub fn new_with_profile(
58 weights: &Sam3EncoderWeights,
59 batch: usize,
60 hw: usize,
61 seq: usize,
62 device: Device,
63 profile: &CompileProfile,
64 ) -> Result<Self> {
65 Self::new_with_profile_and_gguf(weights, batch, hw, seq, device, profile, None)
66 }
67
68 pub fn new_with_profile_and_gguf(
69 weights: &Sam3EncoderWeights,
70 batch: usize,
71 hw: usize,
72 seq: usize,
73 device: Device,
74 profile: &CompileProfile,
75 gguf_packed: Option<&GgufPackedParams>,
76 ) -> Result<Self> {
77 let parts = build_encoder_hir(weights, batch, hw, seq, gguf_packed)?;
78 let mut compiled =
79 rlx_core::flow_bridge::compile_hir_with_profile(device, parts.hir, profile)?;
80 rlx_core::flow_util::attach_built_params(&mut compiled, parts.params, &parts.typed_params);
81 Ok(Self {
82 compiled,
83 batch,
84 hw,
85 seq,
86 d: D_MODEL,
87 })
88 }
89
90 #[allow(clippy::too_many_arguments)]
91 pub fn run(
92 &mut self,
93 src_bchw: &[f32],
94 src_pos_bchw: &[f32],
95 prompt_seq_first: &[f32],
96 prompt_kpm: &[u8],
97 src_h: usize,
98 src_w: usize,
99 ) -> Result<Vec<f32>> {
100 let hw = src_h * src_w;
101 ensure!(
102 hw == self.hw,
103 "compiled encoder expects hw={}, got {hw}",
104 self.hw
105 );
106 let mut src_bhwc = vec![0f32; self.batch * hw * self.d];
107 let mut pos_bhwc = vec![0f32; self.batch * hw * self.d];
108 for b in 0..self.batch {
109 for s in 0..hw {
110 for c in 0..self.d {
111 src_bhwc[(b * hw + s) * self.d + c] = src_bchw[((b * self.d + c) * hw) + s];
112 pos_bhwc[(b * hw + s) * self.d + c] = src_pos_bchw[((b * self.d + c) * hw) + s];
113 }
114 }
115 }
116 let mut prompt_bf = vec![0f32; self.batch * self.seq * self.d];
117 for b in 0..self.batch {
118 for l in 0..self.seq {
119 let s = (l * self.batch + b) * self.d;
120 let dst = (b * self.seq + l) * self.d;
121 prompt_bf[dst..dst + self.d].copy_from_slice(&prompt_seq_first[s..s + self.d]);
122 }
123 }
124 let prompt_kpm_inv: Vec<f32> = prompt_kpm
125 .iter()
126 .map(|&v| if v == 0 { 1.0 } else { 0.0 })
127 .collect();
128 let outputs = self.compiled.run(&[
129 ("src", src_bhwc.as_slice()),
130 ("src_pos", pos_bhwc.as_slice()),
131 ("prompt", prompt_bf.as_slice()),
132 ("prompt_kpm_inv", prompt_kpm_inv.as_slice()),
133 ]);
134 outputs
135 .into_iter()
136 .next()
137 .ok_or_else(|| anyhow::anyhow!("encoder graph produced no outputs"))
138 }
139}
140
141const D_MODEL: usize = 256;
142const DIM_FF: usize = 2048;
143const N_HEADS: usize = 8;
144const HEAD_DIM: usize = D_MODEL / N_HEADS;
145
146fn enc_layer_key(base: &str, li: usize, suffix: &str) -> String {
147 format!("{base}.layers.{li}.{suffix}")
148}
149
150fn gguf_weight_param(
151 g: &mut HirMut<'_>,
152 typed: &mut Vec<(String, Vec<u8>, DType)>,
153 cache: &mut HashMap<String, HirNodeId>,
154 ir_name: &str,
155 p: &GgufPackedLinear,
156) -> HirNodeId {
157 if let Some(&id) = cache.get(ir_name) {
158 return id;
159 }
160 let id = g.param(ir_name, Shape::new(&[p.w_q.len()], DType::U8));
161 typed.push((ir_name.to_string(), p.w_q.clone(), DType::U8));
162 cache.insert(ir_name.to_string(), id);
163 id
164}
165
166fn linear_gguf_matmul(
167 g: &mut HirMut<'_>,
168 typed: &mut Vec<(String, Vec<u8>, DType)>,
169 cache: &mut HashMap<String, HirNodeId>,
170 ir_stem: &str,
171 p: &GgufPackedLinear,
172 input: HirNodeId,
173 in_dim: usize,
174 out_dim: usize,
175) -> Result<HirNodeId> {
176 ensure!(
177 p.in_dim == in_dim && p.out_dim == out_dim,
178 "packed linear {ir_stem}: shape {}x{} vs {in_dim}x{out_dim}",
179 p.in_dim,
180 p.out_dim
181 );
182 let w_name = format!("{ir_stem}.w");
183 let w_id = gguf_weight_param(g, typed, cache, &w_name, p);
184 let cur = g.shape(input);
185 let mut dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
186 *dims.last_mut().unwrap() = out_dim;
187 let out_shape = Shape::new(&dims, DType::F32);
188 Ok(g.add_node(
189 Op::DequantMatMul { scheme: p.scheme },
190 vec![input, w_id],
191 out_shape,
192 ))
193}
194
195fn add_f32_bias(
196 g: &mut HirMut<'_>,
197 params: &mut HashMap<String, Vec<f32>>,
198 name: &str,
199 input: HirNodeId,
200 bias: &[f32],
201) -> HirNodeId {
202 if bias.iter().all(|&v| v == 0.0) {
203 return input;
204 }
205 let out_dim = bias.len();
206 let b_id = add_param(g, name, bias.to_vec(), Shape::new(&[out_dim], DType::F32));
207 params.insert(name.to_string(), bias.to_vec());
208 g.add(input, b_id)
209}
210
211fn linear_gguf_bias(
212 g: &mut HirMut<'_>,
213 params: &mut HashMap<String, Vec<f32>>,
214 typed: &mut Vec<(String, Vec<u8>, DType)>,
215 cache: &mut HashMap<String, HirNodeId>,
216 ir_stem: &str,
217 p: &GgufPackedLinear,
218 input: HirNodeId,
219 bias: &[f32],
220 in_dim: usize,
221 out_dim: usize,
222) -> Result<HirNodeId> {
223 let y = linear_gguf_matmul(g, typed, cache, ir_stem, p, input, in_dim, out_dim)?;
224 Ok(add_f32_bias(g, params, &format!("{ir_stem}.b"), y, bias))
225}
226
227fn in_proj_qkv(
228 g: &mut HirMut<'_>,
229 params: &mut HashMap<String, Vec<f32>>,
230 typed: &mut Vec<(String, Vec<u8>, DType)>,
231 cache: &mut HashMap<String, HirNodeId>,
232 gguf_packed: Option<&GgufPackedParams>,
233 gguf_key: &str,
234 ir_stem: &str,
235 layer_w_t: &[f32],
236 layer_b: &[f32],
237 input_q: HirNodeId,
238 input_k: HirNodeId,
239 input_v: HirNodeId,
240 d: usize,
241) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
242 if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
243 let qkv_q = linear_gguf_bias(
244 g,
245 params,
246 typed,
247 cache,
248 ir_stem,
249 p,
250 input_q,
251 layer_b,
252 d,
253 3 * d,
254 )?;
255 let qkv_k = linear_gguf_bias(
256 g,
257 params,
258 typed,
259 cache,
260 ir_stem,
261 p,
262 input_k,
263 layer_b,
264 d,
265 3 * d,
266 )?;
267 let qkv_v = linear_gguf_bias(
268 g,
269 params,
270 typed,
271 cache,
272 ir_stem,
273 p,
274 input_v,
275 layer_b,
276 d,
277 3 * d,
278 )?;
279 let axis = g.shape(qkv_q).rank().saturating_sub(1);
280 let q = g.narrow_(qkv_q, axis, 0, d);
281 let k = g.narrow_(qkv_k, axis, d, d);
282 let v = g.narrow_(qkv_v, axis, 2 * d, d);
283 return Ok((q, k, v));
284 }
285 let (wq, wk, wv) = split_qkv(layer_w_t, d);
286 let bq = layer_b[0..d].to_vec();
287 let bk = layer_b[d..2 * d].to_vec();
288 let bv = layer_b[2 * d..3 * d].to_vec();
289 let batch_q = g.shape(input_q).dims()[0].unwrap_static();
290 let seq_q = g.shape(input_q).dims()[1].unwrap_static();
291 let batch_k = g.shape(input_k).dims()[0].unwrap_static();
292 let seq_k = g.shape(input_k).dims()[1].unwrap_static();
293 let batch_v = g.shape(input_v).dims()[0].unwrap_static();
294 let seq_v = g.shape(input_v).dims()[1].unwrap_static();
295 let q = qkv_linear(
296 g,
297 params,
298 &format!("{ir_stem}.q"),
299 input_q,
300 wq,
301 bq,
302 batch_q,
303 seq_q,
304 d,
305 );
306 let k = qkv_linear(
307 g,
308 params,
309 &format!("{ir_stem}.k"),
310 input_k,
311 wk,
312 bk,
313 batch_k,
314 seq_k,
315 d,
316 );
317 let v = qkv_linear(
318 g,
319 params,
320 &format!("{ir_stem}.v"),
321 input_v,
322 wv,
323 bv,
324 batch_v,
325 seq_v,
326 d,
327 );
328 Ok((q, k, v))
329}
330
331fn linear_fused_or_gguf(
332 g: &mut HirMut<'_>,
333 params: &mut HashMap<String, Vec<f32>>,
334 typed: &mut Vec<(String, Vec<u8>, DType)>,
335 cache: &mut HashMap<String, HirNodeId>,
336 gguf_packed: Option<&GgufPackedParams>,
337 gguf_key: &str,
338 ir_stem: &str,
339 input: HirNodeId,
340 w_t: Vec<f32>,
341 bias: Vec<f32>,
342 in_dim: usize,
343 out_dim: usize,
344) -> Result<HirNodeId> {
345 if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
346 return linear_gguf_bias(
347 g, params, typed, cache, ir_stem, p, input, &bias, in_dim, out_dim,
348 );
349 }
350 Ok(linear_with_bias(
351 g, params, ir_stem, input, w_t, bias, in_dim, out_dim,
352 ))
353}
354
355fn split_qkv(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
356 let mut wq = vec![0f32; e * e];
357 let mut wk = vec![0f32; e * e];
358 let mut wv = vec![0f32; e * e];
359 for i in 0..e {
360 for j in 0..e {
361 wq[i * e + j] = w_t[i * 3 * e + j];
362 wk[i * e + j] = w_t[i * 3 * e + e + j];
363 wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
364 }
365 }
366 (wq, wk, wv)
367}
368
369fn add_param(g: &mut HirMut<'_>, name: &str, _data: Vec<f32>, shape: Shape) -> HirNodeId {
370 g.param(name, shape)
371}
372
373pub fn build_sam3_detector_encoder_graph(
375 weights: &Sam3EncoderWeights,
376 batch: usize,
377 hw: usize,
378 seq: usize,
379) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
380 let parts = build_encoder_hir(weights, batch, hw, seq, None)?;
381 rlx_core::flow_util::graph_from_hir(parts.hir, parts.params)
382}
383
384pub fn build_encoder_hir(
386 weights: &Sam3EncoderWeights,
387 batch: usize,
388 hw: usize,
389 seq: usize,
390 gguf_packed: Option<&GgufPackedParams>,
391) -> Result<Sam3EncoderHirParts> {
392 let mut hir = HirModule::new("sam3_detector_encoder");
393 let mut g = HirMut::new(&mut hir);
394 let mut params: HashMap<String, Vec<f32>> = HashMap::new();
395 let mut typed_params = Vec::new();
396 let mut gguf_w_cache: HashMap<String, HirNodeId> = HashMap::new();
397 let f = DType::F32;
398 let d = D_MODEL;
399 let enc_base = &weights.prefix;
400
401 let src = g.input("src", Shape::new(&[batch, hw, d], f));
402 let src_pos = g.input("src_pos", Shape::new(&[batch, hw, d], f));
403 let prompt = g.input("prompt", Shape::new(&[batch, seq, d], f));
404 let prompt_kpm_inv = g.input("prompt_kpm_inv", Shape::new(&[batch, seq], f));
405
406 let mut tgt = src;
407 for (li, layer) in weights.layers.iter().enumerate() {
408 tgt = emit_sam3_detector_encoder_layer(
409 &mut g,
410 &mut params,
411 &mut typed_params,
412 &mut gguf_w_cache,
413 gguf_packed,
414 enc_base,
415 li,
416 layer,
417 batch,
418 hw,
419 seq,
420 tgt,
421 src_pos,
422 prompt,
423 prompt_kpm_inv,
424 )?;
425 }
426 g.set_outputs(vec![tgt]);
427 Ok(Sam3EncoderHirParts {
428 hir,
429 params,
430 typed_params,
431 })
432}
433
434#[allow(clippy::too_many_arguments)]
436pub fn emit_sam3_detector_encoder_layer(
437 g: &mut HirMut<'_>,
438 params: &mut HashMap<String, Vec<f32>>,
439 typed_params: &mut Vec<(String, Vec<u8>, DType)>,
440 gguf_w_cache: &mut HashMap<String, HirNodeId>,
441 gguf_packed: Option<&GgufPackedParams>,
442 enc_base: &str,
443 li: usize,
444 layer: &Sam3EncoderLayerWeights,
445 _batch: usize,
446 _hw: usize,
447 _seq: usize,
448 tgt: HirNodeId,
449 src_pos: HirNodeId,
450 prompt: HirNodeId,
451 prompt_kpm_inv: HirNodeId,
452) -> Result<HirNodeId> {
453 let f = DType::F32;
454 let d = D_MODEL;
455 let nh = N_HEADS;
456 let dh = HEAD_DIM;
457 let dim_ff = DIM_FF;
458
459 let n1_w = add_param(
460 g,
461 &format!("l{li}.norm1.w"),
462 layer.norm1_w.clone(),
463 Shape::new(&[d], f),
464 );
465 params.insert(format!("l{li}.norm1.w"), layer.norm1_w.clone());
466 let n1_b = add_param(
467 g,
468 &format!("l{li}.norm1.b"),
469 layer.norm1_b.clone(),
470 Shape::new(&[d], f),
471 );
472 params.insert(format!("l{li}.norm1.b"), layer.norm1_b.clone());
473 let n1 = g.ln(tgt, n1_w, n1_b, 1e-5);
474
475 let qk_in = g.add(n1, src_pos);
476
477 let (q_node, k_node, v_node) = in_proj_qkv(
478 g,
479 params,
480 typed_params,
481 gguf_w_cache,
482 gguf_packed,
483 &enc_layer_key(enc_base, li, "self_attn.in_proj_weight"),
484 &format!("l{li}.sa.in_proj"),
485 &layer.self_attn_in_w_t,
486 &layer.self_attn_in_b,
487 qk_in,
488 qk_in,
489 n1,
490 d,
491 )?;
492
493 let sa_attn = g.attention_kind(
494 q_node,
495 k_node,
496 v_node,
497 nh,
498 dh,
499 MaskKind::None,
500 shape::attention_shape(g.shape(q_node)),
501 );
502 let sa_out = linear_fused_or_gguf(
503 g,
504 params,
505 typed_params,
506 gguf_w_cache,
507 gguf_packed,
508 &enc_layer_key(enc_base, li, "self_attn.out_proj.weight"),
509 &format!("l{li}.sa.proj"),
510 sa_attn,
511 layer.self_attn_out_w_t.clone(),
512 layer.self_attn_out_b.clone(),
513 d,
514 d,
515 )?;
516 let mut tgt = g.add(tgt, sa_out);
517
518 let n2_w = add_param(
519 g,
520 &format!("l{li}.norm2.w"),
521 layer.norm2_w.clone(),
522 Shape::new(&[d], f),
523 );
524 params.insert(format!("l{li}.norm2.w"), layer.norm2_w.clone());
525 let n2_b = add_param(
526 g,
527 &format!("l{li}.norm2.b"),
528 layer.norm2_b.clone(),
529 Shape::new(&[d], f),
530 );
531 params.insert(format!("l{li}.norm2.b"), layer.norm2_b.clone());
532 let n2 = g.ln(tgt, n2_w, n2_b, 1e-5);
533
534 let (qc, kc, vc) = in_proj_qkv(
535 g,
536 params,
537 typed_params,
538 gguf_w_cache,
539 gguf_packed,
540 &enc_layer_key(enc_base, li, "cross_attn_image.in_proj_weight"),
541 &format!("l{li}.ca.in_proj"),
542 &layer.cross_attn_in_w_t,
543 &layer.cross_attn_in_b,
544 n2,
545 prompt,
546 prompt,
547 d,
548 )?;
549
550 let ca_attn = g.attention(
551 qc,
552 kc,
553 vc,
554 prompt_kpm_inv,
555 nh,
556 dh,
557 shape::attention_shape(g.shape(qc)),
558 );
559 let ca_out = linear_fused_or_gguf(
560 g,
561 params,
562 typed_params,
563 gguf_w_cache,
564 gguf_packed,
565 &enc_layer_key(enc_base, li, "cross_attn_image.out_proj.weight"),
566 &format!("l{li}.ca.proj"),
567 ca_attn,
568 layer.cross_attn_out_w_t.clone(),
569 layer.cross_attn_out_b.clone(),
570 d,
571 d,
572 )?;
573 tgt = g.add(tgt, ca_out);
574
575 let n3_w = add_param(
576 g,
577 &format!("l{li}.norm3.w"),
578 layer.norm3_w.clone(),
579 Shape::new(&[d], f),
580 );
581 params.insert(format!("l{li}.norm3.w"), layer.norm3_w.clone());
582 let n3_b = add_param(
583 g,
584 &format!("l{li}.norm3.b"),
585 layer.norm3_b.clone(),
586 Shape::new(&[d], f),
587 );
588 params.insert(format!("l{li}.norm3.b"), layer.norm3_b.clone());
589 let n3 = g.ln(tgt, n3_w, n3_b, 1e-5);
590
591 let ff1 = linear_fused_or_gguf(
592 g,
593 params,
594 typed_params,
595 gguf_w_cache,
596 gguf_packed,
597 &enc_layer_key(enc_base, li, "linear1.weight"),
598 &format!("l{li}.ffn.fc1"),
599 n3,
600 layer.linear1_w_t.clone(),
601 layer.linear1_b.clone(),
602 d,
603 dim_ff,
604 )?;
605 let relud = g.relu(ff1);
606 let ff2 = linear_fused_or_gguf(
607 g,
608 params,
609 typed_params,
610 gguf_w_cache,
611 gguf_packed,
612 &enc_layer_key(enc_base, li, "linear2.weight"),
613 &format!("l{li}.ffn.fc2"),
614 relud,
615 layer.linear2_w_t.clone(),
616 layer.linear2_b.clone(),
617 dim_ff,
618 d,
619 )?;
620 Ok(g.add(tgt, ff2))
621}
622
623fn qkv_linear(
624 g: &mut HirMut<'_>,
625 params: &mut HashMap<String, Vec<f32>>,
626 name: &str,
627 input: HirNodeId,
628 w: Vec<f32>,
629 b: Vec<f32>,
630 batch: usize,
631 seq: usize,
632 d: usize,
633) -> HirNodeId {
634 let f = DType::F32;
635 let w_name = format!("{name}.w");
636 let b_name = format!("{name}.b");
637 let w_id = g.param(&w_name, Shape::new(&[d, d], f));
638 params.insert(w_name, w);
639 let b_id = g.param(&b_name, Shape::new(&[d], f));
640 params.insert(b_name, b);
641 let out_shape = Shape::new(&[batch, seq, d], f);
642 g.add_node(
643 Op::FusedMatMulBiasAct { activation: None },
644 vec![input, w_id, b_id],
645 out_shape,
646 )
647}
648
649fn linear_with_bias(
650 g: &mut HirMut<'_>,
651 params: &mut HashMap<String, Vec<f32>>,
652 name: &str,
653 input: HirNodeId,
654 w: Vec<f32>,
655 b: Vec<f32>,
656 in_dim: usize,
657 out_dim: usize,
658) -> HirNodeId {
659 let f = DType::F32;
660 let w_name = format!("{name}.w");
661 let b_name = format!("{name}.b");
662 let w_id = g.param(&w_name, Shape::new(&[in_dim, out_dim], f));
663 params.insert(w_name, w);
664 let b_id = g.param(&b_name, Shape::new(&[out_dim], f));
665 params.insert(b_name, b);
666 let cur_shape = g.shape(input);
667 let mut out_dims: Vec<usize> = cur_shape.dims().iter().map(|d| d.unwrap_static()).collect();
668 *out_dims.last_mut().unwrap() = out_dim;
669 g.add_node(
670 Op::FusedMatMulBiasAct { activation: None },
671 vec![input, w_id, b_id],
672 Shape::new(&out_dims, f),
673 )
674}
675
676#[allow(clippy::too_many_arguments)]
677pub fn forward_encoder_ir_on(
678 weights: &Sam3EncoderWeights,
679 src_bchw: &[f32],
680 src_pos_bchw: &[f32],
681 prompt_seq_first: &[f32],
682 prompt_kpm: &[u8],
683 batch: usize,
684 src_h: usize,
685 src_w: usize,
686 prompt_len: usize,
687 device: Device,
688) -> Result<Vec<f32>> {
689 forward_encoder_ir_on_with_profile(
690 weights,
691 src_bchw,
692 src_pos_bchw,
693 prompt_seq_first,
694 prompt_kpm,
695 batch,
696 src_h,
697 src_w,
698 prompt_len,
699 device,
700 &CompileProfile::sam3(),
701 None,
702 )
703}
704
705#[allow(clippy::too_many_arguments)]
707pub fn forward_encoder_ir_on_with_profile(
708 weights: &Sam3EncoderWeights,
709 src_bchw: &[f32],
710 src_pos_bchw: &[f32],
711 prompt_seq_first: &[f32],
712 prompt_kpm: &[u8],
713 batch: usize,
714 src_h: usize,
715 src_w: usize,
716 prompt_len: usize,
717 device: Device,
718 profile: &CompileProfile,
719 gguf_packed: Option<&GgufPackedParams>,
720) -> Result<Vec<f32>> {
721 ensure!(weights.loaded, "SAM3 detector encoder not loaded");
722 let hw = src_h * src_w;
723 ensure!(
724 src_bchw.len() == batch * D_MODEL * hw,
725 "encoder src shape mismatch"
726 );
727 ensure!(
728 prompt_seq_first.len() == prompt_len * batch * D_MODEL,
729 "encoder prompt shape mismatch"
730 );
731
732 let mut src_bhwc = vec![0f32; batch * hw * D_MODEL];
733 let mut pos_bhwc = vec![0f32; batch * hw * D_MODEL];
734 for b in 0..batch {
735 for s in 0..hw {
736 for c in 0..D_MODEL {
737 src_bhwc[(b * hw + s) * D_MODEL + c] = src_bchw[((b * D_MODEL + c) * hw) + s];
738 pos_bhwc[(b * hw + s) * D_MODEL + c] = src_pos_bchw[((b * D_MODEL + c) * hw) + s];
739 }
740 }
741 }
742
743 let mut prompt_bf = vec![0f32; batch * prompt_len * D_MODEL];
744 for b in 0..batch {
745 for l in 0..prompt_len {
746 let s = (l * batch + b) * D_MODEL;
747 let dst = (b * prompt_len + l) * D_MODEL;
748 prompt_bf[dst..dst + D_MODEL].copy_from_slice(&prompt_seq_first[s..s + D_MODEL]);
749 }
750 }
751 let prompt_kpm_inv: Vec<f32> = prompt_kpm
752 .iter()
753 .map(|&v| if v == 0 { 1.0 } else { 0.0 })
754 .collect();
755
756 let parts = build_encoder_hir(weights, batch, hw, prompt_len, gguf_packed)?;
757 let mut compiled = rlx_core::flow_bridge::compile_hir_with_profile(device, parts.hir, profile)?;
758 rlx_core::flow_util::attach_built_params(&mut compiled, parts.params, &parts.typed_params);
759 let outputs = compiled.run(&[
760 ("src", src_bhwc.as_slice()),
761 ("src_pos", pos_bhwc.as_slice()),
762 ("prompt", prompt_bf.as_slice()),
763 ("prompt_kpm_inv", prompt_kpm_inv.as_slice()),
764 ]);
765 let out = outputs
766 .into_iter()
767 .next()
768 .ok_or_else(|| anyhow::anyhow!("encoder graph produced no outputs"))?;
769 Ok(out)
770}
771
772#[allow(clippy::too_many_arguments)]
773pub fn forward_encoder_ir(
774 weights: &Sam3EncoderWeights,
775 src_bchw: &[f32],
776 src_pos_bchw: &[f32],
777 prompt_seq_first: &[f32],
778 prompt_kpm: &[u8],
779 batch: usize,
780 src_h: usize,
781 src_w: usize,
782 prompt_len: usize,
783) -> Result<Vec<f32>> {
784 forward_encoder_ir_on_with_profile(
785 weights,
786 src_bchw,
787 src_pos_bchw,
788 prompt_seq_first,
789 prompt_kpm,
790 batch,
791 src_h,
792 src_w,
793 prompt_len,
794 Device::Cpu,
795 &CompileProfile::sam3(),
796 None,
797 )
798}