1use super::tensor::{layer_norm, linear};
29use anyhow::{Result, ensure};
30use rlx_core::weight_map::WeightMap;
31use rlx_flow::GgufPackedParams;
32
33use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
34
35const D_MODEL: usize = 256;
36const DIM_FF: usize = 2048;
37const N_HEADS: usize = 8;
38const N_LAYERS: usize = 6;
39const NUM_QUERIES: usize = 200;
40
41#[derive(Clone)]
42pub struct Sam3DecoderLayerWeights {
43 pub self_attn_in_w_t: Vec<f32>,
44 pub self_attn_in_b: Vec<f32>,
45 pub self_attn_in_gguf_key: Option<String>,
46 pub self_attn_out_w_t: Vec<f32>,
47 pub self_attn_out_b: Vec<f32>,
48 pub self_attn_out_gguf_key: Option<String>,
49 pub ca_text_in_w_t: Vec<f32>,
50 pub ca_text_in_b: Vec<f32>,
51 pub ca_text_in_gguf_key: Option<String>,
52 pub ca_text_out_w_t: Vec<f32>,
53 pub ca_text_out_b: Vec<f32>,
54 pub ca_text_out_gguf_key: Option<String>,
55 pub cross_attn_in_w_t: Vec<f32>,
56 pub cross_attn_in_b: Vec<f32>,
57 pub cross_attn_in_gguf_key: Option<String>,
58 pub cross_attn_out_w_t: Vec<f32>,
59 pub cross_attn_out_b: Vec<f32>,
60 pub cross_attn_out_gguf_key: Option<String>,
61 pub linear1_w_t: Vec<f32>,
62 pub linear1_b: Vec<f32>,
63 pub linear1_gguf_key: Option<String>,
64 pub linear2_w_t: Vec<f32>,
65 pub linear2_b: Vec<f32>,
66 pub linear2_gguf_key: Option<String>,
67 pub norm1_w: Vec<f32>, pub norm1_b: Vec<f32>,
69 pub norm2_w: Vec<f32>, pub norm2_b: Vec<f32>,
71 pub norm3_w: Vec<f32>, pub norm3_b: Vec<f32>,
73 pub catext_norm_w: Vec<f32>, pub catext_norm_b: Vec<f32>,
75}
76
77#[derive(Clone, Default)]
78pub struct Sam3DecoderWeights {
79 pub loaded: bool,
80 pub prefix: String,
82 pub layers: Vec<Sam3DecoderLayerWeights>,
83 pub query_embed: Vec<f32>, pub reference_points: Vec<f32>, pub norm_w: Vec<f32>,
86 pub norm_b: Vec<f32>,
87 pub bbox_embed: Mlp3, pub ref_point_head: Mlp2, pub boxrpb_x: Mlp2, pub boxrpb_y: Mlp2, pub presence_token: Vec<f32>, pub presence_token_head: Mlp3, pub presence_token_out_norm_w: Vec<f32>,
94 pub presence_token_out_norm_b: Vec<f32>,
95}
96
97#[derive(Clone, Default)]
98pub struct Mlp2 {
99 pub w0_t: Vec<f32>,
100 pub b0: Vec<f32>,
101 pub w1_t: Vec<f32>,
102 pub b1: Vec<f32>,
103 pub in_dim: usize,
104 pub hidden: usize,
105 pub out_dim: usize,
106 pub w0_gguf_key: Option<String>,
107 pub w1_gguf_key: Option<String>,
108}
109
110#[derive(Clone, Default)]
111pub struct Mlp3 {
112 pub w0_t: Vec<f32>,
113 pub b0: Vec<f32>,
114 pub w1_t: Vec<f32>,
115 pub b1: Vec<f32>,
116 pub w2_t: Vec<f32>,
117 pub b2: Vec<f32>,
118 pub in_dim: usize,
119 pub hidden: usize,
120 pub out_dim: usize,
121 pub w0_gguf_key: Option<String>,
122 pub w1_gguf_key: Option<String>,
123 pub w2_gguf_key: Option<String>,
124}
125
126pub fn take_mlp2(
127 weights: &mut WeightMap,
128 gguf_packed: Option<&GgufPackedParams>,
129 base: &str,
130 in_dim: usize,
131 hidden: usize,
132 out_dim: usize,
133) -> Result<Mlp2> {
134 let (w0_t, w0_gguf_key) =
135 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.0.weight"))?;
136 let (b0, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.0.bias"))?;
137 let (w1_t, w1_gguf_key) =
138 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.1.weight"))?;
139 let (b1, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.1.bias"))?;
140 Ok(Mlp2 {
141 w0_t,
142 b0,
143 w1_t,
144 b1,
145 in_dim,
146 hidden,
147 out_dim,
148 w0_gguf_key,
149 w1_gguf_key,
150 })
151}
152
153pub fn take_mlp3(
154 weights: &mut WeightMap,
155 gguf_packed: Option<&GgufPackedParams>,
156 base: &str,
157 in_dim: usize,
158 hidden: usize,
159 out_dim: usize,
160) -> Result<Mlp3> {
161 let (w0_t, w0_gguf_key) =
162 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.0.weight"))?;
163 let (b0, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.0.bias"))?;
164 let (w1_t, w1_gguf_key) =
165 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.1.weight"))?;
166 let (b1, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.1.bias"))?;
167 let (w2_t, w2_gguf_key) =
168 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.layers.2.weight"))?;
169 let (b2, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.layers.2.bias"))?;
170 Ok(Mlp3 {
171 w0_t,
172 b0,
173 w1_t,
174 b1,
175 w2_t,
176 b2,
177 in_dim,
178 hidden,
179 out_dim,
180 w0_gguf_key,
181 w1_gguf_key,
182 w2_gguf_key,
183 })
184}
185
186pub fn mlp2_forward(
187 mlp: &Mlp2,
188 x: &[f32],
189 rows: usize,
190 gguf_packed: Option<&GgufPackedParams>,
191) -> Result<Vec<f32>> {
192 let mut h = linear_maybe_gguf(
193 x,
194 rows,
195 mlp.in_dim,
196 &mlp.w0_t,
197 mlp.w0_gguf_key.as_deref(),
198 gguf_packed,
199 mlp.hidden,
200 &mlp.b0,
201 )?;
202 for v in h.iter_mut() {
203 if *v < 0.0 {
204 *v = 0.0;
205 }
206 }
207 linear_maybe_gguf(
208 &h,
209 rows,
210 mlp.hidden,
211 &mlp.w1_t,
212 mlp.w1_gguf_key.as_deref(),
213 gguf_packed,
214 mlp.out_dim,
215 &mlp.b1,
216 )
217}
218
219pub fn bbox_embed_forward(
221 weights: &Sam3DecoderWeights,
222 x: &[f32],
223 rows: usize,
224 gguf_packed: Option<&GgufPackedParams>,
225) -> Result<Vec<f32>> {
226 mlp3_forward(&weights.bbox_embed, x, rows, gguf_packed)
227}
228
229pub fn mlp3_forward(
230 mlp: &Mlp3,
231 x: &[f32],
232 rows: usize,
233 gguf_packed: Option<&GgufPackedParams>,
234) -> Result<Vec<f32>> {
235 let mut h = linear_maybe_gguf(
236 x,
237 rows,
238 mlp.in_dim,
239 &mlp.w0_t,
240 mlp.w0_gguf_key.as_deref(),
241 gguf_packed,
242 mlp.hidden,
243 &mlp.b0,
244 )?;
245 for v in h.iter_mut() {
246 if *v < 0.0 {
247 *v = 0.0;
248 }
249 }
250 h = linear_maybe_gguf(
251 &h,
252 rows,
253 mlp.hidden,
254 &mlp.w1_t,
255 mlp.w1_gguf_key.as_deref(),
256 gguf_packed,
257 mlp.hidden,
258 &mlp.b1,
259 )?;
260 for v in h.iter_mut() {
261 if *v < 0.0 {
262 *v = 0.0;
263 }
264 }
265 linear_maybe_gguf(
266 &h,
267 rows,
268 mlp.hidden,
269 &mlp.w2_t,
270 mlp.w2_gguf_key.as_deref(),
271 gguf_packed,
272 mlp.out_dim,
273 &mlp.b2,
274 )
275}
276
277pub fn mlp2_forward_into(
279 mlp: &Mlp2,
280 x: &[f32],
281 rows: usize,
282 hidden: &mut [f32],
283 out: &mut [f32],
284 gguf_packed: Option<&GgufPackedParams>,
285) -> Result<()> {
286 if mlp.w0_gguf_key.is_none() && !mlp.w0_t.is_empty() {
287 rlx_cpu::blas::sgemm_bias_epilogue(
288 x,
289 &mlp.w0_t,
290 &mlp.b0,
291 hidden,
292 rows,
293 mlp.in_dim,
294 mlp.hidden,
295 |v| if v < 0.0 { 0.0 } else { v },
296 );
297 rlx_cpu::blas::sgemm_bias(
298 hidden,
299 &mlp.w1_t,
300 &mlp.b1,
301 out,
302 rows,
303 mlp.hidden,
304 mlp.out_dim,
305 );
306 return Ok(());
307 }
308 let mut h = linear_maybe_gguf(
309 x,
310 rows,
311 mlp.in_dim,
312 &mlp.w0_t,
313 mlp.w0_gguf_key.as_deref(),
314 gguf_packed,
315 mlp.hidden,
316 &mlp.b0,
317 )?;
318 for v in h.iter_mut() {
319 if *v < 0.0 {
320 *v = 0.0;
321 }
322 }
323 hidden.copy_from_slice(&h);
324 let h2 = linear_maybe_gguf(
325 hidden,
326 rows,
327 mlp.hidden,
328 &mlp.w1_t,
329 mlp.w1_gguf_key.as_deref(),
330 gguf_packed,
331 mlp.out_dim,
332 &mlp.b1,
333 )?;
334 out.copy_from_slice(&h2);
335 Ok(())
336}
337
338pub fn mlp3_forward_into(
340 mlp: &Mlp3,
341 x: &[f32],
342 rows: usize,
343 h0: &mut [f32],
344 h1: &mut [f32],
345 out: &mut [f32],
346 gguf_packed: Option<&GgufPackedParams>,
347) -> Result<()> {
348 if mlp.w0_gguf_key.is_none() && !mlp.w0_t.is_empty() {
349 let relu = |v: f32| if v < 0.0 { 0.0 } else { v };
350 rlx_cpu::blas::sgemm_bias_epilogue(
351 x, &mlp.w0_t, &mlp.b0, h0, rows, mlp.in_dim, mlp.hidden, relu,
352 );
353 rlx_cpu::blas::sgemm_bias_epilogue(
354 h0, &mlp.w1_t, &mlp.b1, h1, rows, mlp.hidden, mlp.hidden, relu,
355 );
356 rlx_cpu::blas::sgemm_bias(h1, &mlp.w2_t, &mlp.b2, out, rows, mlp.hidden, mlp.out_dim);
357 return Ok(());
358 }
359 let o = mlp3_forward(mlp, x, rows, gguf_packed)?;
360 out.copy_from_slice(&o);
361 Ok(())
362}
363
364pub fn extract_decoder_weights(
365 weights: &mut WeightMap,
366 gguf_packed: Option<&GgufPackedParams>,
367) -> Result<Sam3DecoderWeights> {
368 let base = "detector.transformer.decoder";
369 ensure!(
370 weights.has(&format!("{base}.query_embed.weight")),
371 "SAM3 detector decoder not found"
372 );
373
374 let mut layers = Vec::with_capacity(N_LAYERS);
375 for i in 0..N_LAYERS {
376 let p = format!("{base}.layers.{i}");
377 let (self_attn_in_w_t, self_attn_in_gguf_key) = take_transposed_with_gguf_key(
378 weights,
379 gguf_packed,
380 &format!("{p}.self_attn.in_proj_weight"),
381 )?;
382 let (self_attn_in_b, _) =
383 take_or_gguf(weights, gguf_packed, &format!("{p}.self_attn.in_proj_bias"))?;
384 let (self_attn_out_w_t, self_attn_out_gguf_key) = take_transposed_with_gguf_key(
385 weights,
386 gguf_packed,
387 &format!("{p}.self_attn.out_proj.weight"),
388 )?;
389 let (self_attn_out_b, _) = take_or_gguf(
390 weights,
391 gguf_packed,
392 &format!("{p}.self_attn.out_proj.bias"),
393 )?;
394 let (ca_text_in_w_t, ca_text_in_gguf_key) = take_transposed_with_gguf_key(
395 weights,
396 gguf_packed,
397 &format!("{p}.ca_text.in_proj_weight"),
398 )?;
399 let (ca_text_in_b, _) =
400 take_or_gguf(weights, gguf_packed, &format!("{p}.ca_text.in_proj_bias"))?;
401 let (ca_text_out_w_t, ca_text_out_gguf_key) = take_transposed_with_gguf_key(
402 weights,
403 gguf_packed,
404 &format!("{p}.ca_text.out_proj.weight"),
405 )?;
406 let (ca_text_out_b, _) =
407 take_or_gguf(weights, gguf_packed, &format!("{p}.ca_text.out_proj.bias"))?;
408 let (cross_attn_in_w_t, cross_attn_in_gguf_key) = take_transposed_with_gguf_key(
409 weights,
410 gguf_packed,
411 &format!("{p}.cross_attn.in_proj_weight"),
412 )?;
413 let (cross_attn_in_b, _) = take_or_gguf(
414 weights,
415 gguf_packed,
416 &format!("{p}.cross_attn.in_proj.bias"),
417 )?;
418 let (cross_attn_out_w_t, cross_attn_out_gguf_key) = take_transposed_with_gguf_key(
419 weights,
420 gguf_packed,
421 &format!("{p}.cross_attn.out_proj.weight"),
422 )?;
423 let (cross_attn_out_b, _) = take_or_gguf(
424 weights,
425 gguf_packed,
426 &format!("{p}.cross_attn.out_proj.bias"),
427 )?;
428 let (linear1_w_t, linear1_gguf_key) =
429 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear1.weight"))?;
430 let (linear1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear1.bias"))?;
431 let (linear2_w_t, linear2_gguf_key) =
432 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear2.weight"))?;
433 let (linear2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear2.bias"))?;
434 let (norm1_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm1.weight"))?;
435 let (norm1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm1.bias"))?;
436 let (norm2_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm2.weight"))?;
437 let (norm2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm2.bias"))?;
438 let (norm3_w, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm3.weight"))?;
439 let (norm3_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.norm3.bias"))?;
440 let (catext_norm_w, _) =
441 take_or_gguf(weights, gguf_packed, &format!("{p}.catext_norm.weight"))?;
442 let (catext_norm_b, _) =
443 take_or_gguf(weights, gguf_packed, &format!("{p}.catext_norm.bias"))?;
444 layers.push(Sam3DecoderLayerWeights {
445 self_attn_in_w_t,
446 self_attn_in_b,
447 self_attn_in_gguf_key,
448 self_attn_out_w_t,
449 self_attn_out_b,
450 self_attn_out_gguf_key,
451 ca_text_in_w_t,
452 ca_text_in_b,
453 ca_text_in_gguf_key,
454 ca_text_out_w_t,
455 ca_text_out_b,
456 ca_text_out_gguf_key,
457 cross_attn_in_w_t,
458 cross_attn_in_b,
459 cross_attn_in_gguf_key,
460 cross_attn_out_w_t,
461 cross_attn_out_b,
462 cross_attn_out_gguf_key,
463 linear1_w_t,
464 linear1_b,
465 linear1_gguf_key,
466 linear2_w_t,
467 linear2_b,
468 linear2_gguf_key,
469 norm1_w,
470 norm1_b,
471 norm2_w,
472 norm2_b,
473 norm3_w,
474 norm3_b,
475 catext_norm_w,
476 catext_norm_b,
477 });
478 }
479
480 let (query_embed, qs) =
481 take_or_gguf(weights, gguf_packed, &format!("{base}.query_embed.weight"))?;
482 ensure!(qs == vec![NUM_QUERIES, D_MODEL], "query_embed shape {qs:?}");
483 let (reference_points, rs) = take_or_gguf(
484 weights,
485 gguf_packed,
486 &format!("{base}.reference_points.weight"),
487 )?;
488 ensure!(rs == vec![NUM_QUERIES, 4], "reference_points shape {rs:?}");
489 let (norm_w, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.norm.weight"))?;
490 let (norm_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.norm.bias"))?;
491 let bbox_embed = take_mlp3(
492 weights,
493 gguf_packed,
494 &format!("{base}.bbox_embed"),
495 D_MODEL,
496 D_MODEL,
497 4,
498 )?;
499 let ref_point_head = take_mlp2(
500 weights,
501 gguf_packed,
502 &format!("{base}.ref_point_head"),
503 2 * D_MODEL,
504 D_MODEL,
505 D_MODEL,
506 )?;
507 let boxrpb_x = take_mlp2(
508 weights,
509 gguf_packed,
510 &format!("{base}.boxRPB_embed_x"),
511 2,
512 D_MODEL,
513 N_HEADS,
514 )?;
515 let boxrpb_y = take_mlp2(
516 weights,
517 gguf_packed,
518 &format!("{base}.boxRPB_embed_y"),
519 2,
520 D_MODEL,
521 N_HEADS,
522 )?;
523 let (presence_token, ps) = take_or_gguf(
524 weights,
525 gguf_packed,
526 &format!("{base}.presence_token.weight"),
527 )?;
528 ensure!(ps == vec![1, D_MODEL], "presence_token shape {ps:?}");
529 let presence_token_head = take_mlp3(
530 weights,
531 gguf_packed,
532 &format!("{base}.presence_token_head"),
533 D_MODEL,
534 D_MODEL,
535 1,
536 )?;
537 let (presence_token_out_norm_w, _) = take_or_gguf(
538 weights,
539 gguf_packed,
540 &format!("{base}.presence_token_out_norm.weight"),
541 )?;
542 let (presence_token_out_norm_b, _) = take_or_gguf(
543 weights,
544 gguf_packed,
545 &format!("{base}.presence_token_out_norm.bias"),
546 )?;
547
548 Ok(Sam3DecoderWeights {
549 loaded: true,
550 prefix: base.to_string(),
551 layers,
552 query_embed,
553 reference_points,
554 norm_w,
555 norm_b,
556 bbox_embed,
557 ref_point_head,
558 boxrpb_x,
559 boxrpb_y,
560 presence_token,
561 presence_token_head,
562 presence_token_out_norm_w,
563 presence_token_out_norm_b,
564 })
565}
566
567#[derive(Debug, Clone, Default)]
568pub struct Sam3DecoderOutput {
569 pub intermediate: Vec<f32>,
571 pub intermediate_ref_boxes: Vec<f32>,
575 pub presence_logits: Vec<f32>,
577 pub presence_feats: Vec<f32>,
579 pub num_layers: usize,
580 pub num_queries: usize,
581 pub batch: usize,
582 pub d_model: usize,
583}
584
585fn sineembed_for_position_4d(pos: &[f32], nq: usize, bs: usize, d_model: usize) -> Vec<f32> {
589 let half = d_model / 2;
590 let scale = 2.0 * std::f32::consts::PI;
591 let mut dim_t = vec![0.0f32; half];
592 for i in 0..half {
593 let exp = 2.0 * ((i / 2) as f32) / half as f32;
594 dim_t[i] = 10000.0f32.powf(exp);
595 }
596 let mut out = vec![0.0f32; nq * bs * 2 * d_model];
597 for q in 0..nq {
598 for b in 0..bs {
599 let p = &pos[(q * bs + b) * 4..(q * bs + b + 1) * 4];
600 let x_e = p[0] * scale;
601 let y_e = p[1] * scale;
602 let w_e = p[2] * scale;
603 let h_e = p[3] * scale;
604 let base = (q * bs + b) * 2 * d_model;
606 for axis in 0..4 {
607 let val = [y_e, x_e, w_e, h_e][axis];
608 let slot = base + axis * half;
609 for i in 0..half {
610 let theta = val / dim_t[i];
611 out[slot + i] = if i % 2 == 0 { theta.sin() } else { theta.cos() };
612 }
613 }
614 }
615 }
616 out
617}
618
619fn inverse_sigmoid(x: f32) -> f32 {
620 let eps = 1e-3f32;
621 let x = x.clamp(0.0, 1.0);
622 let x1 = x.max(eps);
623 let x2 = (1.0 - x).max(eps);
624 (x1 / x2).ln()
625}
626
627fn sigmoid(x: f32) -> f32 {
628 1.0 / (1.0 + (-x).exp())
629}
630
631fn boxrpb_log_mask(
635 weights: &Sam3DecoderWeights,
636 reference_boxes: &[f32], nq: usize,
638 h: usize,
639 w: usize,
640 gguf_packed: Option<&GgufPackedParams>,
641) -> Result<Vec<f32>> {
642 let coords_h: Vec<f32> = (0..h).map(|y| y as f32 / h as f32).collect();
644 let coords_w: Vec<f32> = (0..w).map(|x| x as f32 / w as f32).collect();
645
646 let mut deltas_x = vec![0f32; nq * w * 2];
648 let mut deltas_y = vec![0f32; nq * h * 2];
649 for q in 0..nq {
650 let p = &reference_boxes[q * 4..(q + 1) * 4];
651 let (cx, cy, bw, bh) = (p[0], p[1], p[2], p[3]);
652 let x0 = cx - 0.5 * bw;
653 let x1 = cx + 0.5 * bw;
654 let y0 = cy - 0.5 * bh;
655 let y1 = cy + 0.5 * bh;
656 for xi in 0..w {
657 let dx0 = (coords_w[xi] - x0) * 8.0;
658 let dx1 = (coords_w[xi] - x1) * 8.0;
659 deltas_x[(q * w + xi) * 2] = log_norm(dx0);
660 deltas_x[(q * w + xi) * 2 + 1] = log_norm(dx1);
661 }
662 for yi in 0..h {
663 let dy0 = (coords_h[yi] - y0) * 8.0;
664 let dy1 = (coords_h[yi] - y1) * 8.0;
665 deltas_y[(q * h + yi) * 2] = log_norm(dy0);
666 deltas_y[(q * h + yi) * 2 + 1] = log_norm(dy1);
667 }
668 }
669 let dx_feats = mlp2_forward(&weights.boxrpb_x, &deltas_x, nq * w, gguf_packed)?;
671 let dy_feats = mlp2_forward(&weights.boxrpb_y, &deltas_y, nq * h, gguf_packed)?;
672
673 let mut out = vec![0f32; N_HEADS * nq * h * w];
676 for q in 0..nq {
677 for y in 0..h {
678 for x in 0..w {
679 for head in 0..N_HEADS {
680 let dy = dy_feats[(q * h + y) * N_HEADS + head];
681 let dx = dx_feats[(q * w + x) * N_HEADS + head];
682 out[(head * nq + q) * h * w + y * w + x] = dy + dx;
683 }
684 }
685 }
686 }
687 Ok(out)
688}
689
690fn log_norm(v: f32) -> f32 {
691 let s = if v < 0.0 { -1.0 } else { 1.0 };
693 s * (v.abs() + 1.0).log2() / 8.0f32.log2()
694}
695
696fn narrow_last(row: &[f32], rows: usize, width: usize, start: usize, len: usize) -> Vec<f32> {
697 let mut out = vec![0f32; rows * len];
698 for r in 0..rows {
699 out[r * len..(r + 1) * len]
700 .copy_from_slice(&row[r * width + start..r * width + start + len]);
701 }
702 out
703}
704
705#[allow(clippy::too_many_arguments)]
707pub(crate) fn mha_with_bias_maybe_gguf(
708 q: &[f32],
709 k: &[f32],
710 v: &[f32],
711 in_proj_w_t: &[f32],
712 in_proj_b: &[f32],
713 in_gguf_key: Option<&str>,
714 out_proj_w_t: &[f32],
715 out_proj_b: &[f32],
716 out_gguf_key: Option<&str>,
717 gguf_packed: Option<&GgufPackedParams>,
718 batch: usize,
719 l_q: usize,
720 l_k: usize,
721 embed_dim: usize,
722 num_heads: usize,
723 attn_bias_h_lq_lk: Option<&[f32]>,
724 key_padding_mask: Option<&[u8]>,
725) -> Result<Vec<f32>> {
726 if in_gguf_key.is_none() && out_gguf_key.is_none() {
727 return mha_with_bias_f32(
728 q,
729 k,
730 v,
731 in_proj_w_t,
732 in_proj_b,
733 out_proj_w_t,
734 out_proj_b,
735 batch,
736 l_q,
737 l_k,
738 embed_dim,
739 num_heads,
740 attn_bias_h_lq_lk,
741 key_padding_mask,
742 );
743 }
744
745 use super::tensor::{matmul, matmul_bt, softmax_rows};
746 let head_dim = embed_dim / num_heads;
747 let rows_q = batch * l_q;
748 let rows_k = batch * l_k;
749
750 let (q_proj, k_proj, v_proj) = if let Some(in_key) = in_gguf_key {
751 let qkv_q = linear_maybe_gguf(
752 q,
753 rows_q,
754 embed_dim,
755 in_proj_w_t,
756 Some(in_key),
757 gguf_packed,
758 3 * embed_dim,
759 in_proj_b,
760 )?;
761 let qkv_k = linear_maybe_gguf(
762 k,
763 rows_k,
764 embed_dim,
765 in_proj_w_t,
766 Some(in_key),
767 gguf_packed,
768 3 * embed_dim,
769 in_proj_b,
770 )?;
771 let qkv_v = linear_maybe_gguf(
772 v,
773 rows_k,
774 embed_dim,
775 in_proj_w_t,
776 Some(in_key),
777 gguf_packed,
778 3 * embed_dim,
779 in_proj_b,
780 )?;
781 (
782 narrow_last(&qkv_q, rows_q, 3 * embed_dim, 0, embed_dim),
783 narrow_last(&qkv_k, rows_k, 3 * embed_dim, embed_dim, embed_dim),
784 narrow_last(&qkv_v, rows_k, 3 * embed_dim, 2 * embed_dim, embed_dim),
785 )
786 } else {
787 let (wq, wk, wv) = split3(in_proj_w_t, embed_dim);
788 let bq = &in_proj_b[0..embed_dim];
789 let bk = &in_proj_b[embed_dim..2 * embed_dim];
790 let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
791 (
792 linear_maybe_gguf(q, rows_q, embed_dim, &wq, None, gguf_packed, embed_dim, bq)?,
793 linear_maybe_gguf(k, rows_k, embed_dim, &wk, None, gguf_packed, embed_dim, bk)?,
794 linear_maybe_gguf(v, rows_k, embed_dim, &wv, None, gguf_packed, embed_dim, bv)?,
795 )
796 };
797
798 let bh = batch * num_heads;
799 let mut qh = vec![0f32; bh * l_q * head_dim];
800 let mut kh = vec![0f32; bh * l_k * head_dim];
801 let mut vh = vec![0f32; bh * l_k * head_dim];
802 repack(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
803 repack(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
804 repack(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
805
806 let scale = 1.0f32 / (head_dim as f32).sqrt();
807 let mut scores = vec![0f32; l_q * l_k];
808 let mut attn_out = vec![0f32; bh * l_q * head_dim];
809 for bi in 0..batch {
810 for h in 0..num_heads {
811 let bhi = bi * num_heads + h;
812 let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
813 let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
814 let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
815 matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
816 if let Some(bias) = attn_bias_h_lq_lk {
817 let bias_h = &bias[h * l_q * l_k..(h + 1) * l_q * l_k];
818 for i in 0..scores.len() {
819 scores[i] += bias_h[i];
820 }
821 }
822 if let Some(mask) = key_padding_mask {
823 let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
824 for r in 0..l_q {
825 let row = &mut scores[r * l_k..(r + 1) * l_k];
826 for (c, m) in mask_b.iter().enumerate() {
827 if *m != 0 {
828 row[c] = f32::NEG_INFINITY;
829 }
830 }
831 }
832 }
833 softmax_rows(&mut scores, l_q, l_k);
834 let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
835 matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
836 }
837 }
838
839 let mut packed = vec![0f32; batch * l_q * embed_dim];
840 for bi in 0..batch {
841 for l in 0..l_q {
842 for h in 0..num_heads {
843 let src = ((bi * num_heads + h) * l_q + l) * head_dim;
844 let dst = (bi * l_q + l) * embed_dim + h * head_dim;
845 packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
846 }
847 }
848 }
849 linear_maybe_gguf(
850 &packed,
851 batch * l_q,
852 embed_dim,
853 out_proj_w_t,
854 out_gguf_key,
855 gguf_packed,
856 embed_dim,
857 out_proj_b,
858 )
859}
860
861#[allow(clippy::too_many_arguments)]
863fn mha_with_bias_f32(
864 q: &[f32],
865 k: &[f32],
866 v: &[f32],
867 in_proj_w_t: &[f32],
868 in_proj_b: &[f32],
869 out_proj_w_t: &[f32],
870 out_proj_b: &[f32],
871 batch: usize,
872 l_q: usize,
873 l_k: usize,
874 embed_dim: usize,
875 num_heads: usize,
876 attn_bias_h_lq_lk: Option<&[f32]>,
877 key_padding_mask: Option<&[u8]>,
878) -> Result<Vec<f32>> {
879 use super::tensor::{matmul, matmul_bt, softmax_rows};
880 let head_dim = embed_dim / num_heads;
881 let (wq, wk, wv) = split3(in_proj_w_t, embed_dim);
882 let bq = &in_proj_b[0..embed_dim];
883 let bk = &in_proj_b[embed_dim..2 * embed_dim];
884 let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
885
886 let q_proj = linear(q, batch * l_q, embed_dim, &wq, embed_dim, bq)?;
887 let k_proj = linear(k, batch * l_k, embed_dim, &wk, embed_dim, bk)?;
888 let v_proj = linear(v, batch * l_k, embed_dim, &wv, embed_dim, bv)?;
889
890 let bh = batch * num_heads;
891 let mut qh = vec![0f32; bh * l_q * head_dim];
892 let mut kh = vec![0f32; bh * l_k * head_dim];
893 let mut vh = vec![0f32; bh * l_k * head_dim];
894 repack(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
895 repack(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
896 repack(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
897
898 let scale = 1.0f32 / (head_dim as f32).sqrt();
899 let mut scores = vec![0f32; l_q * l_k];
900 let mut attn_out = vec![0f32; bh * l_q * head_dim];
901 for bi in 0..batch {
902 for h in 0..num_heads {
903 let bhi = bi * num_heads + h;
904 let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
905 let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
906 let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
907 matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
908 if let Some(bias) = attn_bias_h_lq_lk {
909 let bias_h = &bias[h * l_q * l_k..(h + 1) * l_q * l_k];
910 for i in 0..scores.len() {
911 scores[i] += bias_h[i];
912 }
913 }
914 if let Some(mask) = key_padding_mask {
915 let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
916 for r in 0..l_q {
917 let row = &mut scores[r * l_k..(r + 1) * l_k];
918 for (c, m) in mask_b.iter().enumerate() {
919 if *m != 0 {
920 row[c] = f32::NEG_INFINITY;
921 }
922 }
923 }
924 }
925 softmax_rows(&mut scores, l_q, l_k);
926 let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
927 matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
928 }
929 }
930
931 let mut packed = vec![0f32; batch * l_q * embed_dim];
932 for bi in 0..batch {
933 for l in 0..l_q {
934 for h in 0..num_heads {
935 let src = ((bi * num_heads + h) * l_q + l) * head_dim;
936 let dst = (bi * l_q + l) * embed_dim + h * head_dim;
937 packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
938 }
939 }
940 }
941 linear(
942 &packed,
943 batch * l_q,
944 embed_dim,
945 out_proj_w_t,
946 embed_dim,
947 out_proj_b,
948 )
949}
950
951fn split3(w_t: &[f32], e: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
952 let mut wq = vec![0f32; e * e];
953 let mut wk = vec![0f32; e * e];
954 let mut wv = vec![0f32; e * e];
955 for i in 0..e {
956 for j in 0..e {
957 wq[i * e + j] = w_t[i * 3 * e + j];
958 wk[i * e + j] = w_t[i * 3 * e + e + j];
959 wv[i * e + j] = w_t[i * 3 * e + 2 * e + j];
960 }
961 }
962 (wq, wk, wv)
963}
964
965fn repack(src: &[f32], dst: &mut [f32], batch: usize, l: usize, num_heads: usize, head_dim: usize) {
966 let e = num_heads * head_dim;
967 for bi in 0..batch {
968 for li in 0..l {
969 for h in 0..num_heads {
970 let s = (bi * l + li) * e + h * head_dim;
971 let d = ((bi * num_heads + h) * l + li) * head_dim;
972 dst[d..d + head_dim].copy_from_slice(&src[s..s + head_dim]);
973 }
974 }
975 }
976}
977
978#[allow(clippy::too_many_arguments)]
980pub fn forward_decoder(
981 weights: &Sam3DecoderWeights,
982 memory: &[f32], memory_pos: &[f32], memory_text: &[f32], text_attention_mask: &[u8], batch: usize,
987 h: usize,
988 w: usize,
989 seq_len: usize,
990 gguf_packed: Option<&GgufPackedParams>,
991) -> Result<Sam3DecoderOutput> {
992 ensure!(weights.loaded, "SAM3 detector decoder not loaded");
993 ensure!(batch == 1, "decoder forward requires batch=1 for boxRPB");
994 let hw = h * w;
995 let nq = NUM_QUERIES;
996
997 let mut tgt = vec![0f32; nq * batch * D_MODEL]; for q in 0..nq {
1000 let src = &weights.query_embed[q * D_MODEL..(q + 1) * D_MODEL];
1001 for b in 0..batch {
1002 tgt[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL].copy_from_slice(src);
1003 }
1004 }
1005 let mut reference_boxes = vec![0f32; nq * batch * 4];
1006 for q in 0..nq {
1007 let src = &weights.reference_points[q * 4..(q + 1) * 4];
1008 for b in 0..batch {
1009 let dst = &mut reference_boxes[(q * batch + b) * 4..(q * batch + b + 1) * 4];
1010 for k in 0..4 {
1011 dst[k] = sigmoid(src[k]);
1012 }
1013 }
1014 }
1015
1016 let mut presence_out = vec![0f32; batch * D_MODEL];
1017 for b in 0..batch {
1018 presence_out[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(&weights.presence_token);
1019 }
1020
1021 let mut intermediate = Vec::with_capacity(N_LAYERS);
1022 let mut intermediate_ref_boxes = Vec::with_capacity(N_LAYERS);
1023 let mut presence_logits = Vec::with_capacity(N_LAYERS);
1024
1025 intermediate_ref_boxes.push(reference_boxes.clone());
1027
1028 let mut memory_text_bf = vec![0f32; batch * seq_len * D_MODEL];
1031 for b in 0..batch {
1032 for l in 0..seq_len {
1033 let src = (l * batch + b) * D_MODEL;
1034 let dst = (b * seq_len + l) * D_MODEL;
1035 memory_text_bf[dst..dst + D_MODEL].copy_from_slice(&memory_text[src..src + D_MODEL]);
1036 }
1037 }
1038
1039 for (layer_idx, layer) in weights.layers.iter().enumerate() {
1040 let sine = sineembed_for_position_4d(&reference_boxes, nq, batch, D_MODEL);
1042 let query_pos = mlp2_forward(&weights.ref_point_head, &sine, nq * batch, gguf_packed)?;
1043 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1044 use std::io::Write as _;
1045 let path = format!("{dir}/host_layer{layer_idx}_query_pos.f32");
1046 let mut f = std::fs::File::create(&path).unwrap();
1047 for v in &query_pos {
1048 f.write_all(&v.to_le_bytes()).unwrap();
1049 }
1050 }
1051
1052 let sa_len = 1 + nq;
1056 let mut sa_x = vec![0f32; sa_len * batch * D_MODEL];
1057 let mut sa_pos = vec![0f32; sa_len * batch * D_MODEL];
1058 for b in 0..batch {
1059 sa_x[b * D_MODEL..(b + 1) * D_MODEL]
1060 .copy_from_slice(&presence_out[b * D_MODEL..(b + 1) * D_MODEL]);
1061 }
1062 for q in 0..nq {
1063 for b in 0..batch {
1064 let src = &tgt[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1065 sa_x[((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1066 .copy_from_slice(src);
1067 let qp = &query_pos[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1068 sa_pos[((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1069 .copy_from_slice(qp);
1070 }
1071 }
1072 let mut sa_x_bf = vec![0f32; batch * sa_len * D_MODEL];
1074 let mut sa_pos_bf = vec![0f32; batch * sa_len * D_MODEL];
1075 for b in 0..batch {
1076 for l in 0..sa_len {
1077 let s = (l * batch + b) * D_MODEL;
1078 let d = (b * sa_len + l) * D_MODEL;
1079 sa_x_bf[d..d + D_MODEL].copy_from_slice(&sa_x[s..s + D_MODEL]);
1080 sa_pos_bf[d..d + D_MODEL].copy_from_slice(&sa_pos[s..s + D_MODEL]);
1081 }
1082 }
1083 let mut qk = vec![0f32; sa_x_bf.len()];
1085 for i in 0..qk.len() {
1086 qk[i] = sa_x_bf[i] + sa_pos_bf[i];
1087 }
1088 let sa = mha_with_bias_maybe_gguf(
1089 &qk,
1090 &qk,
1091 &sa_x_bf,
1092 &layer.self_attn_in_w_t,
1093 &layer.self_attn_in_b,
1094 layer.self_attn_in_gguf_key.as_deref(),
1095 &layer.self_attn_out_w_t,
1096 &layer.self_attn_out_b,
1097 layer.self_attn_out_gguf_key.as_deref(),
1098 gguf_packed,
1099 batch,
1100 sa_len,
1101 sa_len,
1102 D_MODEL,
1103 N_HEADS,
1104 None,
1105 None,
1106 )?;
1107 for i in 0..sa_x_bf.len() {
1108 sa_x_bf[i] += sa[i];
1109 }
1110 let sa_x_bf = layer_norm(&sa_x_bf, &layer.norm2_w, &layer.norm2_b, D_MODEL, 1e-5)?;
1112 let mut new_presence = vec![0f32; batch * D_MODEL];
1114 for b in 0..batch {
1115 let src = &sa_x_bf[(b * sa_len) * D_MODEL..(b * sa_len + 1) * D_MODEL];
1116 new_presence[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(src);
1117 }
1118 let mut after_sa = vec![0f32; batch * nq * D_MODEL];
1119 for b in 0..batch {
1120 for q in 0..nq {
1121 let src = (b * sa_len + 1 + q) * D_MODEL;
1122 let dst = (b * nq + q) * D_MODEL;
1123 after_sa[dst..dst + D_MODEL].copy_from_slice(&sa_x_bf[src..src + D_MODEL]);
1124 }
1125 }
1126 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1127 use std::io::Write as _;
1128 let path = format!("{dir}/host_layer{layer_idx}_sa_queries.f32");
1129 let mut f = std::fs::File::create(&path).unwrap();
1130 for v in &after_sa {
1131 f.write_all(&v.to_le_bytes()).unwrap();
1132 }
1133 }
1134
1135 let mut q_text = vec![0f32; batch * nq * D_MODEL];
1138 for b in 0..batch {
1139 for q in 0..nq {
1140 let dst = (b * nq + q) * D_MODEL;
1141 let qp = &query_pos[(q * batch + b) * D_MODEL..(q * batch + b + 1) * D_MODEL];
1142 for c in 0..D_MODEL {
1143 q_text[dst + c] = after_sa[dst + c] + qp[c];
1144 }
1145 }
1146 }
1147 let text_attn = mha_with_bias_maybe_gguf(
1148 &q_text,
1149 &memory_text_bf,
1150 &memory_text_bf,
1151 &layer.ca_text_in_w_t,
1152 &layer.ca_text_in_b,
1153 layer.ca_text_in_gguf_key.as_deref(),
1154 &layer.ca_text_out_w_t,
1155 &layer.ca_text_out_b,
1156 layer.ca_text_out_gguf_key.as_deref(),
1157 gguf_packed,
1158 batch,
1159 nq,
1160 seq_len,
1161 D_MODEL,
1162 N_HEADS,
1163 None,
1164 Some(text_attention_mask),
1165 )?;
1166 let mut tgt_after_ca_text = vec![0f32; batch * nq * D_MODEL];
1167 for i in 0..tgt_after_ca_text.len() {
1168 tgt_after_ca_text[i] = after_sa[i] + text_attn[i];
1169 }
1170 let tgt_after_ca_text = layer_norm(
1171 &tgt_after_ca_text,
1172 &layer.catext_norm_w,
1173 &layer.catext_norm_b,
1174 D_MODEL,
1175 1e-5,
1176 )?;
1177 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1178 use std::io::Write as _;
1179 let path = format!("{dir}/host_layer{layer_idx}_after_ca_text_q.f32");
1180 let mut f = std::fs::File::create(&path).unwrap();
1181 for v in &tgt_after_ca_text {
1182 f.write_all(&v.to_le_bytes()).unwrap();
1183 }
1184 }
1185
1186 let rpb = boxrpb_log_mask(weights, &reference_boxes, nq, h, w, gguf_packed)?;
1188 let cross_len_q = 1 + nq;
1197 let mut full_mask = vec![0f32; N_HEADS * cross_len_q * hw];
1199 for head in 0..N_HEADS {
1200 for q in 0..nq {
1202 let src = (head * nq + q) * hw;
1203 let dst = (head * cross_len_q + 1 + q) * hw;
1204 full_mask[dst..dst + hw].copy_from_slice(&rpb[src..src + hw]);
1205 }
1206 }
1207 let mut ca_in_seq_first = vec![0f32; cross_len_q * batch * D_MODEL];
1209 for b in 0..batch {
1210 ca_in_seq_first[b * D_MODEL..(b + 1) * D_MODEL]
1212 .copy_from_slice(&new_presence[b * D_MODEL..(b + 1) * D_MODEL]);
1213 for q in 0..nq {
1214 let src = &tgt_after_ca_text[(b * nq + q) * D_MODEL..(b * nq + q + 1) * D_MODEL];
1215 ca_in_seq_first
1216 [((1 + q) * batch + b) * D_MODEL..((1 + q) * batch + b + 1) * D_MODEL]
1217 .copy_from_slice(src);
1218 }
1219 }
1220 let mut ca_in_bf = vec![0f32; batch * cross_len_q * D_MODEL];
1222 let mut ca_pos_bf = vec![0f32; batch * cross_len_q * D_MODEL];
1223 for b in 0..batch {
1224 for l in 0..cross_len_q {
1225 let s = (l * batch + b) * D_MODEL;
1226 let d = (b * cross_len_q + l) * D_MODEL;
1227 ca_in_bf[d..d + D_MODEL].copy_from_slice(&ca_in_seq_first[s..s + D_MODEL]);
1228 if l == 0 {
1229 } else {
1231 let qp = &query_pos
1232 [((l - 1) * batch + b) * D_MODEL..((l - 1) * batch + b + 1) * D_MODEL];
1233 ca_pos_bf[d..d + D_MODEL].copy_from_slice(qp);
1234 }
1235 }
1236 }
1237 let mut q_img = vec![0f32; ca_in_bf.len()];
1239 for i in 0..q_img.len() {
1240 q_img[i] = ca_in_bf[i] + ca_pos_bf[i];
1241 }
1242 let mut k_img = vec![0f32; memory.len()];
1243 for i in 0..k_img.len() {
1244 k_img[i] = memory[i] + memory_pos[i];
1245 }
1246 let ca_out = mha_with_bias_maybe_gguf(
1247 &q_img,
1248 &k_img,
1249 memory,
1250 &layer.cross_attn_in_w_t,
1251 &layer.cross_attn_in_b,
1252 layer.cross_attn_in_gguf_key.as_deref(),
1253 &layer.cross_attn_out_w_t,
1254 &layer.cross_attn_out_b,
1255 layer.cross_attn_out_gguf_key.as_deref(),
1256 gguf_packed,
1257 batch,
1258 cross_len_q,
1259 hw,
1260 D_MODEL,
1261 N_HEADS,
1262 Some(&full_mask),
1263 None,
1264 )?;
1265 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1266 use std::io::Write as _;
1267 let path = format!("{dir}/host_layer{layer_idx}_ca_img_proj.f32");
1268 let mut f = std::fs::File::create(&path).unwrap();
1269 for v in &ca_out {
1270 f.write_all(&v.to_le_bytes()).unwrap();
1271 }
1272 }
1273 for i in 0..ca_in_bf.len() {
1274 ca_in_bf[i] += ca_out[i];
1275 }
1276 let ca_in_bf = layer_norm(&ca_in_bf, &layer.norm1_w, &layer.norm1_b, D_MODEL, 1e-5)?;
1278 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1279 use std::io::Write as _;
1280 let mut q_only = vec![0f32; batch * nq * D_MODEL];
1282 for b in 0..batch {
1283 for q in 0..nq {
1284 let src = (b * cross_len_q + 1 + q) * D_MODEL;
1285 let dst = (b * nq + q) * D_MODEL;
1286 q_only[dst..dst + D_MODEL].copy_from_slice(&ca_in_bf[src..src + D_MODEL]);
1287 }
1288 }
1289 let path = format!("{dir}/host_layer{layer_idx}_after_ca_img_q.f32");
1290 let mut f = std::fs::File::create(&path).unwrap();
1291 for v in &q_only {
1292 f.write_all(&v.to_le_bytes()).unwrap();
1293 }
1294 }
1295
1296 let mut ff = linear_maybe_gguf(
1298 &ca_in_bf,
1299 batch * cross_len_q,
1300 D_MODEL,
1301 &layer.linear1_w_t,
1302 layer.linear1_gguf_key.as_deref(),
1303 gguf_packed,
1304 DIM_FF,
1305 &layer.linear1_b,
1306 )?;
1307 for v in ff.iter_mut() {
1308 if *v < 0.0 {
1309 *v = 0.0;
1310 }
1311 }
1312 let ffn = linear_maybe_gguf(
1313 &ff,
1314 batch * cross_len_q,
1315 DIM_FF,
1316 &layer.linear2_w_t,
1317 layer.linear2_gguf_key.as_deref(),
1318 gguf_packed,
1319 D_MODEL,
1320 &layer.linear2_b,
1321 )?;
1322 let mut after_ffn = ca_in_bf.clone();
1323 for i in 0..after_ffn.len() {
1324 after_ffn[i] += ffn[i];
1325 }
1326 let after_ffn = layer_norm(&after_ffn, &layer.norm3_w, &layer.norm3_b, D_MODEL, 1e-5)?;
1328
1329 let mut layer_presence = vec![0f32; batch * D_MODEL];
1331 let mut layer_tgt = vec![0f32; batch * nq * D_MODEL];
1332 for b in 0..batch {
1333 let src_p = &after_ffn[(b * cross_len_q) * D_MODEL..(b * cross_len_q + 1) * D_MODEL];
1334 layer_presence[b * D_MODEL..(b + 1) * D_MODEL].copy_from_slice(src_p);
1335 for q in 0..nq {
1336 let src = (b * cross_len_q + 1 + q) * D_MODEL;
1337 let dst = (b * nq + q) * D_MODEL;
1338 layer_tgt[dst..dst + D_MODEL].copy_from_slice(&after_ffn[src..src + D_MODEL]);
1339 }
1340 }
1341 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1342 use std::io::Write as _;
1343 for (vals, name) in [(&layer_tgt, "new_tgt"), (&layer_presence, "new_presence")] {
1344 let path = format!("{dir}/host_layer{layer_idx}_{name}.f32");
1345 let mut f = std::fs::File::create(&path).unwrap();
1346 for v in vals {
1347 f.write_all(&v.to_le_bytes()).unwrap();
1348 }
1349 }
1350 }
1351 for q in 0..nq {
1353 for b in 0..batch {
1354 let src = (b * nq + q) * D_MODEL;
1355 let dst = (q * batch + b) * D_MODEL;
1356 tgt[dst..dst + D_MODEL].copy_from_slice(&layer_tgt[src..src + D_MODEL]);
1357 }
1358 }
1359 presence_out.copy_from_slice(&layer_presence);
1361
1362 let out_norm = layer_norm(&layer_tgt, &weights.norm_w, &weights.norm_b, D_MODEL, 1e-5)?;
1365 if let Some(dir) = rlx_ir::env::var("RLX_SAM3_DECODER_DUMP_DIR") {
1366 use std::io::Write as _;
1367 let path = format!("{dir}/host_layer{layer_idx}_out_norm.f32");
1368 let mut f = std::fs::File::create(&path).unwrap();
1369 for v in &out_norm {
1370 f.write_all(&v.to_le_bytes()).unwrap();
1371 }
1372 }
1373 let delta = mlp3_forward(&weights.bbox_embed, &out_norm, batch * nq, gguf_packed)?;
1374 let mut new_ref = vec![0f32; nq * batch * 4];
1375 for q in 0..nq {
1376 for b in 0..batch {
1377 let cur = &reference_boxes[(q * batch + b) * 4..(q * batch + b + 1) * 4];
1378 let d = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
1379 for k in 0..4 {
1380 new_ref[(q * batch + b) * 4 + k] = sigmoid(inverse_sigmoid(cur[k]) + d[k]);
1381 }
1382 }
1383 }
1384 reference_boxes = new_ref;
1385 if layer_idx != N_LAYERS - 1 {
1386 intermediate_ref_boxes.push(reference_boxes.clone());
1387 }
1388
1389 let mut out_seq_first = vec![0f32; nq * batch * D_MODEL];
1391 for q in 0..nq {
1392 for b in 0..batch {
1393 let src = (b * nq + q) * D_MODEL;
1394 let dst = (q * batch + b) * D_MODEL;
1395 out_seq_first[dst..dst + D_MODEL].copy_from_slice(&out_norm[src..src + D_MODEL]);
1396 }
1397 }
1398 intermediate.push(out_seq_first);
1399
1400 let p_norm = layer_norm(
1402 &layer_presence,
1403 &weights.presence_token_out_norm_w,
1404 &weights.presence_token_out_norm_b,
1405 D_MODEL,
1406 1e-5,
1407 )?;
1408 let p_logit = mlp3_forward(&weights.presence_token_head, &p_norm, batch, gguf_packed)?;
1409 presence_logits.push(p_logit);
1410 }
1411
1412 let mut int_stack = vec![0f32; N_LAYERS * nq * batch * D_MODEL];
1414 for (li, layer_out) in intermediate.iter().enumerate() {
1415 int_stack[li * nq * batch * D_MODEL..(li + 1) * nq * batch * D_MODEL]
1416 .copy_from_slice(layer_out);
1417 }
1418 let mut ref_stack = vec![0f32; N_LAYERS * nq * batch * 4];
1419 for (li, ref_l) in intermediate_ref_boxes.iter().enumerate() {
1420 ref_stack[li * nq * batch * 4..(li + 1) * nq * batch * 4].copy_from_slice(ref_l);
1421 }
1422 let mut presence_stack = vec![0f32; N_LAYERS * batch];
1423 for (li, p) in presence_logits.iter().enumerate() {
1424 for b in 0..batch {
1425 presence_stack[li * batch + b] = p[b];
1426 }
1427 }
1428
1429 Ok(Sam3DecoderOutput {
1430 intermediate: int_stack,
1431 intermediate_ref_boxes: ref_stack,
1432 presence_logits: presence_stack,
1433 presence_feats: presence_out,
1434 num_layers: N_LAYERS,
1435 num_queries: nq,
1436 batch,
1437 d_model: D_MODEL,
1438 })
1439}