1use super::detector_decoder::{
25 Mlp2, Mlp3, Sam3DecoderLayerWeights, Sam3DecoderOutput, Sam3DecoderWeights, mlp2_forward,
26 mlp2_forward_into, mlp3_forward, mlp3_forward_into,
27};
28use super::packed_gguf::packed_linear;
29use anyhow::{Result, ensure};
30use rlx_flow::CompileProfile;
31use rlx_flow::{GgufPackedLinear, GgufPackedParams};
32use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
33use rlx_ir::op::{Activation, MaskKind, Op};
34use rlx_ir::shape;
35use rlx_ir::{DType, Shape};
36use rlx_runtime::{CompiledGraph, Device};
37use std::collections::HashMap;
38
39const D_MODEL: usize = 256;
40const DIM_FF: usize = 2048;
41const N_HEADS: usize = 8;
42const HEAD_DIM: usize = D_MODEL / N_HEADS;
43const NUM_QUERIES: usize = 200;
44const N_LAYERS: usize = 6;
45
46type LayerHirParts = (
47 HirModule,
48 HashMap<String, Vec<f32>>,
49 Vec<(String, Vec<u8>, DType)>,
50);
51type LayerRunOut = (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>);
52
53fn dec_layer_key(base: &str, li: usize, suffix: &str) -> String {
54 format!("{base}.layers.{li}.{suffix}")
55}
56
57fn gguf_weight_param(
58 g: &mut HirMut<'_>,
59 typed: &mut Vec<(String, Vec<u8>, DType)>,
60 cache: &mut HashMap<String, HirNodeId>,
61 ir_name: &str,
62 p: &GgufPackedLinear,
63) -> HirNodeId {
64 if let Some(&id) = cache.get(ir_name) {
65 return id;
66 }
67 let id = g.param(ir_name, Shape::new(&[p.w_q.len()], DType::U8));
68 typed.push((ir_name.to_string(), p.w_q.clone(), DType::U8));
69 cache.insert(ir_name.to_string(), id);
70 id
71}
72
73fn linear_gguf_matmul(
74 g: &mut HirMut<'_>,
75 typed: &mut Vec<(String, Vec<u8>, DType)>,
76 cache: &mut HashMap<String, HirNodeId>,
77 ir_stem: &str,
78 p: &GgufPackedLinear,
79 input: HirNodeId,
80 in_dim: usize,
81 out_dim: usize,
82) -> Result<HirNodeId> {
83 ensure!(
84 p.in_dim == in_dim && p.out_dim == out_dim,
85 "packed linear {ir_stem}: shape {}x{} vs {in_dim}x{out_dim}",
86 p.in_dim,
87 p.out_dim
88 );
89 let w_name = format!("{ir_stem}.w");
90 let w_id = gguf_weight_param(g, typed, cache, &w_name, p);
91 let cur = g.shape(input);
92 let mut dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
93 *dims.last_mut().unwrap() = out_dim;
94 let out_shape = Shape::new(&dims, DType::F32);
95 Ok(g.add_node(
96 Op::DequantMatMul { scheme: p.scheme },
97 vec![input, w_id],
98 out_shape,
99 ))
100}
101
102fn add_f32_bias(
103 g: &mut HirMut<'_>,
104 params: &mut HashMap<String, Vec<f32>>,
105 name: &str,
106 input: HirNodeId,
107 bias: &[f32],
108) -> HirNodeId {
109 if bias.iter().all(|&v| v == 0.0) {
110 return input;
111 }
112 let out_dim = bias.len();
113 let b_id = add_param(
114 g,
115 params,
116 name,
117 bias.to_vec(),
118 Shape::new(&[out_dim], DType::F32),
119 );
120 g.add(input, b_id)
121}
122
123fn linear_gguf_bias(
124 g: &mut HirMut<'_>,
125 params: &mut HashMap<String, Vec<f32>>,
126 typed: &mut Vec<(String, Vec<u8>, DType)>,
127 cache: &mut HashMap<String, HirNodeId>,
128 ir_stem: &str,
129 p: &GgufPackedLinear,
130 input: HirNodeId,
131 bias: &[f32],
132 in_dim: usize,
133 out_dim: usize,
134) -> Result<HirNodeId> {
135 let y = linear_gguf_matmul(g, typed, cache, ir_stem, p, input, in_dim, out_dim)?;
136 Ok(add_f32_bias(g, params, &format!("{ir_stem}.b"), y, bias))
137}
138
139fn in_proj_qkv(
140 g: &mut HirMut<'_>,
141 params: &mut HashMap<String, Vec<f32>>,
142 typed: &mut Vec<(String, Vec<u8>, DType)>,
143 cache: &mut HashMap<String, HirNodeId>,
144 gguf_packed: Option<&GgufPackedParams>,
145 gguf_key: &str,
146 ir_stem: &str,
147 layer_w_t: &[f32],
148 layer_b: &[f32],
149 input_q: HirNodeId,
150 input_k: HirNodeId,
151 input_v: HirNodeId,
152 d: usize,
153) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
154 if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
155 let qkv_q = linear_gguf_bias(
156 g,
157 params,
158 typed,
159 cache,
160 ir_stem,
161 p,
162 input_q,
163 layer_b,
164 d,
165 3 * d,
166 )?;
167 let qkv_k = linear_gguf_bias(
168 g,
169 params,
170 typed,
171 cache,
172 ir_stem,
173 p,
174 input_k,
175 layer_b,
176 d,
177 3 * d,
178 )?;
179 let qkv_v = linear_gguf_bias(
180 g,
181 params,
182 typed,
183 cache,
184 ir_stem,
185 p,
186 input_v,
187 layer_b,
188 d,
189 3 * d,
190 )?;
191 let axis = g.shape(qkv_q).rank().saturating_sub(1);
192 let q = g.narrow_(qkv_q, axis, 0, d);
193 let k = g.narrow_(qkv_k, axis, d, d);
194 let v = g.narrow_(qkv_v, axis, 2 * d, d);
195 return Ok((q, k, v));
196 }
197 let (wq, wk, wv) = split_qkv(layer_w_t, d);
198 let bq = layer_b[0..d].to_vec();
199 let bk = layer_b[d..2 * d].to_vec();
200 let bv = layer_b[2 * d..3 * d].to_vec();
201 let batch_q = g.shape(input_q).dims()[0].unwrap_static();
202 let seq_q = g.shape(input_q).dims()[1].unwrap_static();
203 let batch_k = g.shape(input_k).dims()[0].unwrap_static();
204 let seq_k = g.shape(input_k).dims()[1].unwrap_static();
205 let batch_v = g.shape(input_v).dims()[0].unwrap_static();
206 let seq_v = g.shape(input_v).dims()[1].unwrap_static();
207 let q = linear_bias_shaped(
208 g,
209 params,
210 &format!("{ir_stem}.q"),
211 input_q,
212 wq,
213 bq,
214 d,
215 d,
216 Some(batch_q),
217 Some(seq_q),
218 );
219 let k = linear_bias_shaped(
220 g,
221 params,
222 &format!("{ir_stem}.k"),
223 input_k,
224 wk,
225 bk,
226 d,
227 d,
228 Some(batch_k),
229 Some(seq_k),
230 );
231 let v = linear_bias_shaped(
232 g,
233 params,
234 &format!("{ir_stem}.v"),
235 input_v,
236 wv,
237 bv,
238 d,
239 d,
240 Some(batch_v),
241 Some(seq_v),
242 );
243 Ok((q, k, v))
244}
245
246fn linear_fused_or_gguf(
247 g: &mut HirMut<'_>,
248 params: &mut HashMap<String, Vec<f32>>,
249 typed: &mut Vec<(String, Vec<u8>, DType)>,
250 cache: &mut HashMap<String, HirNodeId>,
251 gguf_packed: Option<&GgufPackedParams>,
252 gguf_key: &str,
253 ir_stem: &str,
254 input: HirNodeId,
255 w_t: Vec<f32>,
256 bias: Vec<f32>,
257 in_dim: usize,
258 out_dim: usize,
259) -> Result<HirNodeId> {
260 if let Some(p) = gguf_packed.and_then(|m| packed_linear(m, gguf_key)) {
261 return linear_gguf_bias(
262 g, params, typed, cache, ir_stem, p, input, &bias, in_dim, out_dim,
263 );
264 }
265 Ok(linear_bias(
266 g, params, ir_stem, input, w_t, bias, in_dim, out_dim,
267 ))
268}
269
270fn split_qkv(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
271 let mut wq = vec![0f32; e * e];
272 let mut wk = vec![0f32; e * e];
273 let mut wv = vec![0f32; e * e];
274 for i in 0..e {
275 for j in 0..e {
276 wq[i * e + j] = w_t[i * 3 * e + j];
277 wk[i * e + j] = w_t[i * 3 * e + e + j];
278 wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
279 }
280 }
281 (wq, wk, wv)
282}
283
284fn add_param(
285 g: &mut HirMut<'_>,
286 params: &mut HashMap<String, Vec<f32>>,
287 name: &str,
288 data: Vec<f32>,
289 shape: Shape,
290) -> HirNodeId {
291 let id = g.param(name, shape);
292 params.insert(name.to_string(), data);
293 id
294}
295
296fn linear_bias(
297 g: &mut HirMut<'_>,
298 params: &mut HashMap<String, Vec<f32>>,
299 name: &str,
300 input: HirNodeId,
301 w: Vec<f32>,
302 b: Vec<f32>,
303 in_dim: usize,
304 out_dim: usize,
305) -> HirNodeId {
306 linear_bias_shaped(g, params, name, input, w, b, in_dim, out_dim, None, None)
307}
308
309fn linear_bias_shaped(
310 g: &mut HirMut<'_>,
311 params: &mut HashMap<String, Vec<f32>>,
312 name: &str,
313 input: HirNodeId,
314 w: Vec<f32>,
315 b: Vec<f32>,
316 in_dim: usize,
317 out_dim: usize,
318 batch: Option<usize>,
319 seq: Option<usize>,
320) -> HirNodeId {
321 let f = DType::F32;
322 let w_id = add_param(
323 g,
324 params,
325 &format!("{name}.w"),
326 w,
327 Shape::new(&[in_dim, out_dim], f),
328 );
329 let b_id = add_param(
330 g,
331 params,
332 &format!("{name}.b"),
333 b,
334 Shape::new(&[out_dim], f),
335 );
336 let out_shape = if let (Some(batch), Some(seq)) = (batch, seq) {
337 Shape::new(&[batch, seq, out_dim], f)
338 } else {
339 let cur = g.shape(input);
340 let mut out_dims: Vec<usize> = cur.dims().iter().map(|d| d.unwrap_static()).collect();
341 *out_dims.last_mut().unwrap() = out_dim;
342 Shape::new(&out_dims, f)
343 };
344 g.add_node(
345 Op::FusedMatMulBiasAct { activation: None },
346 vec![input, w_id, b_id],
347 out_shape,
348 )
349}
350
351fn fused_matmul_bias_act(
352 g: &mut HirMut<'_>,
353 input: HirNodeId,
354 w: HirNodeId,
355 b: HirNodeId,
356 activation: Option<Activation>,
357 out_shape: Shape,
358) -> HirNodeId {
359 g.add_node(
360 Op::FusedMatMulBiasAct { activation },
361 vec![input, w, b],
362 out_shape,
363 )
364}
365
366fn attention_bias(
367 g: &mut HirMut<'_>,
368 q: HirNodeId,
369 k: HirNodeId,
370 v: HirNodeId,
371 bias: HirNodeId,
372 num_heads: usize,
373 head_dim: usize,
374) -> HirNodeId {
375 let attn_shape = shape::attention_shape(g.shape(q));
376 g.add_node(
377 Op::Attention {
378 num_heads,
379 head_dim,
380 mask_kind: MaskKind::Bias,
381 score_scale: None,
382 attn_logit_softcap: None,
383 },
384 vec![q, k, v, bias],
385 attn_shape,
386 )
387}
388
389#[allow(clippy::too_many_arguments)]
396fn mlp2_relu_pair_gguf(
397 g: &mut HirMut<'_>,
398 params: &mut HashMap<String, Vec<f32>>,
399 typed: &mut Vec<(String, Vec<u8>, DType)>,
400 cache: &mut HashMap<String, HirNodeId>,
401 gguf_packed: Option<&GgufPackedParams>,
402 mlp: &Mlp2,
403 stem: &str,
404 input: HirNodeId,
405 rows: usize,
406 hidden_dim: usize,
407 out_dim: usize,
408) -> Result<HirNodeId> {
409 let h = if let Some(p) = mlp
410 .w0_gguf_key
411 .as_deref()
412 .and_then(|key| gguf_packed.and_then(|m| super::packed_gguf::packed_linear(m, key)))
413 {
414 let y = linear_gguf_bias(
415 g,
416 params,
417 typed,
418 cache,
419 &format!("{stem}.fc0"),
420 p,
421 input,
422 &mlp.b0,
423 mlp.in_dim,
424 hidden_dim,
425 )?;
426 g.relu(y)
427 } else {
428 let w_id = add_param(
429 g,
430 params,
431 &format!("{stem}.w0"),
432 mlp.w0_t.clone(),
433 Shape::new(&[mlp.in_dim, hidden_dim], DType::F32),
434 );
435 let b_id = add_param(
436 g,
437 params,
438 &format!("{stem}.b0"),
439 mlp.b0.clone(),
440 Shape::new(&[hidden_dim], DType::F32),
441 );
442 fused_matmul_bias_act(
443 g,
444 input,
445 w_id,
446 b_id,
447 Some(Activation::Relu),
448 Shape::new(&[rows, hidden_dim], DType::F32),
449 )
450 };
451 if let Some(p) = mlp
452 .w1_gguf_key
453 .as_deref()
454 .and_then(|key| gguf_packed.and_then(|m| super::packed_gguf::packed_linear(m, key)))
455 {
456 return linear_gguf_bias(
457 g,
458 params,
459 typed,
460 cache,
461 &format!("{stem}.fc1"),
462 p,
463 h,
464 &mlp.b1,
465 hidden_dim,
466 out_dim,
467 );
468 }
469 let w_id = add_param(
470 g,
471 params,
472 &format!("{stem}.w1"),
473 mlp.w1_t.clone(),
474 Shape::new(&[hidden_dim, out_dim], DType::F32),
475 );
476 let b_id = add_param(
477 g,
478 params,
479 &format!("{stem}.b1"),
480 mlp.b1.clone(),
481 Shape::new(&[out_dim], DType::F32),
482 );
483 Ok(fused_matmul_bias_act(
484 g,
485 h,
486 w_id,
487 b_id,
488 None,
489 Shape::new(&[rows, out_dim], DType::F32),
490 ))
491}
492
493fn build_boxrpb_subgraph(
494 g: &mut HirMut<'_>,
495 params: &mut HashMap<String, Vec<f32>>,
496 typed: &mut Vec<(String, Vec<u8>, DType)>,
497 gguf_cache: &mut HashMap<String, HirNodeId>,
498 gguf_packed: Option<&GgufPackedParams>,
499 boxrpb_x: &Mlp2,
500 boxrpb_y: &Mlp2,
501 deltas_x: HirNodeId,
502 deltas_y: HirNodeId,
503 batch: usize,
504 nq: usize,
505 nh: usize,
506 h: usize,
507 w: usize,
508) -> Result<HirNodeId> {
509 let f = DType::F32;
510 let hidden_x = boxrpb_x.hidden;
511 let hidden_y = boxrpb_y.hidden;
512 assert_eq!(boxrpb_x.in_dim, 2);
513 assert_eq!(boxrpb_y.in_dim, 2);
514 assert_eq!(boxrpb_x.out_dim, nh);
515 assert_eq!(boxrpb_y.out_dim, nh);
516
517 let dx_flat = g.reshape_(deltas_x, vec![(batch * nq * w) as i64, 2]);
518 let dx_o = mlp2_relu_pair_gguf(
519 g,
520 params,
521 typed,
522 gguf_cache,
523 gguf_packed,
524 boxrpb_x,
525 "boxrpb_x",
526 dx_flat,
527 batch * nq * w,
528 hidden_x,
529 nh,
530 )?;
531 let dx_4d = g.reshape_(dx_o, vec![batch as i64, nq as i64, w as i64, nh as i64]);
532 let dx_perm = g.transpose_(dx_4d, vec![0, 3, 1, 2]);
533 let dx_bc = g.reshape_(
534 dx_perm,
535 vec![batch as i64, nh as i64, nq as i64, 1, w as i64],
536 );
537
538 let dy_flat = g.reshape_(deltas_y, vec![(batch * nq * h) as i64, 2]);
539 let dy_o = mlp2_relu_pair_gguf(
540 g,
541 params,
542 typed,
543 gguf_cache,
544 gguf_packed,
545 boxrpb_y,
546 "boxrpb_y",
547 dy_flat,
548 batch * nq * h,
549 hidden_y,
550 nh,
551 )?;
552 let dy_4d = g.reshape_(dy_o, vec![batch as i64, nq as i64, h as i64, nh as i64]);
553 let dy_perm = g.transpose_(dy_4d, vec![0, 3, 1, 2]);
554 let dy_bc = g.reshape_(
555 dy_perm,
556 vec![batch as i64, nh as i64, nq as i64, h as i64, 1],
557 );
558
559 let rpb_q = g.add(dx_bc, dy_bc);
560 let rpb_q_flat = g.reshape_(
561 rpb_q,
562 vec![batch as i64, nh as i64, nq as i64, (h * w) as i64],
563 );
564
565 let hw = h * w;
566 let _lq = nq + 1;
567 let zero_pres = add_param(
568 g,
569 params,
570 "rpb_zero_presence",
571 vec![0f32; batch * nh * hw],
572 Shape::new(&[batch, nh, 1, hw], f),
573 );
574 Ok(g.concat_(vec![zero_pres, rpb_q_flat], 2))
575}
576
577struct DecoderLayerHirParts {
578 params: HashMap<String, Vec<f32>>,
579 typed_params: Vec<(String, Vec<u8>, DType)>,
580}
581
582#[allow(clippy::too_many_arguments)]
584fn build_layer_body(
585 hir: &mut HirModule,
586 layer: &Sam3DecoderLayerWeights,
587 boxrpb_x: &Mlp2,
588 boxrpb_y: &Mlp2,
589 norm_w: &[f32],
590 norm_b: &[f32],
591 dec_base: &str,
592 li: usize,
593 batch: usize,
594 h: usize,
595 w: usize,
596 seq: usize,
597 use_bias_attn: bool,
598 boxrpb_in_ir: bool,
599 gguf_packed: Option<&GgufPackedParams>,
600) -> Result<DecoderLayerHirParts> {
601 let hw = h * w;
602 let mut g = HirMut::new(hir);
603 let mut params: HashMap<String, Vec<f32>> = HashMap::new();
604 let mut typed_params = Vec::new();
605 let mut gguf_w_cache: HashMap<String, HirNodeId> = HashMap::new();
606 let f = DType::F32;
607 let d = D_MODEL;
608 let nh = N_HEADS;
609 let dh = HEAD_DIM;
610 let nq = NUM_QUERIES;
611 let lq = nq + 1;
612
613 let tgt = g.input("tgt", Shape::new(&[batch, nq, d], f));
614 let query_pos = g.input("query_pos", Shape::new(&[batch, nq, d], f));
615 let presence = g.input("presence", Shape::new(&[batch, 1, d], f));
616 let memory = g.input("memory", Shape::new(&[batch, hw, d], f));
617 let memory_pos = g.input("memory_pos", Shape::new(&[batch, hw, d], f));
618 let text = g.input("text", Shape::new(&[batch, seq, d], f));
619 let text_kpm_inv = g.input("text_kpm_inv", Shape::new(&[batch, seq], f));
620 let rpb_bias = if boxrpb_in_ir {
621 let dx = g.input("deltas_x", Shape::new(&[batch, nq, w, 2], f));
622 let dy = g.input("deltas_y", Shape::new(&[batch, nq, h, 2], f));
623 build_boxrpb_subgraph(
624 &mut g,
625 &mut params,
626 &mut typed_params,
627 &mut gguf_w_cache,
628 gguf_packed,
629 boxrpb_x,
630 boxrpb_y,
631 dx,
632 dy,
633 batch,
634 nq,
635 nh,
636 h,
637 w,
638 )?
639 } else {
640 g.input("rpb_bias", Shape::new(&[batch, nh, lq, hw], f))
641 };
642
643 let sa_x = g.concat_(vec![presence, tgt], 1);
644 let zero_pos = add_param(
645 &mut g,
646 &mut params,
647 "zero_presence_pos",
648 vec![0f32; batch * d],
649 Shape::new(&[batch, 1, d], f),
650 );
651 let sa_pos = g.concat_(vec![zero_pos, query_pos], 1);
652 let sa_qk = g.add(sa_x, sa_pos);
653
654 let (q_sa, k_sa, v_sa) = in_proj_qkv(
655 &mut g,
656 &mut params,
657 &mut typed_params,
658 &mut gguf_w_cache,
659 gguf_packed,
660 &dec_layer_key(dec_base, li, "self_attn.in_proj_weight"),
661 "sa.in_proj",
662 &layer.self_attn_in_w_t,
663 &layer.self_attn_in_b,
664 sa_qk,
665 sa_qk,
666 sa_x,
667 d,
668 )?;
669 let sa_attn = g.attention_kind(
670 q_sa,
671 k_sa,
672 v_sa,
673 nh,
674 dh,
675 MaskKind::None,
676 shape::attention_shape(g.shape(q_sa)),
677 );
678 let sa_proj = linear_fused_or_gguf(
679 &mut g,
680 &mut params,
681 &mut typed_params,
682 &mut gguf_w_cache,
683 gguf_packed,
684 &dec_layer_key(dec_base, li, "self_attn.out_proj.weight"),
685 "sa.out",
686 sa_attn,
687 layer.self_attn_out_w_t.clone(),
688 layer.self_attn_out_b.clone(),
689 d,
690 d,
691 )?;
692 let sa_res = g.add(sa_x, sa_proj);
693 let n2_w = add_param(
694 &mut g,
695 &mut params,
696 "norm2.w",
697 layer.norm2_w.clone(),
698 Shape::new(&[d], f),
699 );
700 let n2_b = add_param(
701 &mut g,
702 &mut params,
703 "norm2.b",
704 layer.norm2_b.clone(),
705 Shape::new(&[d], f),
706 );
707 let sa_normed = g.ln(sa_res, n2_w, n2_b, 1e-5);
708 let presence_after_sa = g.narrow_(sa_normed, 1, 0, 1);
709 let queries_after_sa = g.narrow_(sa_normed, 1, 1, nq);
710
711 let q_text_in = g.add(queries_after_sa, query_pos);
712 let (q_text, k_text, v_text) = in_proj_qkv(
713 &mut g,
714 &mut params,
715 &mut typed_params,
716 &mut gguf_w_cache,
717 gguf_packed,
718 &dec_layer_key(dec_base, li, "ca_text.in_proj_weight"),
719 "ca_text.in_proj",
720 &layer.ca_text_in_w_t,
721 &layer.ca_text_in_b,
722 q_text_in,
723 text,
724 text,
725 d,
726 )?;
727 let ca_text_attn = g.attention(
728 q_text,
729 k_text,
730 v_text,
731 text_kpm_inv,
732 nh,
733 dh,
734 shape::attention_shape(g.shape(q_text)),
735 );
736 let ca_text_proj = linear_fused_or_gguf(
737 &mut g,
738 &mut params,
739 &mut typed_params,
740 &mut gguf_w_cache,
741 gguf_packed,
742 &dec_layer_key(dec_base, li, "ca_text.out_proj.weight"),
743 "ca_text.out",
744 ca_text_attn,
745 layer.ca_text_out_w_t.clone(),
746 layer.ca_text_out_b.clone(),
747 d,
748 d,
749 )?;
750 let after_ca_text_res = g.add(queries_after_sa, ca_text_proj);
751 let cat_w = add_param(
752 &mut g,
753 &mut params,
754 "catext_norm.w",
755 layer.catext_norm_w.clone(),
756 Shape::new(&[d], f),
757 );
758 let cat_b = add_param(
759 &mut g,
760 &mut params,
761 "catext_norm.b",
762 layer.catext_norm_b.clone(),
763 Shape::new(&[d], f),
764 );
765 let after_ca_text = g.ln(after_ca_text_res, cat_w, cat_b, 1e-5);
766
767 let ca_in = g.concat_(vec![presence_after_sa, after_ca_text], 1);
768 let ca_q_in = g.add(ca_in, sa_pos);
769 let k_mem_in = g.add(memory, memory_pos);
770
771 let (q_img, k_img, v_img) = in_proj_qkv(
772 &mut g,
773 &mut params,
774 &mut typed_params,
775 &mut gguf_w_cache,
776 gguf_packed,
777 &dec_layer_key(dec_base, li, "cross_attn.in_proj_weight"),
778 "ca_img.in_proj",
779 &layer.cross_attn_in_w_t,
780 &layer.cross_attn_in_b,
781 ca_q_in,
782 k_mem_in,
783 memory,
784 d,
785 )?;
786
787 let attn_flat = if use_bias_attn {
788 attention_bias(&mut g, q_img, k_img, v_img, rpb_bias, nh, dh)
789 } else {
790 let q_4d = g.reshape_(q_img, vec![batch as i64, lq as i64, nh as i64, dh as i64]);
791 let q_perm = g.transpose_(q_4d, vec![0, 2, 1, 3]);
792 let k_4d = g.reshape_(k_img, vec![batch as i64, hw as i64, nh as i64, dh as i64]);
793 let k_perm = g.transpose_(k_4d, vec![0, 2, 1, 3]);
794 let v_4d = g.reshape_(v_img, vec![batch as i64, hw as i64, nh as i64, dh as i64]);
795 let v_perm = g.transpose_(v_4d, vec![0, 2, 1, 3]);
796 let k_t = g.transpose_(k_perm, vec![0, 1, 3, 2]);
797 let scores = g.mm(q_perm, k_t);
798 let scale_val = 1.0f32 / (HEAD_DIM as f32).sqrt();
799 let scale_node = add_param(
800 &mut g,
801 &mut params,
802 "img.scale",
803 vec![scale_val],
804 Shape::new(&[1], f),
805 );
806 let scores_scaled = g.mul(scores, scale_node);
807 let scores_biased = g.add(scores_scaled, rpb_bias);
808 let probs = g.sm(scores_biased, -1);
809 let attn_out = g.mm(probs, v_perm);
810 let attn_perm = g.transpose_(attn_out, vec![0, 2, 1, 3]);
811 g.reshape_(attn_perm, vec![batch as i64, lq as i64, d as i64])
812 };
813 let ca_img_proj = linear_fused_or_gguf(
814 &mut g,
815 &mut params,
816 &mut typed_params,
817 &mut gguf_w_cache,
818 gguf_packed,
819 &dec_layer_key(dec_base, li, "cross_attn.out_proj.weight"),
820 "ca_img.out",
821 attn_flat,
822 layer.cross_attn_out_w_t.clone(),
823 layer.cross_attn_out_b.clone(),
824 d,
825 d,
826 )?;
827 let ca_img_res = g.add(ca_in, ca_img_proj);
828 let n1_w = add_param(
829 &mut g,
830 &mut params,
831 "norm1.w",
832 layer.norm1_w.clone(),
833 Shape::new(&[d], f),
834 );
835 let n1_b = add_param(
836 &mut g,
837 &mut params,
838 "norm1.b",
839 layer.norm1_b.clone(),
840 Shape::new(&[d], f),
841 );
842 let after_ca_img = g.ln(ca_img_res, n1_w, n1_b, 1e-5);
843
844 let ff1 = linear_fused_or_gguf(
845 &mut g,
846 &mut params,
847 &mut typed_params,
848 &mut gguf_w_cache,
849 gguf_packed,
850 &dec_layer_key(dec_base, li, "linear1.weight"),
851 "ffn.fc1",
852 after_ca_img,
853 layer.linear1_w_t.clone(),
854 layer.linear1_b.clone(),
855 d,
856 DIM_FF,
857 )?;
858 let relud = g.relu(ff1);
859 let ff2 = linear_fused_or_gguf(
860 &mut g,
861 &mut params,
862 &mut typed_params,
863 &mut gguf_w_cache,
864 gguf_packed,
865 &dec_layer_key(dec_base, li, "linear2.weight"),
866 "ffn.fc2",
867 relud,
868 layer.linear2_w_t.clone(),
869 layer.linear2_b.clone(),
870 DIM_FF,
871 d,
872 )?;
873 let ffn_res = g.add(after_ca_img, ff2);
874 let n3_w = add_param(
875 &mut g,
876 &mut params,
877 "norm3.w",
878 layer.norm3_w.clone(),
879 Shape::new(&[d], f),
880 );
881 let n3_b = add_param(
882 &mut g,
883 &mut params,
884 "norm3.b",
885 layer.norm3_b.clone(),
886 Shape::new(&[d], f),
887 );
888 let after_ffn = g.ln(ffn_res, n3_w, n3_b, 1e-5);
889
890 let new_presence = g.narrow_(after_ffn, 1, 0, 1);
891 let new_tgt = g.narrow_(after_ffn, 1, 1, nq);
892
893 let dec_norm_w = add_param(
894 &mut g,
895 &mut params,
896 "dec.norm.w",
897 norm_w.to_vec(),
898 Shape::new(&[d], f),
899 );
900 let dec_norm_b = add_param(
901 &mut g,
902 &mut params,
903 "dec.norm.b",
904 norm_b.to_vec(),
905 Shape::new(&[d], f),
906 );
907 let out_norm = g.ln(new_tgt, dec_norm_w, dec_norm_b, 1e-5);
908
909 g.set_outputs(vec![new_tgt, new_presence, out_norm]);
910 let _ = (q_img, k_img, v_img, ca_img_proj);
911 Ok(DecoderLayerHirParts {
912 params,
913 typed_params,
914 })
915}
916
917fn build_layer_hir(
919 layer: &Sam3DecoderLayerWeights,
920 boxrpb_x: &Mlp2,
921 boxrpb_y: &Mlp2,
922 norm_w: &[f32],
923 norm_b: &[f32],
924 dec_base: &str,
925 li: usize,
926 batch: usize,
927 h: usize,
928 w: usize,
929 seq: usize,
930 use_bias_attn: bool,
931 boxrpb_in_ir: bool,
932 gguf_packed: Option<&GgufPackedParams>,
933) -> Result<LayerHirParts> {
934 let mut hir = HirModule::new("sam3_dec_layer");
935 let parts = build_layer_body(
936 &mut hir,
937 layer,
938 boxrpb_x,
939 boxrpb_y,
940 norm_w,
941 norm_b,
942 dec_base,
943 li,
944 batch,
945 h,
946 w,
947 seq,
948 use_bias_attn,
949 boxrpb_in_ir,
950 gguf_packed,
951 )?;
952 Ok((hir, parts.params, parts.typed_params))
953}
954
955pub struct Sam3CompiledDecoder {
957 layers: Vec<CompiledGraph>,
958 bbox_embed: Mlp3,
959 ref_point_head: Mlp2,
960 boxrpb_x: Mlp2,
961 boxrpb_y: Mlp2,
962 initial_query_embed: Vec<f32>,
963 initial_reference_points: Vec<f32>,
964 cached_layer0_query_pos: Vec<f32>,
965 cached_layer0_deltas_x: Option<Vec<f32>>,
968 cached_layer0_deltas_y: Option<Vec<f32>>,
969 cached_layer0_rpb: Option<Vec<f32>>,
972 #[allow(dead_code)]
973 cached_initial_ref_boxes: Vec<f32>,
974 boxrpb_in_ir: bool,
975 presence_token: Vec<f32>,
976 presence_head: Mlp3,
977 presence_norm_w: Vec<f32>,
978 presence_norm_b: Vec<f32>,
979 scratch_deltas_x: Vec<f32>,
982 scratch_deltas_y: Vec<f32>,
983 scratch_rpb: Option<Vec<f32>>,
986 scratch_dx_thq: Option<Vec<f32>>,
987 scratch_dy_thq: Option<Vec<f32>>,
988 scratch_boxrpb_x_hidden: Option<Vec<f32>>,
989 scratch_boxrpb_y_hidden: Option<Vec<f32>>,
990 scratch_boxrpb_x_feats: Option<Vec<f32>>,
991 scratch_boxrpb_y_feats: Option<Vec<f32>>,
992 scratch_sine: Vec<f32>,
994 scratch_rph_hidden: Vec<f32>,
995 scratch_query_pos: Vec<f32>,
997 scratch_bbox_h0: Vec<f32>,
999 scratch_bbox_h1: Vec<f32>,
1000 scratch_bbox_out: Vec<f32>,
1001 pub batch: usize,
1002 pub hw: usize,
1003 pub seq: usize,
1004 gguf_packed: Option<GgufPackedParams>,
1005}
1006
1007impl Sam3CompiledDecoder {
1008 pub fn new(
1009 weights: &Sam3DecoderWeights,
1010 batch: usize,
1011 hw: usize,
1012 seq: usize,
1013 device: Device,
1014 ) -> Result<Self> {
1015 Self::new_with_profile(weights, batch, hw, seq, device, &CompileProfile::sam3())
1016 }
1017
1018 pub fn new_with_profile(
1019 weights: &Sam3DecoderWeights,
1020 batch: usize,
1021 hw: usize,
1022 seq: usize,
1023 device: Device,
1024 profile: &CompileProfile,
1025 ) -> Result<Self> {
1026 Self::new_with_profile_and_gguf(weights, batch, hw, seq, device, profile, None)
1027 }
1028
1029 pub fn new_with_profile_and_gguf(
1030 weights: &Sam3DecoderWeights,
1031 batch: usize,
1032 hw: usize,
1033 seq: usize,
1034 device: Device,
1035 profile: &CompileProfile,
1036 gguf_packed: Option<&GgufPackedParams>,
1037 ) -> Result<Self> {
1038 ensure!(weights.loaded, "decoder weights not loaded");
1039 let nq = NUM_QUERIES;
1040 let d = D_MODEL;
1041 let h_w = (hw as f64).sqrt().round() as usize;
1042 ensure!(
1043 h_w * h_w == hw,
1044 "boxRPB cache requires square spatial grid; got hw={hw}"
1045 );
1046 let mut layers = Vec::with_capacity(N_LAYERS);
1047 let use_bias_attn = if matches!(device, Device::Metal) {
1052 rlx_ir::env::flag("RLX_SAM3_METAL_BIAS_SDPA")
1053 } else {
1054 true
1055 };
1056 let boxrpb_in_ir = matches!(device, Device::Mlx)
1061 || (matches!(device, Device::Cpu) && rlx_ir::env::flag("RLX_SAM3_BOXRPB_IR"));
1062 let dec_base = &weights.prefix;
1063 for (li, layer) in weights.layers.iter().enumerate() {
1064 let (hir, params, typed) = build_layer_hir(
1065 layer,
1066 &weights.boxrpb_x,
1067 &weights.boxrpb_y,
1068 &weights.norm_w,
1069 &weights.norm_b,
1070 dec_base,
1071 li,
1072 batch,
1073 h_w,
1074 h_w,
1075 seq,
1076 use_bias_attn,
1077 boxrpb_in_ir,
1078 gguf_packed,
1079 )?;
1080 let mut compiled =
1081 rlx_core::flow_bridge::compile_hir_with_profile(device, hir, profile)?;
1082 rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
1083 layers.push(compiled);
1084 }
1085 let mut cached_initial_ref_boxes = vec![0f32; batch * nq * 4];
1090 for b in 0..batch {
1091 for q in 0..nq {
1092 for k in 0..4 {
1093 let v = weights.reference_points[q * 4 + k];
1094 cached_initial_ref_boxes[(b * nq + q) * 4 + k] = sigmoid(v);
1095 }
1096 }
1097 }
1098 let sine = sineembed_4d(&cached_initial_ref_boxes, batch, nq, d);
1099 let cached_layer0_query_pos =
1100 mlp2_forward(&weights.ref_point_head, &sine, batch * nq, gguf_packed)?;
1101 let lq = nq + 1;
1102 let nh = N_HEADS;
1103 let (cached_layer0_deltas_x, cached_layer0_deltas_y, cached_layer0_rpb) = if boxrpb_in_ir {
1104 let mut dx = vec![0f32; batch * nq * h_w * 2];
1105 let mut dy = vec![0f32; batch * nq * h_w * 2];
1106 compute_deltas_into(
1107 &cached_initial_ref_boxes,
1108 batch,
1109 nq,
1110 h_w,
1111 h_w,
1112 &mut dx,
1113 &mut dy,
1114 );
1115 (Some(dx), Some(dy), None)
1116 } else {
1117 let rpb = boxrpb_log_full(
1118 &weights.boxrpb_x,
1119 &weights.boxrpb_y,
1120 &cached_initial_ref_boxes,
1121 batch,
1122 nq,
1123 h_w,
1124 h_w,
1125 gguf_packed,
1126 )?;
1127 (None, None, Some(rpb))
1128 };
1129 Ok(Self {
1130 layers,
1131 bbox_embed: weights.bbox_embed.clone(),
1132 ref_point_head: weights.ref_point_head.clone(),
1133 boxrpb_x: weights.boxrpb_x.clone(),
1134 boxrpb_y: weights.boxrpb_y.clone(),
1135 initial_query_embed: weights.query_embed.clone(),
1136 initial_reference_points: weights.reference_points.clone(),
1137 cached_layer0_query_pos,
1138 cached_layer0_deltas_x,
1139 cached_layer0_deltas_y,
1140 cached_layer0_rpb,
1141 cached_initial_ref_boxes,
1142 boxrpb_in_ir,
1143 presence_token: weights.presence_token.clone(),
1144 presence_head: weights.presence_token_head.clone(),
1145 presence_norm_w: weights.presence_token_out_norm_w.clone(),
1146 presence_norm_b: weights.presence_token_out_norm_b.clone(),
1147 scratch_deltas_x: if boxrpb_in_ir {
1148 vec![0f32; batch * nq * h_w * 2]
1149 } else {
1150 Vec::new()
1151 },
1152 scratch_deltas_y: if boxrpb_in_ir {
1153 vec![0f32; batch * nq * h_w * 2]
1154 } else {
1155 Vec::new()
1156 },
1157 scratch_rpb: (!boxrpb_in_ir).then(|| vec![0f32; batch * nh * lq * hw]),
1158 scratch_dx_thq: (!boxrpb_in_ir).then(|| vec![0f32; nh * nq * h_w]),
1159 scratch_dy_thq: (!boxrpb_in_ir).then(|| vec![0f32; nh * nq * h_w]),
1160 scratch_boxrpb_x_hidden: (!boxrpb_in_ir)
1161 .then(|| vec![0f32; nq * h_w * weights.boxrpb_x.hidden]),
1162 scratch_boxrpb_y_hidden: (!boxrpb_in_ir)
1163 .then(|| vec![0f32; nq * h_w * weights.boxrpb_y.hidden]),
1164 scratch_boxrpb_x_feats: (!boxrpb_in_ir)
1165 .then(|| vec![0f32; nq * h_w * weights.boxrpb_x.out_dim]),
1166 scratch_boxrpb_y_feats: (!boxrpb_in_ir)
1167 .then(|| vec![0f32; nq * h_w * weights.boxrpb_y.out_dim]),
1168 scratch_sine: vec![0f32; batch * nq * 2 * d],
1169 scratch_rph_hidden: vec![0f32; batch * nq * weights.ref_point_head.hidden],
1170 scratch_query_pos: vec![0f32; batch * nq * weights.ref_point_head.out_dim],
1171 scratch_bbox_h0: vec![0f32; batch * nq * weights.bbox_embed.hidden],
1172 scratch_bbox_h1: vec![0f32; batch * nq * weights.bbox_embed.hidden],
1173 scratch_bbox_out: vec![0f32; batch * nq * weights.bbox_embed.out_dim],
1174 batch,
1175 hw,
1176 seq,
1177 gguf_packed: gguf_packed.cloned(),
1178 })
1179 }
1180
1181 pub fn run(
1185 &mut self,
1186 memory: &[f32],
1187 memory_pos: &[f32],
1188 text_seq_first: &[f32],
1189 text_kpm: &[u8],
1190 h: usize,
1191 w: usize,
1192 ) -> Result<LayerRunOut> {
1193 let hw = h * w;
1194 ensure!(hw == self.hw);
1195 let batch = self.batch;
1196 let nq = NUM_QUERIES;
1197 let d = D_MODEL;
1198 let nh = N_HEADS;
1199 let lq = nq + 1;
1200 let seq = self.seq;
1201
1202 let mut tgt = vec![0f32; batch * nq * d];
1204 for b in 0..batch {
1205 tgt[b * nq * d..(b + 1) * nq * d].copy_from_slice(&self.initial_query_embed);
1206 }
1207 let mut ref_boxes = vec![0f32; batch * nq * 4];
1209 for b in 0..batch {
1210 for q in 0..nq {
1211 for k in 0..4 {
1212 let v = self.initial_reference_points[q * 4 + k];
1213 ref_boxes[(b * nq + q) * 4 + k] = sigmoid(v);
1214 }
1215 }
1216 }
1217 let mut presence = vec![0f32; batch * d];
1218 for b in 0..batch {
1219 presence[b * d..(b + 1) * d].copy_from_slice(&self.presence_token);
1220 }
1221
1222 let mut text_bf = vec![0f32; batch * seq * d];
1224 for b in 0..batch {
1225 for l in 0..seq {
1226 let s = (l * batch + b) * d;
1227 let dst = (b * seq + l) * d;
1228 text_bf[dst..dst + d].copy_from_slice(&text_seq_first[s..s + d]);
1229 }
1230 }
1231 let text_kpm_inv: Vec<f32> = text_kpm
1232 .iter()
1233 .map(|&v| if v == 0 { 1.0 } else { 0.0 })
1234 .collect();
1235
1236 let mut intermediate = Vec::with_capacity(N_LAYERS);
1237 let mut intermediate_ref_boxes = Vec::with_capacity(N_LAYERS);
1238 intermediate_ref_boxes.push(ref_boxes.clone());
1239 let mut presence_logits = Vec::with_capacity(N_LAYERS);
1240
1241 let profile = rlx_ir::env::flag("RLX_SAM3_PROFILE");
1242 let mut t_qpos = 0u128;
1243 let mut t_rpb = 0u128;
1244 let mut t_graph = 0u128;
1245 let mut t_box = 0u128;
1246 let mut t_other = 0u128;
1247 for li in 0..N_LAYERS {
1248 let tq = std::time::Instant::now();
1249 let query_pos_slice: &[f32];
1257 let rpb_slice: &[f32];
1258 let deltas_x_slice: &[f32];
1259 let deltas_y_slice: &[f32];
1260 if li == 0 {
1261 query_pos_slice = &self.cached_layer0_query_pos;
1262 if self.boxrpb_in_ir {
1263 deltas_x_slice = self.cached_layer0_deltas_x.as_ref().unwrap();
1264 deltas_y_slice = self.cached_layer0_deltas_y.as_ref().unwrap();
1265 rpb_slice = &[];
1266 } else {
1267 rpb_slice = self.cached_layer0_rpb.as_ref().unwrap();
1268 deltas_x_slice = &[];
1269 deltas_y_slice = &[];
1270 }
1271 } else {
1272 sineembed_4d_into(&ref_boxes, batch, nq, d, &mut self.scratch_sine);
1273 mlp2_forward_into(
1274 &self.ref_point_head,
1275 &self.scratch_sine,
1276 batch * nq,
1277 &mut self.scratch_rph_hidden,
1278 &mut self.scratch_query_pos,
1279 self.gguf_packed.as_ref(),
1280 )?;
1281 query_pos_slice = &self.scratch_query_pos;
1282 if self.boxrpb_in_ir {
1283 compute_deltas_into(
1284 &ref_boxes,
1285 batch,
1286 nq,
1287 h,
1288 w,
1289 &mut self.scratch_deltas_x,
1290 &mut self.scratch_deltas_y,
1291 );
1292 deltas_x_slice = &self.scratch_deltas_x;
1293 deltas_y_slice = &self.scratch_deltas_y;
1294 rpb_slice = &[];
1295 } else {
1296 let mut host_deltas_x = vec![0f32; nq * w * 2];
1297 let mut host_deltas_y = vec![0f32; nq * h * 2];
1298 boxrpb_log_full_into(
1299 &self.boxrpb_x,
1300 &self.boxrpb_y,
1301 &ref_boxes,
1302 batch,
1303 nq,
1304 h,
1305 w,
1306 self.scratch_rpb.as_mut().unwrap(),
1307 self.scratch_dx_thq.as_mut().unwrap(),
1308 self.scratch_dy_thq.as_mut().unwrap(),
1309 &mut host_deltas_x,
1310 &mut host_deltas_y,
1311 self.scratch_boxrpb_x_hidden.as_mut().unwrap(),
1312 self.scratch_boxrpb_y_hidden.as_mut().unwrap(),
1313 self.scratch_boxrpb_x_feats.as_mut().unwrap(),
1314 self.scratch_boxrpb_y_feats.as_mut().unwrap(),
1315 self.gguf_packed.as_ref(),
1316 )?;
1317 rpb_slice = self.scratch_rpb.as_ref().unwrap();
1318 deltas_x_slice = &[];
1319 deltas_y_slice = &[];
1320 }
1321 }
1322 if profile {
1323 t_qpos += tq.elapsed().as_micros();
1324 }
1325
1326 let tr = std::time::Instant::now();
1327 if profile {
1328 t_rpb += tr.elapsed().as_micros();
1329 }
1330
1331 let tg = std::time::Instant::now();
1333 let outputs = if self.boxrpb_in_ir {
1334 self.layers[li].run(&[
1335 ("tgt", tgt.as_slice()),
1336 ("query_pos", query_pos_slice),
1337 ("presence", presence.as_slice()),
1338 ("memory", memory),
1339 ("memory_pos", memory_pos),
1340 ("text", text_bf.as_slice()),
1341 ("text_kpm_inv", text_kpm_inv.as_slice()),
1342 ("deltas_x", deltas_x_slice),
1343 ("deltas_y", deltas_y_slice),
1344 ])
1345 } else {
1346 self.layers[li].run(&[
1347 ("tgt", tgt.as_slice()),
1348 ("query_pos", query_pos_slice),
1349 ("presence", presence.as_slice()),
1350 ("memory", memory),
1351 ("memory_pos", memory_pos),
1352 ("text", text_bf.as_slice()),
1353 ("text_kpm_inv", text_kpm_inv.as_slice()),
1354 ("rpb_bias", rpb_slice),
1355 ])
1356 };
1357 if profile {
1358 t_graph += tg.elapsed().as_micros();
1359 }
1360 ensure!(outputs.len() == 3, "decoder layer expected 3 outputs");
1361 tgt = outputs[0].clone();
1362 presence = outputs[1].clone();
1363 let out_norm = outputs[2].clone();
1364
1365 let tb = std::time::Instant::now();
1366 mlp3_forward_into(
1368 &self.bbox_embed,
1369 &out_norm,
1370 batch * nq,
1371 &mut self.scratch_bbox_h0,
1372 &mut self.scratch_bbox_h1,
1373 &mut self.scratch_bbox_out,
1374 self.gguf_packed.as_ref(),
1375 )?;
1376 let delta: &[f32] = &self.scratch_bbox_out;
1377 if profile {
1378 t_box += tb.elapsed().as_micros();
1379 }
1380 let to = std::time::Instant::now();
1381 let _ = to;
1382 let _ = &mut t_other;
1383 let mut new_ref = vec![0f32; batch * nq * 4];
1384 for q in 0..nq {
1385 for b in 0..batch {
1386 let cur = &ref_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1387 let dl = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1388 for k in 0..4 {
1389 new_ref[(b * nq + q) * 4 + k] = sigmoid(inv_sigmoid(cur[k]) + dl[k]);
1390 }
1391 }
1392 }
1393 ref_boxes = new_ref;
1394 if li != N_LAYERS - 1 {
1395 intermediate_ref_boxes.push(ref_boxes.clone());
1396 }
1397
1398 let mut out_seq_first = vec![0f32; nq * batch * d];
1400 for q in 0..nq {
1401 for b in 0..batch {
1402 let src = (b * nq + q) * d;
1403 let dst = (q * batch + b) * d;
1404 out_seq_first[dst..dst + d].copy_from_slice(&out_norm[src..src + d]);
1405 }
1406 }
1407 intermediate.push(out_seq_first);
1408
1409 let p_norm =
1411 layer_norm_host(&presence, &self.presence_norm_w, &self.presence_norm_b, d);
1412 let p_logit = mlp3_forward(
1413 &self.presence_head,
1414 &p_norm,
1415 batch,
1416 self.gguf_packed.as_ref(),
1417 )?;
1418 presence_logits.push(p_logit);
1419 }
1420 if profile {
1421 let to_ms = |us: u128| us as f32 / 1000.0;
1422 eprintln!(
1423 " decoder per-stage (6 layers total): qpos={:.1}ms rpb={:.1}ms graph={:.1}ms box={:.1}ms",
1424 to_ms(t_qpos),
1425 to_ms(t_rpb),
1426 to_ms(t_graph),
1427 to_ms(t_box)
1428 );
1429 }
1430
1431 let mut int_stack = vec![0f32; N_LAYERS * nq * batch * d];
1433 for (li, l) in intermediate.iter().enumerate() {
1434 int_stack[li * nq * batch * d..(li + 1) * nq * batch * d].copy_from_slice(l);
1435 }
1436 let mut ref_stack = vec![0f32; N_LAYERS * nq * batch * 4];
1437 for (li, r) in intermediate_ref_boxes.iter().enumerate() {
1438 ref_stack[li * nq * batch * 4..(li + 1) * nq * batch * 4].copy_from_slice(r);
1439 }
1440 let mut presence_stack = vec![0f32; N_LAYERS * batch];
1441 for (li, p) in presence_logits.iter().enumerate() {
1442 for b in 0..batch {
1443 presence_stack[li * batch + b] = p[b];
1444 }
1445 }
1446 let _ = nh;
1447 let _ = lq;
1448 Ok((int_stack, ref_stack, presence_stack, presence))
1449 }
1450}
1451
1452#[allow(clippy::too_many_arguments)]
1454pub fn forward_decoder_ir_on(
1455 weights: &Sam3DecoderWeights,
1456 memory: &[f32],
1457 memory_pos: &[f32],
1458 memory_text: &[f32],
1459 text_attention_mask: &[u8],
1460 batch: usize,
1461 h: usize,
1462 w: usize,
1463 seq_len: usize,
1464 device: Device,
1465) -> Result<Sam3DecoderOutput> {
1466 forward_decoder_ir_on_with_profile(
1467 weights,
1468 memory,
1469 memory_pos,
1470 memory_text,
1471 text_attention_mask,
1472 batch,
1473 h,
1474 w,
1475 seq_len,
1476 device,
1477 &CompileProfile::sam3(),
1478 None,
1479 )
1480}
1481
1482#[allow(clippy::too_many_arguments)]
1484pub fn forward_decoder_ir_on_with_profile(
1485 weights: &Sam3DecoderWeights,
1486 memory: &[f32],
1487 memory_pos: &[f32],
1488 memory_text: &[f32],
1489 text_attention_mask: &[u8],
1490 batch: usize,
1491 h: usize,
1492 w: usize,
1493 seq_len: usize,
1494 device: Device,
1495 profile: &CompileProfile,
1496 gguf_packed: Option<&GgufPackedParams>,
1497) -> Result<Sam3DecoderOutput> {
1498 ensure!(weights.loaded, "decoder weights not loaded");
1499 ensure!(batch == 1, "decoder IR forward requires batch=1 for boxRPB");
1500 let hw = h * w;
1501 let mut dec = Sam3CompiledDecoder::new_with_profile_and_gguf(
1502 weights,
1503 batch,
1504 hw,
1505 seq_len,
1506 device,
1507 profile,
1508 gguf_packed,
1509 )?;
1510 let (intermediate, intermediate_ref_boxes, presence_logits, presence_feats) =
1511 dec.run(memory, memory_pos, memory_text, text_attention_mask, h, w)?;
1512 Ok(Sam3DecoderOutput {
1513 intermediate,
1514 intermediate_ref_boxes,
1515 presence_logits,
1516 presence_feats,
1517 num_layers: N_LAYERS,
1518 num_queries: NUM_QUERIES,
1519 batch,
1520 d_model: D_MODEL,
1521 })
1522}
1523
1524fn sigmoid(x: f32) -> f32 {
1527 1.0 / (1.0 + (-x).exp())
1528}
1529
1530fn inv_sigmoid(x: f32) -> f32 {
1531 let eps = 1e-3f32;
1532 let x = x.clamp(0.0, 1.0).max(eps).min(1.0 - eps);
1533 (x / (1.0 - x)).ln()
1534}
1535
1536fn layer_norm_host(x: &[f32], gamma: &[f32], beta: &[f32], dim: usize) -> Vec<f32> {
1537 let rows = x.len() / dim;
1538 let mut out = vec![0f32; x.len()];
1539 for r in 0..rows {
1540 let row = &x[r * dim..(r + 1) * dim];
1541 let mean = row.iter().sum::<f32>() / dim as f32;
1542 let var = row.iter().map(|v| (*v - mean).powi(2)).sum::<f32>() / dim as f32;
1543 let inv = 1.0 / (var + 1e-5).sqrt();
1544 for c in 0..dim {
1545 out[r * dim + c] = (row[c] - mean) * inv * gamma[c] + beta[c];
1546 }
1547 }
1548 out
1549}
1550
1551#[allow(dead_code)]
1552fn host_mlp2_forward(mlp: &Mlp2, x: &[f32], rows: usize) -> Result<Vec<f32>> {
1553 let h = matmul_bias_relu(x, &mlp.w0_t, &mlp.b0, rows, mlp.in_dim, mlp.hidden);
1554 Ok(matmul_bias(
1555 &h,
1556 &mlp.w1_t,
1557 &mlp.b1,
1558 rows,
1559 mlp.hidden,
1560 mlp.out_dim,
1561 ))
1562}
1563
1564#[allow(dead_code)]
1568fn host_mlp2_forward_into(mlp: &Mlp2, x: &[f32], rows: usize, hidden: &mut [f32], out: &mut [f32]) {
1569 rlx_cpu::blas::sgemm_bias_epilogue(
1570 x,
1571 &mlp.w0_t,
1572 &mlp.b0,
1573 hidden,
1574 rows,
1575 mlp.in_dim,
1576 mlp.hidden,
1577 |v| if v < 0.0 { 0.0 } else { v },
1578 );
1579 rlx_cpu::blas::sgemm_bias(
1580 hidden,
1581 &mlp.w1_t,
1582 &mlp.b1,
1583 out,
1584 rows,
1585 mlp.hidden,
1586 mlp.out_dim,
1587 );
1588}
1589
1590#[allow(dead_code)]
1591fn host_mlp3_forward(mlp: &Mlp3, x: &[f32], rows: usize) -> Result<Vec<f32>> {
1592 let h = matmul_bias_relu(x, &mlp.w0_t, &mlp.b0, rows, mlp.in_dim, mlp.hidden);
1593 let h = matmul_bias_relu(&h, &mlp.w1_t, &mlp.b1, rows, mlp.hidden, mlp.hidden);
1594 Ok(matmul_bias(
1595 &h,
1596 &mlp.w2_t,
1597 &mlp.b2,
1598 rows,
1599 mlp.hidden,
1600 mlp.out_dim,
1601 ))
1602}
1603
1604#[allow(dead_code)]
1605fn host_mlp3_forward_into(
1606 mlp: &Mlp3,
1607 x: &[f32],
1608 rows: usize,
1609 h0: &mut [f32],
1610 h1: &mut [f32],
1611 out: &mut [f32],
1612) {
1613 let relu = |v: f32| if v < 0.0 { 0.0 } else { v };
1614 rlx_cpu::blas::sgemm_bias_epilogue(
1615 x, &mlp.w0_t, &mlp.b0, h0, rows, mlp.in_dim, mlp.hidden, relu,
1616 );
1617 rlx_cpu::blas::sgemm_bias_epilogue(
1618 h0, &mlp.w1_t, &mlp.b1, h1, rows, mlp.hidden, mlp.hidden, relu,
1619 );
1620 rlx_cpu::blas::sgemm_bias(h1, &mlp.w2_t, &mlp.b2, out, rows, mlp.hidden, mlp.out_dim);
1621}
1622
1623#[allow(dead_code)]
1624fn matmul_bias(x: &[f32], w_t: &[f32], b: &[f32], rows: usize, k: usize, n: usize) -> Vec<f32> {
1625 let mut out = vec![0f32; rows * n];
1626 rlx_cpu::blas::sgemm_bias(x, w_t, b, &mut out, rows, k, n);
1627 out
1628}
1629
1630#[allow(dead_code)]
1631fn matmul_bias_relu(
1632 x: &[f32],
1633 w_t: &[f32],
1634 b: &[f32],
1635 rows: usize,
1636 k: usize,
1637 n: usize,
1638) -> Vec<f32> {
1639 let mut out = matmul_bias(x, w_t, b, rows, k, n);
1640 for v in out.iter_mut() {
1641 if *v < 0.0 {
1642 *v = 0.0;
1643 }
1644 }
1645 out
1646}
1647
1648fn sineembed_4d(pos: &[f32], batch: usize, nq: usize, d_model: usize) -> Vec<f32> {
1649 let mut out = vec![0.0f32; batch * nq * 2 * d_model];
1650 sineembed_4d_into(pos, batch, nq, d_model, &mut out);
1651 out
1652}
1653
1654fn sineembed_4d_into(pos: &[f32], batch: usize, nq: usize, d_model: usize, out: &mut [f32]) {
1655 let half = d_model / 2;
1656 let scale = 2.0 * std::f32::consts::PI;
1657 let mut dim_t = vec![0.0f32; half];
1658 for i in 0..half {
1659 let exp = 2.0 * ((i / 2) as f32) / half as f32;
1660 dim_t[i] = 10000.0f32.powf(exp);
1661 }
1662 debug_assert_eq!(out.len(), batch * nq * 2 * d_model);
1663 for b in 0..batch {
1664 for q in 0..nq {
1665 let p = &pos[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1666 let vals = [p[1] * scale, p[0] * scale, p[2] * scale, p[3] * scale];
1667 let base = (b * nq + q) * 2 * d_model;
1668 for axis in 0..4 {
1669 let slot = base + axis * half;
1670 for i in 0..half {
1671 let theta = vals[axis] / dim_t[i];
1672 out[slot + i] = if i % 2 == 0 { theta.sin() } else { theta.cos() };
1673 }
1674 }
1675 }
1676 }
1677}
1678
1679fn boxrpb_log_full(
1683 boxrpb_x: &Mlp2,
1684 boxrpb_y: &Mlp2,
1685 reference_boxes: &[f32],
1686 batch: usize,
1687 nq: usize,
1688 h: usize,
1689 w: usize,
1690 gguf_packed: Option<&GgufPackedParams>,
1691) -> Result<Vec<f32>> {
1692 let nh = N_HEADS;
1693 let lq = nq + 1;
1694 let mut out = vec![0f32; batch * nh * lq * h * w];
1695 let mut dx_thq = vec![0f32; nh * nq * w];
1696 let mut dy_thq = vec![0f32; nh * nq * h];
1697 let mut deltas_x = vec![0f32; nq * w * 2];
1698 let mut deltas_y = vec![0f32; nq * h * 2];
1699 let mut hidden_x = vec![0f32; nq * w * boxrpb_x.hidden];
1700 let mut hidden_y = vec![0f32; nq * h * boxrpb_y.hidden];
1701 let mut feats_x = vec![0f32; nq * w * boxrpb_x.out_dim];
1702 let mut feats_y = vec![0f32; nq * h * boxrpb_y.out_dim];
1703 boxrpb_log_full_into(
1704 boxrpb_x,
1705 boxrpb_y,
1706 reference_boxes,
1707 batch,
1708 nq,
1709 h,
1710 w,
1711 &mut out,
1712 &mut dx_thq,
1713 &mut dy_thq,
1714 &mut deltas_x,
1715 &mut deltas_y,
1716 &mut hidden_x,
1717 &mut hidden_y,
1718 &mut feats_x,
1719 &mut feats_y,
1720 gguf_packed,
1721 )?;
1722 Ok(out)
1723}
1724
1725#[allow(clippy::too_many_arguments)]
1726fn boxrpb_log_full_into(
1727 boxrpb_x: &Mlp2,
1728 boxrpb_y: &Mlp2,
1729 reference_boxes: &[f32],
1730 batch: usize,
1731 nq: usize,
1732 h: usize,
1733 w: usize,
1734 out: &mut [f32],
1735 dx_thq: &mut [f32],
1736 dy_thq: &mut [f32],
1737 deltas_x: &mut [f32],
1738 deltas_y: &mut [f32],
1739 hidden_x: &mut [f32],
1740 hidden_y: &mut [f32],
1741 feats_x: &mut [f32],
1742 feats_y: &mut [f32],
1743 gguf_packed: Option<&GgufPackedParams>,
1744) -> Result<()> {
1745 let nh = N_HEADS;
1746 let lq = nq + 1;
1747 debug_assert_eq!(out.len(), batch * nh * lq * h * w);
1748 debug_assert_eq!(dx_thq.len(), nh * nq * w);
1749 debug_assert_eq!(dy_thq.len(), nh * nq * h);
1750 debug_assert_eq!(deltas_x.len(), nq * w * 2);
1751 debug_assert_eq!(deltas_y.len(), nq * h * 2);
1752 debug_assert_eq!(hidden_x.len(), nq * w * boxrpb_x.hidden);
1753 debug_assert_eq!(hidden_y.len(), nq * h * boxrpb_y.hidden);
1754 debug_assert_eq!(feats_x.len(), nq * w * boxrpb_x.out_dim);
1755 debug_assert_eq!(feats_y.len(), nq * h * boxrpb_y.out_dim);
1756 for head in 0..nh {
1758 for b in 0..batch {
1759 let off = b * nh * lq * h * w + head * lq * h * w;
1760 for i in 0..h * w {
1762 out[off + i] = 0.0;
1763 }
1764 }
1765 }
1766 let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
1767 let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
1768
1769 for b in 0..batch {
1770 for q in 0..nq {
1771 let p = &reference_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1772 let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
1773 let x0 = cx - 0.5 * bw;
1774 let x1 = cx + 0.5 * bw;
1775 let y0 = cy - 0.5 * bh;
1776 let y1 = cy + 0.5 * bh;
1777 for xi in 0..w {
1778 let dx0 = (coords_w[xi] - x0) * 8.0;
1779 let dx1 = (coords_w[xi] - x1) * 8.0;
1780 deltas_x[(q * w + xi) * 2] = log_norm(dx0);
1781 deltas_x[(q * w + xi) * 2 + 1] = log_norm(dx1);
1782 }
1783 for yi in 0..h {
1784 let dy0 = (coords_h[yi] - y0) * 8.0;
1785 let dy1 = (coords_h[yi] - y1) * 8.0;
1786 deltas_y[(q * h + yi) * 2] = log_norm(dy0);
1787 deltas_y[(q * h + yi) * 2 + 1] = log_norm(dy1);
1788 }
1789 }
1790 mlp2_forward_into(boxrpb_x, deltas_x, nq * w, hidden_x, feats_x, gguf_packed)?;
1791 mlp2_forward_into(boxrpb_y, deltas_y, nq * h, hidden_y, feats_y, gguf_packed)?;
1792 let dx_feats: &[f32] = feats_x;
1793 let dy_feats: &[f32] = feats_y;
1794 for q in 0..nq {
1797 for xi in 0..w {
1798 let src_base = (q * w + xi) * nh;
1799 for head in 0..nh {
1800 dx_thq[(head * nq + q) * w + xi] = dx_feats[src_base + head];
1801 }
1802 }
1803 for yi in 0..h {
1804 let src_base = (q * h + yi) * nh;
1805 for head in 0..nh {
1806 dy_thq[(head * nq + q) * h + yi] = dy_feats[src_base + head];
1807 }
1808 }
1809 }
1810 let base = b * nh * lq * h * w;
1811 let total = nh * nq;
1812 let out_ptr = out.as_mut_ptr() as usize;
1813 let dx_ptr = dx_thq.as_ptr() as usize;
1814 let dy_ptr = dy_thq.as_ptr() as usize;
1815 rlx_cpu::pool::par_for(total, 8, &|off, cnt| unsafe {
1816 for idx in off..off + cnt {
1817 let head = idx / nq;
1818 let q = idx % nq;
1819 let dst = (out_ptr as *mut f32).add(base + (head * lq + 1 + q) * h * w);
1820 let dx_row =
1821 std::slice::from_raw_parts((dx_ptr as *const f32).add((head * nq + q) * w), w);
1822 let dy_row =
1823 std::slice::from_raw_parts((dy_ptr as *const f32).add((head * nq + q) * h), h);
1824 for y in 0..h {
1825 let dy = dy_row[y];
1826 let row_dst = dst.add(y * w);
1827 for x in 0..w {
1828 *row_dst.add(x) = dy + dx_row[x];
1829 }
1830 }
1831 }
1832 });
1833 }
1834 Ok(())
1835}
1836
1837fn log_norm(v: f32) -> f32 {
1838 let s = if v < 0.0 { -1.0 } else { 1.0 };
1839 s * (v.abs() + 1.0).log2() / 8.0f32.log2()
1840}
1841
1842fn compute_deltas_into(
1847 reference_boxes: &[f32],
1848 batch: usize,
1849 nq: usize,
1850 h: usize,
1851 w: usize,
1852 deltas_x: &mut [f32],
1853 deltas_y: &mut [f32],
1854) {
1855 debug_assert_eq!(deltas_x.len(), batch * nq * w * 2);
1856 debug_assert_eq!(deltas_y.len(), batch * nq * h * 2);
1857 let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
1858 let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
1859 for b in 0..batch {
1860 for q in 0..nq {
1861 let p = &reference_boxes[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1862 let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
1863 let x0 = cx - 0.5 * bw;
1864 let x1 = cx + 0.5 * bw;
1865 let y0 = cy - 0.5 * bh;
1866 let y1 = cy + 0.5 * bh;
1867 let dx_off = ((b * nq + q) * w) * 2;
1868 for xi in 0..w {
1869 let dx0 = (coords_w[xi] - x0) * 8.0;
1870 let dx1 = (coords_w[xi] - x1) * 8.0;
1871 deltas_x[dx_off + xi * 2] = log_norm(dx0);
1872 deltas_x[dx_off + xi * 2 + 1] = log_norm(dx1);
1873 }
1874 let dy_off = ((b * nq + q) * h) * 2;
1875 for yi in 0..h {
1876 let dy0 = (coords_h[yi] - y0) * 8.0;
1877 let dy1 = (coords_h[yi] - y1) * 8.0;
1878 deltas_y[dy_off + yi * 2] = log_norm(dy0);
1879 deltas_y[dy_off + yi * 2 + 1] = log_norm(dy1);
1880 }
1881 }
1882 }
1883}
1884
1885#[allow(dead_code)]
1887fn build_boxrpb_check_hir(
1888 boxrpb_x: &Mlp2,
1889 boxrpb_y: &Mlp2,
1890 batch: usize,
1891 nq: usize,
1892 h: usize,
1893 w: usize,
1894) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1895 let nh = N_HEADS;
1896 let mut hir = HirModule::new("sam3_boxrpb_check");
1897 let mut params = HashMap::new();
1898 let mut typed = Vec::new();
1899 let mut gguf_cache = HashMap::new();
1900 {
1901 let mut g = HirMut::new(&mut hir);
1902 let f = DType::F32;
1903 let deltas_x = g.input("deltas_x", Shape::new(&[batch, nq, w, 2], f));
1904 let deltas_y = g.input("deltas_y", Shape::new(&[batch, nq, h, 2], f));
1905 let out = build_boxrpb_subgraph(
1906 &mut g,
1907 &mut params,
1908 &mut typed,
1909 &mut gguf_cache,
1910 None,
1911 boxrpb_x,
1912 boxrpb_y,
1913 deltas_x,
1914 deltas_y,
1915 batch,
1916 nq,
1917 nh,
1918 h,
1919 w,
1920 )?;
1921 g.set_outputs(vec![out]);
1922 }
1923 Ok((hir, params))
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928 use super::*;
1929
1930 fn synth_mlp2(in_d: usize, hidden: usize, out_d: usize) -> Mlp2 {
1931 Mlp2 {
1932 w0_t: vec![0.01; in_d * hidden],
1933 b0: vec![0.0; hidden],
1934 w1_t: vec![0.02; hidden * out_d],
1935 b1: vec![0.0; out_d],
1936 in_dim: in_d,
1937 hidden,
1938 out_dim: out_d,
1939 w0_gguf_key: None,
1940 w1_gguf_key: None,
1941 }
1942 }
1943
1944 #[test]
1945 fn sam3_boxrpb_ir_matches_host_cpu() -> Result<()> {
1946 let batch = 1usize;
1947 let nq = 2usize;
1948 let h = 4usize;
1949 let w = 4usize;
1950 let nh = N_HEADS;
1951 let boxrpb_x = synth_mlp2(2, 16, nh);
1952 let boxrpb_y = synth_mlp2(2, 16, nh);
1953 let ref_boxes = vec![
1954 0.5, 0.5, 0.4, 0.4, 0.3, 0.7, 0.2, 0.3,
1956 ];
1957 let host = boxrpb_log_full(&boxrpb_x, &boxrpb_y, &ref_boxes, batch, nq, h, w, None)?;
1958
1959 let mut deltas_x = vec![0f32; batch * nq * w * 2];
1960 let mut deltas_y = vec![0f32; batch * nq * h * 2];
1961 compute_deltas_into(&ref_boxes, batch, nq, h, w, &mut deltas_x, &mut deltas_y);
1962
1963 let (hir, params) = build_boxrpb_check_hir(&boxrpb_x, &boxrpb_y, batch, nq, h, w)?;
1964 let mut compiled = rlx_core::flow_bridge::compile_hir_sam(Device::Cpu, hir)?;
1965 for (name, data) in ¶ms {
1966 compiled.set_param(name, data);
1967 }
1968 let ir = compiled
1969 .run(&[("deltas_x", &deltas_x), ("deltas_y", &deltas_y)])
1970 .into_iter()
1971 .next()
1972 .unwrap();
1973
1974 let fd = host
1975 .iter()
1976 .zip(&ir)
1977 .map(|(a, b)| (a - b).abs())
1978 .fold(0f32, f32::max);
1979 assert!(fd < 5e-2, "sam3 boxRPB IR vs host max |Δ| = {fd:.3e}");
1980 Ok(())
1981 }
1982}