1use super::detector::Sam3DetectorOutput;
23use super::detector_decoder::{Mlp2, Mlp3};
24use super::sam3::Sam3ImagePrediction;
25use super::segmentation_pixel_ir::{
26 Sam3Conv1x1Compiled, Sam3PixelDecoderStepCompiled, compile_pixel_decoder_steps,
27};
28use super::tensor::{layer_norm, matmul, matmul_bt, multihead_attention, softmax_rows};
29use rlx_core::weight_map::WeightMap;
30use rlx_flow::GgufPackedParams;
31
32use crate::packed_gguf::{
33 conv2d_3x3_nchw_gguf, conv2d_3x3_nchw_pad1, gguf_packed_conv1_to_nchw,
34 gguf_packed_conv3_to_f32, linear_maybe_gguf, packed_linear, take_conv1x1_with_gguf_key,
35 take_conv3x3_with_gguf_key, take_or_gguf, take_transposed_with_gguf_key,
36};
37use anyhow::{Result, ensure};
38use rlx_runtime::Device;
39
40const D_MODEL: usize = 256;
41const N_HEADS: usize = 8;
42
43#[derive(Default)]
44pub struct Sam3SegmentationHeadWeights {
45 pub loaded: bool,
46 pub cross_attn_norm_w: Vec<f32>,
47 pub cross_attn_norm_b: Vec<f32>,
48 pub cross_attend_in_w_t: Vec<f32>,
49 pub cross_attend_in_b: Vec<f32>,
50 pub cross_attend_out_w_t: Vec<f32>,
51 pub cross_attend_out_b: Vec<f32>,
52 pub cross_attend_in_gguf_key: Option<String>,
53 pub cross_attend_out_gguf_key: Option<String>,
54 pub mask_embed_w0_gguf_key: Option<String>,
55 pub mask_embed_w1_gguf_key: Option<String>,
56 pub mask_embed_w2_gguf_key: Option<String>,
57 pub pixel_conv_w: Vec<Vec<f32>>,
58 pub pixel_conv_b: Vec<Vec<f32>>,
59 pub pixel_conv_gguf_keys: Vec<Option<String>>,
60 pub pixel_conv_nchw_cache: Vec<Option<Vec<f32>>>,
62 pub pixel_gn_w: Vec<Vec<f32>>,
63 pub pixel_gn_b: Vec<Vec<f32>>,
64 pub inst_w: Vec<f32>,
65 pub inst_b: Vec<f32>,
66 pub inst_gguf_key: Option<String>,
67 pub sem_w: Vec<f32>,
68 pub sem_b: Vec<f32>,
69 pub sem_gguf_key: Option<String>,
70 pub mask_embed: Mlp3,
71 pub pixel_steps: Vec<Sam3PixelDecoderStepCompiled>,
72 pub inst_head: Option<Sam3Conv1x1Compiled>,
73 pub sem_head: Option<Sam3Conv1x1Compiled>,
74}
75
76#[derive(Clone, Default)]
77pub struct Sam3DotProductScoringWeights {
78 pub loaded: bool,
79 pub prompt_mlp: Mlp2,
80 pub prompt_mlp_out_norm_w: Vec<f32>,
81 pub prompt_mlp_out_norm_b: Vec<f32>,
82 pub prompt_proj_w_t: Vec<f32>,
83 pub prompt_proj_b: Vec<f32>,
84 pub hs_proj_w_t: Vec<f32>,
85 pub hs_proj_b: Vec<f32>,
86 pub prompt_mlp_w0_gguf_key: Option<String>,
87 pub prompt_mlp_w1_gguf_key: Option<String>,
88 pub prompt_proj_gguf_key: Option<String>,
89 pub hs_proj_gguf_key: Option<String>,
90}
91
92pub fn extract_segmentation_head_weights(
93 weights: &mut WeightMap,
94 gguf_packed: Option<&GgufPackedParams>,
95) -> Result<Sam3SegmentationHeadWeights> {
96 let base = "detector.segmentation_head";
97
98 let (cross_attn_norm_w, _) = take_or_gguf(
99 weights,
100 gguf_packed,
101 &format!("{base}.cross_attn_norm.weight"),
102 )?;
103 let (cross_attn_norm_b, _) = take_or_gguf(
104 weights,
105 gguf_packed,
106 &format!("{base}.cross_attn_norm.bias"),
107 )?;
108 let (cross_attend_in_w_t, cross_attend_in_gguf_key) = take_transposed_with_gguf_key(
109 weights,
110 gguf_packed,
111 &format!("{base}.cross_attend_prompt.in_proj_weight"),
112 )?;
113 let (cross_attend_in_b, _) = take_or_gguf(
114 weights,
115 gguf_packed,
116 &format!("{base}.cross_attend_prompt.in_proj_bias"),
117 )?;
118 let (cross_attend_out_w_t, cross_attend_out_gguf_key) = take_transposed_with_gguf_key(
119 weights,
120 gguf_packed,
121 &format!("{base}.cross_attend_prompt.out_proj.weight"),
122 )?;
123 let (cross_attend_out_b, _) = take_or_gguf(
124 weights,
125 gguf_packed,
126 &format!("{base}.cross_attend_prompt.out_proj.bias"),
127 )?;
128
129 let mut pixel_conv_w = Vec::new();
130 let mut pixel_conv_b = Vec::new();
131 let mut pixel_conv_gguf_keys = Vec::new();
132 let mut pixel_gn_w = Vec::new();
133 let mut pixel_gn_b = Vec::new();
134 for i in 0..3 {
135 let (cw, cs, ck) = take_conv3x3_with_gguf_key(
136 weights,
137 gguf_packed,
138 &format!("{base}.pixel_decoder.conv_layers.{i}.weight"),
139 )?;
140 ensure!(
141 cs == vec![D_MODEL, D_MODEL, 3, 3],
142 "pixel_decoder conv {i} shape {cs:?}"
143 );
144 let (cb, _) = take_or_gguf(
145 weights,
146 gguf_packed,
147 &format!("{base}.pixel_decoder.conv_layers.{i}.bias"),
148 )?;
149 let (nw, _) = take_or_gguf(
150 weights,
151 gguf_packed,
152 &format!("{base}.pixel_decoder.norms.{i}.weight"),
153 )?;
154 let (nb, _) = take_or_gguf(
155 weights,
156 gguf_packed,
157 &format!("{base}.pixel_decoder.norms.{i}.bias"),
158 )?;
159 pixel_conv_w.push(cw);
160 pixel_conv_b.push(cb);
161 pixel_conv_gguf_keys.push(ck);
162 pixel_gn_w.push(nw);
163 pixel_gn_b.push(nb);
164 }
165
166 let (inst_w, ins, inst_gguf_key) = take_conv1x1_with_gguf_key(
167 weights,
168 gguf_packed,
169 &format!("{base}.instance_seg_head.weight"),
170 )?;
171 ensure!(
172 ins == vec![D_MODEL, D_MODEL, 1, 1],
173 "instance_seg_head shape {ins:?}"
174 );
175 let (inst_b, _) = take_or_gguf(
176 weights,
177 gguf_packed,
178 &format!("{base}.instance_seg_head.bias"),
179 )?;
180 let (sem_w, ss, sem_gguf_key) = take_conv1x1_with_gguf_key(
181 weights,
182 gguf_packed,
183 &format!("{base}.semantic_seg_head.weight"),
184 )?;
185 ensure!(
186 ss == vec![1, D_MODEL, 1, 1],
187 "semantic_seg_head shape {ss:?}"
188 );
189 let (sem_b, _) = take_or_gguf(
190 weights,
191 gguf_packed,
192 &format!("{base}.semantic_seg_head.bias"),
193 )?;
194
195 let (m0_t, mask_embed_w0_gguf_key) = take_transposed_with_gguf_key(
196 weights,
197 gguf_packed,
198 &format!("{base}.mask_predictor.mask_embed.layers.0.weight"),
199 )?;
200 let (m0_b, _) = take_or_gguf(
201 weights,
202 gguf_packed,
203 &format!("{base}.mask_predictor.mask_embed.layers.0.bias"),
204 )?;
205 let (m1_t, mask_embed_w1_gguf_key) = take_transposed_with_gguf_key(
206 weights,
207 gguf_packed,
208 &format!("{base}.mask_predictor.mask_embed.layers.1.weight"),
209 )?;
210 let (m1_b, _) = take_or_gguf(
211 weights,
212 gguf_packed,
213 &format!("{base}.mask_predictor.mask_embed.layers.1.bias"),
214 )?;
215 let (m2_t, mask_embed_w2_gguf_key) = take_transposed_with_gguf_key(
216 weights,
217 gguf_packed,
218 &format!("{base}.mask_predictor.mask_embed.layers.2.weight"),
219 )?;
220 let (m2_b, _) = take_or_gguf(
221 weights,
222 gguf_packed,
223 &format!("{base}.mask_predictor.mask_embed.layers.2.bias"),
224 )?;
225 let mask_embed = Mlp3 {
226 w0_t: m0_t,
227 b0: m0_b,
228 w1_t: m1_t,
229 b1: m1_b,
230 w2_t: m2_t,
231 b2: m2_b,
232 in_dim: D_MODEL,
233 hidden: D_MODEL,
234 out_dim: D_MODEL,
235 w0_gguf_key: mask_embed_w0_gguf_key.clone(),
236 w1_gguf_key: mask_embed_w1_gguf_key.clone(),
237 w2_gguf_key: mask_embed_w2_gguf_key.clone(),
238 };
239
240 Ok(Sam3SegmentationHeadWeights {
241 loaded: true,
242 cross_attn_norm_w,
243 cross_attn_norm_b,
244 cross_attend_in_w_t,
245 cross_attend_in_b,
246 cross_attend_out_w_t,
247 cross_attend_out_b,
248 cross_attend_in_gguf_key,
249 cross_attend_out_gguf_key,
250 mask_embed_w0_gguf_key,
251 mask_embed_w1_gguf_key,
252 mask_embed_w2_gguf_key,
253 pixel_conv_w,
254 pixel_conv_b,
255 pixel_conv_gguf_keys,
256 pixel_conv_nchw_cache: vec![None; 3],
257 pixel_gn_w,
258 pixel_gn_b,
259 inst_w,
260 inst_b,
261 inst_gguf_key,
262 sem_w,
263 sem_b,
264 sem_gguf_key,
265 mask_embed,
266 pixel_steps: Vec::new(),
267 inst_head: None,
268 sem_head: None,
269 })
270}
271
272pub fn materialize_segmentation_gguf_weights(
274 weights: &mut Sam3SegmentationHeadWeights,
275 gguf_packed: Option<&GgufPackedParams>,
276) -> Result<()> {
277 let Some(gguf) = gguf_packed else {
278 return Ok(());
279 };
280 for i in 0..weights.pixel_conv_gguf_keys.len() {
281 if weights.pixel_conv_w[i].is_empty() {
282 if let Some(key) = &weights.pixel_conv_gguf_keys[i] {
283 let p = packed_linear(gguf, key)
284 .ok_or_else(|| anyhow::anyhow!("missing packed pixel conv: {key}"))?;
285 weights.pixel_conv_w[i] = gguf_packed_conv3_to_f32(p, D_MODEL, D_MODEL)?;
286 }
287 }
288 }
289 if weights.inst_w.is_empty() {
290 if let Some(key) = &weights.inst_gguf_key {
291 weights.inst_w = gguf_packed_conv1_to_nchw(gguf, key, D_MODEL, D_MODEL)?;
292 }
293 }
294 if weights.sem_w.is_empty() {
295 if let Some(key) = &weights.sem_gguf_key {
296 weights.sem_w = gguf_packed_conv1_to_nchw(gguf, key, 1, D_MODEL)?;
297 }
298 }
299 Ok(())
300}
301
302pub fn compile_segmentation_ir(
304 weights: &mut Sam3SegmentationHeadWeights,
305 gguf_packed: Option<&GgufPackedParams>,
306 trunk_grid: usize,
307 device: Device,
308 profile: &rlx_flow::CompileProfile,
309) -> Result<()> {
310 if !weights.loaded {
311 return Ok(());
312 }
313 materialize_segmentation_gguf_weights(weights, gguf_packed)?;
314
315 if !weights.pixel_conv_w[0].is_empty() {
316 weights.pixel_steps = compile_pixel_decoder_steps(
317 &weights.pixel_conv_w,
318 &weights.pixel_conv_b,
319 &weights.pixel_gn_w,
320 &weights.pixel_gn_b,
321 trunk_grid,
322 device,
323 profile,
324 )?;
325 }
326
327 let g2 = trunk_grid * 4;
328 if let Some(gguf) = gguf_packed {
329 if weights.inst_gguf_key.is_some() || !weights.inst_w.is_empty() {
330 weights.inst_head = Some(Sam3Conv1x1Compiled::compile_with_gguf(
331 D_MODEL,
332 D_MODEL,
333 g2,
334 g2,
335 &weights.inst_w,
336 &weights.inst_b,
337 weights.inst_gguf_key.as_deref(),
338 gguf,
339 device,
340 profile,
341 )?);
342 }
343 if weights.sem_gguf_key.is_some() || !weights.sem_w.is_empty() {
344 weights.sem_head = Some(Sam3Conv1x1Compiled::compile_with_gguf(
345 D_MODEL,
346 1,
347 g2,
348 g2,
349 &weights.sem_w,
350 &weights.sem_b,
351 weights.sem_gguf_key.as_deref(),
352 gguf,
353 device,
354 profile,
355 )?);
356 }
357 } else {
358 if !weights.inst_w.is_empty() {
359 weights.inst_head = Some(Sam3Conv1x1Compiled::compile_with_profile(
360 D_MODEL,
361 D_MODEL,
362 g2,
363 g2,
364 &weights.inst_w,
365 &weights.inst_b,
366 device,
367 profile,
368 )?);
369 }
370 if !weights.sem_w.is_empty() {
371 weights.sem_head = Some(Sam3Conv1x1Compiled::compile_with_profile(
372 D_MODEL,
373 1,
374 g2,
375 g2,
376 &weights.sem_w,
377 &weights.sem_b,
378 device,
379 profile,
380 )?);
381 }
382 }
383 Ok(())
384}
385
386pub fn extract_dot_product_scoring_weights(
387 weights: &mut WeightMap,
388 gguf_packed: Option<&GgufPackedParams>,
389) -> Result<Sam3DotProductScoringWeights> {
390 let base = "detector.dot_prod_scoring";
391 let (pm0_t, prompt_mlp_w0_gguf_key) = take_transposed_with_gguf_key(
392 weights,
393 gguf_packed,
394 &format!("{base}.prompt_mlp.layers.0.weight"),
395 )?;
396 let (pm0_b, _) = take_or_gguf(
397 weights,
398 gguf_packed,
399 &format!("{base}.prompt_mlp.layers.0.bias"),
400 )?;
401 let (pm1_t, prompt_mlp_w1_gguf_key) = take_transposed_with_gguf_key(
402 weights,
403 gguf_packed,
404 &format!("{base}.prompt_mlp.layers.1.weight"),
405 )?;
406 let (pm1_b, _) = take_or_gguf(
407 weights,
408 gguf_packed,
409 &format!("{base}.prompt_mlp.layers.1.bias"),
410 )?;
411 let prompt_mlp = Mlp2 {
412 w0_t: pm0_t,
413 b0: pm0_b,
414 w1_t: pm1_t,
415 b1: pm1_b,
416 in_dim: D_MODEL,
417 hidden: 2048,
418 out_dim: D_MODEL,
419 w0_gguf_key: prompt_mlp_w0_gguf_key.clone(),
420 w1_gguf_key: prompt_mlp_w1_gguf_key.clone(),
421 };
422 let (pm_norm_w, _) = take_or_gguf(
423 weights,
424 gguf_packed,
425 &format!("{base}.prompt_mlp.out_norm.weight"),
426 )?;
427 let (pm_norm_b, _) = take_or_gguf(
428 weights,
429 gguf_packed,
430 &format!("{base}.prompt_mlp.out_norm.bias"),
431 )?;
432 let (pp_t, prompt_proj_gguf_key) =
433 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.prompt_proj.weight"))?;
434 let (pp_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.prompt_proj.bias"))?;
435 let (hs_t, hs_proj_gguf_key) =
436 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.hs_proj.weight"))?;
437 let (hs_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.hs_proj.bias"))?;
438 Ok(Sam3DotProductScoringWeights {
439 loaded: true,
440 prompt_mlp,
441 prompt_mlp_out_norm_w: pm_norm_w,
442 prompt_mlp_out_norm_b: pm_norm_b,
443 prompt_proj_w_t: pp_t,
444 prompt_proj_b: pp_b,
445 hs_proj_w_t: hs_t,
446 hs_proj_b: hs_b,
447 prompt_mlp_w0_gguf_key,
448 prompt_mlp_w1_gguf_key,
449 prompt_proj_gguf_key,
450 hs_proj_gguf_key,
451 })
452}
453
454#[derive(Debug, Clone, Default)]
455pub struct Sam3SegmentationOutput {
456 pub mask_pred: Vec<f32>,
457 pub semantic_seg: Vec<f32>,
458 pub h_out: usize,
459 pub w_out: usize,
460 pub num_queries: usize,
461}
462
463fn split_in_proj_w(in_proj_w_t: &[f32], embed_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
464 let e = embed_dim;
465 let mut wq = vec![0f32; e * e];
466 let mut wk = vec![0f32; e * e];
467 let mut wv = vec![0f32; e * e];
468 for i in 0..e {
469 for j in 0..e {
470 wq[i * e + j] = in_proj_w_t[i * 3 * e + j];
471 wk[i * e + j] = in_proj_w_t[i * 3 * e + e + j];
472 wv[i * e + j] = in_proj_w_t[i * 3 * e + 2 * e + j];
473 }
474 }
475 (wq, wk, wv)
476}
477
478fn repack_heads(
479 flat: &[f32],
480 out: &mut [f32],
481 batch: usize,
482 seq: usize,
483 num_heads: usize,
484 head_dim: usize,
485) {
486 for bi in 0..batch {
487 for l in 0..seq {
488 for h in 0..num_heads {
489 let src = (bi * seq + l) * num_heads * head_dim + h * head_dim;
490 let dst = (bi * num_heads + h) * seq * head_dim + l * head_dim;
491 out[dst..dst + head_dim].copy_from_slice(&flat[src..src + head_dim]);
492 }
493 }
494 }
495}
496
497#[allow(clippy::too_many_arguments)]
498fn cross_attend_prompt(
499 q: &[f32],
500 k: &[f32],
501 v: &[f32],
502 in_proj_w_t: &[f32],
503 in_proj_b: &[f32],
504 in_gguf_key: Option<&str>,
505 out_proj_w_t: &[f32],
506 out_proj_b: &[f32],
507 out_gguf_key: Option<&str>,
508 gguf_packed: Option<&GgufPackedParams>,
509 batch: usize,
510 l_q: usize,
511 l_k: usize,
512 embed_dim: usize,
513 num_heads: usize,
514 key_padding_mask: Option<&[u8]>,
515) -> Result<Vec<f32>> {
516 if in_gguf_key.is_none() && out_gguf_key.is_none() {
517 return multihead_attention(
518 q,
519 k,
520 v,
521 in_proj_w_t,
522 in_proj_b,
523 out_proj_w_t,
524 out_proj_b,
525 batch,
526 l_q,
527 l_k,
528 embed_dim,
529 num_heads,
530 key_padding_mask,
531 );
532 }
533 ensure!(
534 embed_dim.is_multiple_of(num_heads),
535 "embed_dim {embed_dim} not divisible by num_heads {num_heads}"
536 );
537 let head_dim = embed_dim / num_heads;
538 let rows_q = batch * l_q;
539 let rows_k = batch * l_k;
540
541 let (q_proj, k_proj, v_proj) = if let Some(in_key) = in_gguf_key {
542 let qkv_q = linear_maybe_gguf(
543 q,
544 rows_q,
545 embed_dim,
546 in_proj_w_t,
547 Some(in_key),
548 gguf_packed,
549 3 * embed_dim,
550 in_proj_b,
551 )?;
552 let qkv_k = linear_maybe_gguf(
553 k,
554 rows_k,
555 embed_dim,
556 in_proj_w_t,
557 Some(in_key),
558 gguf_packed,
559 3 * embed_dim,
560 in_proj_b,
561 )?;
562 let qkv_v = linear_maybe_gguf(
563 v,
564 rows_k,
565 embed_dim,
566 in_proj_w_t,
567 Some(in_key),
568 gguf_packed,
569 3 * embed_dim,
570 in_proj_b,
571 )?;
572 (
573 narrow_last(qkv_q, rows_q, embed_dim, 0, embed_dim),
574 narrow_last(qkv_k, rows_k, embed_dim, embed_dim, embed_dim),
575 narrow_last(qkv_v, rows_k, embed_dim, 2 * embed_dim, embed_dim),
576 )
577 } else {
578 let (wq, wk, wv) = split_in_proj_w(in_proj_w_t, embed_dim);
579 let bq = &in_proj_b[0..embed_dim];
580 let bk = &in_proj_b[embed_dim..2 * embed_dim];
581 let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
582 (
583 linear_maybe_gguf(q, rows_q, embed_dim, &wq, None, gguf_packed, embed_dim, bq)?,
584 linear_maybe_gguf(k, rows_k, embed_dim, &wk, None, gguf_packed, embed_dim, bk)?,
585 linear_maybe_gguf(v, rows_k, embed_dim, &wv, None, gguf_packed, embed_dim, bv)?,
586 )
587 };
588
589 let bh = batch * num_heads;
590 let mut qh = vec![0f32; bh * l_q * head_dim];
591 let mut kh = vec![0f32; bh * l_k * head_dim];
592 let mut vh = vec![0f32; bh * l_k * head_dim];
593 repack_heads(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
594 repack_heads(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
595 repack_heads(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
596
597 let scale = 1.0f32 / (head_dim as f32).sqrt();
598 let mut scores = vec![0f32; l_q * l_k];
599 let mut attn_out = vec![0f32; bh * l_q * head_dim];
600 for bi in 0..batch {
601 for h in 0..num_heads {
602 let bhi = bi * num_heads + h;
603 let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
604 let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
605 let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
606 matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
607 if let Some(mask) = key_padding_mask {
608 let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
609 for r in 0..l_q {
610 let row = &mut scores[r * l_k..(r + 1) * l_k];
611 for (c, m) in mask_b.iter().enumerate() {
612 if *m != 0 {
613 row[c] = f32::NEG_INFINITY;
614 }
615 }
616 }
617 }
618 softmax_rows(&mut scores, l_q, l_k);
619 let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
620 matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
621 }
622 }
623
624 let mut packed = vec![0f32; batch * l_q * embed_dim];
625 for bi in 0..batch {
626 for l in 0..l_q {
627 for h in 0..num_heads {
628 let src = ((bi * num_heads + h) * l_q + l) * head_dim;
629 let dst = (bi * l_q + l) * embed_dim + h * head_dim;
630 packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
631 }
632 }
633 }
634 linear_maybe_gguf(
635 &packed,
636 batch * l_q,
637 embed_dim,
638 out_proj_w_t,
639 out_gguf_key,
640 gguf_packed,
641 embed_dim,
642 out_proj_b,
643 )
644}
645
646fn narrow_last(qkv: Vec<f32>, rows: usize, width: usize, start: usize, len: usize) -> Vec<f32> {
647 let mut out = vec![0f32; rows * len];
648 for r in 0..rows {
649 for i in 0..len {
650 out[r * len + i] = qkv[r * width + start + i];
651 }
652 }
653 out
654}
655
656fn mlp3_forward_gguf(
657 mlp: &Mlp3,
658 w0_key: Option<&str>,
659 w1_key: Option<&str>,
660 w2_key: Option<&str>,
661 gguf_packed: Option<&GgufPackedParams>,
662 x: &[f32],
663 rows: usize,
664) -> Result<Vec<f32>> {
665 let mut h = linear_maybe_gguf(
666 x,
667 rows,
668 mlp.in_dim,
669 &mlp.w0_t,
670 w0_key,
671 gguf_packed,
672 mlp.hidden,
673 &mlp.b0,
674 )?;
675 for v in h.iter_mut() {
676 if *v < 0.0 {
677 *v = 0.0;
678 }
679 }
680 h = linear_maybe_gguf(
681 &h,
682 rows,
683 mlp.hidden,
684 &mlp.w1_t,
685 w1_key,
686 gguf_packed,
687 mlp.hidden,
688 &mlp.b1,
689 )?;
690 for v in h.iter_mut() {
691 if *v < 0.0 {
692 *v = 0.0;
693 }
694 }
695 linear_maybe_gguf(
696 &h,
697 rows,
698 mlp.hidden,
699 &mlp.w2_t,
700 w2_key,
701 gguf_packed,
702 mlp.out_dim,
703 &mlp.b2,
704 )
705}
706
707#[allow(clippy::too_many_arguments)]
708pub fn forward_segmentation(
709 weights: &mut Sam3SegmentationHeadWeights,
710 enc_memory_bf: &[f32],
711 backbone_fpn: &[Vec<f32>],
712 backbone_shapes: &[(usize, usize)],
713 obj_queries_last_bf: &[f32],
714 prompt_seq_first: &[f32],
715 prompt_kpm: &[u8],
716 batch: usize,
717 enc_h: usize,
718 enc_w: usize,
719 num_queries: usize,
720 seq_len: usize,
721 gguf_packed: Option<&GgufPackedParams>,
722) -> Result<Sam3SegmentationOutput> {
723 ensure!(weights.loaded, "SAM3 segmentation head not loaded");
724 ensure!(batch == 1, "batch > 1 not supported yet");
725 ensure!(
726 backbone_fpn.len() == 3,
727 "expected 3 FPN levels (after scalp)"
728 );
729
730 let hw = enc_h * enc_w;
731 let norm_mem = layer_norm(
732 enc_memory_bf,
733 &weights.cross_attn_norm_w,
734 &weights.cross_attn_norm_b,
735 D_MODEL,
736 1e-5,
737 )?;
738 let mut prompt_bf = vec![0f32; batch * seq_len * D_MODEL];
739 for b in 0..batch {
740 for l in 0..seq_len {
741 let s = (l * batch + b) * D_MODEL;
742 let d = (b * seq_len + l) * D_MODEL;
743 prompt_bf[d..d + D_MODEL].copy_from_slice(&prompt_seq_first[s..s + D_MODEL]);
744 }
745 }
746 let ca = cross_attend_prompt(
747 &norm_mem,
748 &prompt_bf,
749 &prompt_bf,
750 &weights.cross_attend_in_w_t,
751 &weights.cross_attend_in_b,
752 weights.cross_attend_in_gguf_key.as_deref(),
753 &weights.cross_attend_out_w_t,
754 &weights.cross_attend_out_b,
755 weights.cross_attend_out_gguf_key.as_deref(),
756 gguf_packed,
757 batch,
758 hw,
759 seq_len,
760 D_MODEL,
761 N_HEADS,
762 Some(prompt_kpm),
763 )?;
764 let mut enc_refined = enc_memory_bf.to_vec();
765 for i in 0..enc_refined.len() {
766 enc_refined[i] += ca[i];
767 }
768 let mut enc_visual = vec![0f32; batch * D_MODEL * hw];
769 for b in 0..batch {
770 for y in 0..enc_h {
771 for xc in 0..enc_w {
772 for c in 0..D_MODEL {
773 enc_visual[((b * D_MODEL + c) * enc_h + y) * enc_w + xc] =
774 enc_refined[(b * hw + y * enc_w + xc) * D_MODEL + c];
775 }
776 }
777 }
778 }
779
780 let mut levels = backbone_fpn.to_vec();
781 levels[2] = enc_visual;
782 let mut shapes = backbone_shapes.to_vec();
783 shapes[2] = (enc_h, enc_w);
784
785 let mut prev = levels.pop().unwrap();
786 let (mut ph, mut pw) = shapes.pop().unwrap();
787
788 if weights.pixel_steps.len() == 2 {
789 for (i, (curr, (ch, cw))) in levels.iter().rev().zip(shapes.iter().rev()).enumerate() {
790 prev = weights.pixel_steps[i].run(&prev, curr)?;
791 ph = *ch;
792 pw = *cw;
793 }
794 } else {
795 for (i, (curr, (ch, cw))) in levels.iter().rev().zip(shapes.iter().rev()).enumerate() {
796 let up = nearest_upsample_nchw(&prev, D_MODEL, ph, pw, *ch, *cw);
797 let mut combined = vec![0f32; curr.len()];
798 for j in 0..combined.len() {
799 combined[j] = curr[j] + up[j];
800 }
801 let conv = conv2d_3x3_pad1_maybe_gguf(
802 &combined,
803 D_MODEL,
804 *ch,
805 *cw,
806 &weights.pixel_conv_w[i],
807 weights.pixel_conv_gguf_keys[i].as_deref(),
808 gguf_packed,
809 &weights.pixel_conv_b[i],
810 &mut weights.pixel_conv_nchw_cache[i],
811 )?;
812 let mut relud = group_norm(
813 &conv,
814 batch,
815 D_MODEL,
816 *ch,
817 *cw,
818 8,
819 &weights.pixel_gn_w[i],
820 &weights.pixel_gn_b[i],
821 );
822 for v in relud.iter_mut() {
823 if *v < 0.0 {
824 *v = 0.0;
825 }
826 }
827 prev = relud;
828 ph = *ch;
829 pw = *cw;
830 }
831 }
832 let pixel_embed = prev;
833
834 let inst = if let Some(ref mut head) = weights.inst_head {
835 head.run(&pixel_embed)?
836 } else {
837 conv2d_1x1_maybe_gguf(
838 &pixel_embed,
839 D_MODEL,
840 D_MODEL,
841 ph,
842 pw,
843 &weights.inst_w,
844 weights.inst_gguf_key.as_deref(),
845 gguf_packed,
846 &weights.inst_b,
847 )?
848 };
849
850 let mask_embed_out = mlp3_forward_gguf(
851 &weights.mask_embed,
852 weights.mask_embed_w0_gguf_key.as_deref(),
853 weights.mask_embed_w1_gguf_key.as_deref(),
854 weights.mask_embed_w2_gguf_key.as_deref(),
855 gguf_packed,
856 obj_queries_last_bf,
857 batch * num_queries,
858 )?;
859 let mut mask_pred = vec![0f32; batch * num_queries * ph * pw];
860 for b in 0..batch {
861 for q in 0..num_queries {
862 for c in 0..D_MODEL {
863 let qcoeff = mask_embed_out[(b * num_queries + q) * D_MODEL + c];
864 if qcoeff == 0.0 {
865 continue;
866 }
867 let plane =
868 &inst[((b * D_MODEL + c) * ph * pw)..((b * D_MODEL + c) * ph * pw + ph * pw)];
869 let dst = &mut mask_pred
870 [(b * num_queries + q) * ph * pw..(b * num_queries + q + 1) * ph * pw];
871 for p in 0..ph * pw {
872 dst[p] += qcoeff * plane[p];
873 }
874 }
875 }
876 }
877
878 let semantic_seg = if let Some(ref mut head) = weights.sem_head {
879 head.run(&pixel_embed)?
880 } else {
881 conv2d_1x1_maybe_gguf(
882 &pixel_embed,
883 D_MODEL,
884 1,
885 ph,
886 pw,
887 &weights.sem_w,
888 weights.sem_gguf_key.as_deref(),
889 gguf_packed,
890 &weights.sem_b,
891 )?
892 };
893
894 Ok(Sam3SegmentationOutput {
895 mask_pred,
896 semantic_seg,
897 h_out: ph,
898 w_out: pw,
899 num_queries,
900 })
901}
902
903#[allow(clippy::too_many_arguments)]
904pub fn forward_dot_prod_scoring(
905 weights: &Sam3DotProductScoringWeights,
906 hs_bf: &[f32],
907 prompt_seq_first: &[f32],
908 prompt_kpm: &[u8],
909 num_layers: usize,
910 batch: usize,
911 num_queries: usize,
912 seq_len: usize,
913 gguf_packed: Option<&GgufPackedParams>,
914) -> Result<Vec<f32>> {
915 ensure!(weights.loaded, "SAM3 dot product scoring not loaded");
916 let rows = seq_len * batch;
917 let pm = &weights.prompt_mlp;
918 let mut h = linear_maybe_gguf(
919 prompt_seq_first,
920 rows,
921 pm.in_dim,
922 &pm.w0_t,
923 weights.prompt_mlp_w0_gguf_key.as_deref(),
924 gguf_packed,
925 pm.hidden,
926 &pm.b0,
927 )?;
928 for v in h.iter_mut() {
929 if *v < 0.0 {
930 *v = 0.0;
931 }
932 }
933 h = linear_maybe_gguf(
934 &h,
935 rows,
936 pm.hidden,
937 &pm.w1_t,
938 weights.prompt_mlp_w1_gguf_key.as_deref(),
939 gguf_packed,
940 pm.out_dim,
941 &pm.b1,
942 )?;
943 for i in 0..h.len() {
944 h[i] += prompt_seq_first[i];
945 }
946 let h = layer_norm(
947 &h,
948 &weights.prompt_mlp_out_norm_w,
949 &weights.prompt_mlp_out_norm_b,
950 D_MODEL,
951 1e-5,
952 )?;
953
954 let mut pooled = vec![0f32; batch * D_MODEL];
955 let mut counts = vec![0.0f32; batch];
956 for b in 0..batch {
957 for l in 0..seq_len {
958 if prompt_kpm[b * seq_len + l] == 0 {
959 let src = (l * batch + b) * D_MODEL;
960 let dst = b * D_MODEL;
961 for c in 0..D_MODEL {
962 pooled[dst + c] += h[src + c];
963 }
964 counts[b] += 1.0;
965 }
966 }
967 }
968 for b in 0..batch {
969 let denom = counts[b].max(1.0);
970 for c in 0..D_MODEL {
971 pooled[b * D_MODEL + c] /= denom;
972 }
973 }
974
975 let proj_pooled = linear_maybe_gguf(
976 &pooled,
977 batch,
978 D_MODEL,
979 &weights.prompt_proj_w_t,
980 weights.prompt_proj_gguf_key.as_deref(),
981 gguf_packed,
982 D_MODEL,
983 &weights.prompt_proj_b,
984 )?;
985 let proj_hs = linear_maybe_gguf(
986 hs_bf,
987 num_layers * batch * num_queries,
988 D_MODEL,
989 &weights.hs_proj_w_t,
990 weights.hs_proj_gguf_key.as_deref(),
991 gguf_packed,
992 D_MODEL,
993 &weights.hs_proj_b,
994 )?;
995
996 let scale = 1.0f32 / (D_MODEL as f32).sqrt();
997 let clamp = 12.0f32;
998 let mut scores = vec![0f32; num_layers * batch * num_queries];
999 for l in 0..num_layers {
1000 for b in 0..batch {
1001 let pp = &proj_pooled[b * D_MODEL..(b + 1) * D_MODEL];
1002 for q in 0..num_queries {
1003 let row = &proj_hs[((l * batch + b) * num_queries + q) * D_MODEL
1004 ..((l * batch + b) * num_queries + q + 1) * D_MODEL];
1005 let mut acc = 0.0f32;
1006 for c in 0..D_MODEL {
1007 acc += row[c] * pp[c];
1008 }
1009 let s = (acc * scale).clamp(-clamp, clamp);
1010 scores[(l * batch + b) * num_queries + q] = s;
1011 }
1012 }
1013 }
1014 Ok(scores)
1015}
1016
1017fn nearest_upsample_nchw(
1018 x: &[f32],
1019 c: usize,
1020 src_h: usize,
1021 src_w: usize,
1022 dst_h: usize,
1023 dst_w: usize,
1024) -> Vec<f32> {
1025 let mut out = vec![0f32; c * dst_h * dst_w];
1026 for cc in 0..c {
1027 let inp = &x[cc * src_h * src_w..(cc + 1) * src_h * src_w];
1028 let oup = &mut out[cc * dst_h * dst_w..(cc + 1) * dst_h * dst_w];
1029 for y in 0..dst_h {
1030 let sy = y * src_h / dst_h;
1031 for x in 0..dst_w {
1032 let sx = x * src_w / dst_w;
1033 oup[y * dst_w + x] = inp[sy * src_w + sx];
1034 }
1035 }
1036 }
1037 out
1038}
1039
1040fn conv2d_3x3_pad1_maybe_gguf(
1041 input: &[f32],
1042 c: usize,
1043 h: usize,
1044 w: usize,
1045 weight: &[f32],
1046 weight_gguf_key: Option<&str>,
1047 gguf_packed: Option<&GgufPackedParams>,
1048 bias: &[f32],
1049 nchw_cache: &mut Option<Vec<f32>>,
1050) -> Result<Vec<f32>> {
1051 if !weight.is_empty() {
1052 return Ok(conv2d_3x3_nchw_pad1(input, c, h, w, weight, bias));
1053 }
1054 let key = weight_gguf_key
1055 .ok_or_else(|| anyhow::anyhow!("conv3: missing F32 weights and GGUF key"))?;
1056 let p = gguf_packed
1057 .and_then(|m| packed_linear(m, key))
1058 .ok_or_else(|| anyhow::anyhow!("missing packed conv3 weight: {key}"))?;
1059 conv2d_3x3_nchw_gguf(input, c, h, w, p, bias, nchw_cache)
1060}
1061
1062fn conv2d_1x1(
1063 input: &[f32],
1064 in_c: usize,
1065 out_c: usize,
1066 h: usize,
1067 w: usize,
1068 weight: &[f32],
1069 bias: &[f32],
1070) -> Vec<f32> {
1071 let n = h * w;
1072 let mut out = vec![0f32; out_c * n];
1073 rlx_cpu::blas::sgemm(weight, input, &mut out, out_c, in_c, n);
1074 for oc in 0..out_c {
1075 let b = bias[oc];
1076 let row = &mut out[oc * n..(oc + 1) * n];
1077 for v in row {
1078 *v += b;
1079 }
1080 }
1081 out
1082}
1083
1084fn conv2d_1x1_maybe_gguf(
1085 input: &[f32],
1086 in_c: usize,
1087 out_c: usize,
1088 h: usize,
1089 w: usize,
1090 weight: &[f32],
1091 weight_gguf_key: Option<&str>,
1092 gguf_packed: Option<&GgufPackedParams>,
1093 bias: &[f32],
1094) -> Result<Vec<f32>> {
1095 if weight_gguf_key.is_none() {
1096 return Ok(conv2d_1x1(input, in_c, out_c, h, w, weight, bias));
1097 }
1098 let n = h * w;
1099 let mut rows = vec![0f32; n * in_c];
1100 for ic in 0..in_c {
1101 for p in 0..n {
1102 rows[p * in_c + ic] = input[ic * n + p];
1103 }
1104 }
1105 let flat = linear_maybe_gguf(
1106 &rows,
1107 n,
1108 in_c,
1109 weight,
1110 weight_gguf_key,
1111 gguf_packed,
1112 out_c,
1113 bias,
1114 )?;
1115 let mut out = vec![0f32; out_c * n];
1116 for oc in 0..out_c {
1117 for p in 0..n {
1118 out[oc * n + p] = flat[p * out_c + oc];
1119 }
1120 }
1121 Ok(out)
1122}
1123
1124fn group_norm(
1125 x: &[f32],
1126 batch: usize,
1127 channels: usize,
1128 h: usize,
1129 w: usize,
1130 num_groups: usize,
1131 gamma: &[f32],
1132 beta: &[f32],
1133) -> Vec<f32> {
1134 assert!(channels.is_multiple_of(num_groups));
1135 let cpg = channels / num_groups;
1136 let spatial = h * w;
1137 let mut out = vec![0f32; batch * channels * spatial];
1138 for b in 0..batch {
1139 for g in 0..num_groups {
1140 let c0 = g * cpg;
1141 let n = (cpg * spatial) as f32;
1142 let mut mean = 0.0f32;
1143 for c in 0..cpg {
1144 let plane = &x
1145 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1146 for v in plane {
1147 mean += *v;
1148 }
1149 }
1150 mean /= n;
1151 let mut var = 0.0f32;
1152 for c in 0..cpg {
1153 let plane = &x
1154 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1155 for v in plane {
1156 let d = *v - mean;
1157 var += d * d;
1158 }
1159 }
1160 var /= n;
1161 let inv = 1.0 / (var + 1e-5).sqrt();
1162 for c in 0..cpg {
1163 let src = &x
1164 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1165 let dst = &mut out
1166 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1167 let g_ = gamma[c0 + c];
1168 let bias = beta[c0 + c];
1169 for (s, d) in src.iter().zip(dst.iter_mut()) {
1170 *d = (*s - mean) * inv * g_ + bias;
1171 }
1172 }
1173 }
1174 }
1175 out
1176}
1177
1178pub fn segmentation_forward_native(
1180 _weights: &Sam3SegmentationHeadWeights,
1181 detector: &Sam3DetectorOutput,
1182 h_out: usize,
1183 w_out: usize,
1184) -> Sam3ImagePrediction {
1185 Sam3ImagePrediction {
1186 masks: vec![0.0; detector.num_queries * h_out * w_out],
1187 mask_shape: vec![detector.num_queries, h_out, w_out],
1188 boxes: vec![0.0; detector.num_queries * 4],
1189 boxes_shape: vec![detector.num_queries, 4],
1190 scores: vec![0.0; detector.num_queries],
1191 scores_shape: vec![detector.num_queries],
1192 num_instances: detector.num_queries,
1193 h_out,
1194 w_out,
1195 }
1196}