Skip to main content

rlx_sam3/
segmentation_head.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 segmentation head + dot-product scoring.
17//!
18//! Mirrors `sam3.model.maskformer_segmentation.UniversalSegmentationHead`
19//! and `sam3.model.model_misc.DotProductScoring` as configured in
20//! `model_builder._create_segmentation_head` / `_create_dot_product_scoring`.
21
22use super::detector::Sam3DetectorOutput;
23use super::detector_decoder::{Mlp2, Mlp3};
24use super::sam3::Sam3ImagePrediction;
25use super::segmentation_pixel_ir::{
26    Sam3Conv1x1Compiled, Sam3PixelDecoderStepCompiled, compile_pixel_decoder_steps,
27};
28use super::tensor::{layer_norm, matmul, matmul_bt, multihead_attention, softmax_rows};
29use rlx_core::weight_map::WeightMap;
30use rlx_flow::GgufPackedParams;
31
32use crate::packed_gguf::{
33    conv2d_3x3_nchw_gguf, conv2d_3x3_nchw_pad1, gguf_packed_conv1_to_nchw,
34    gguf_packed_conv3_to_f32, linear_maybe_gguf, packed_linear, take_conv1x1_with_gguf_key,
35    take_conv3x3_with_gguf_key, take_or_gguf, take_transposed_with_gguf_key,
36};
37use anyhow::{Result, ensure};
38use rlx_runtime::Device;
39
40const D_MODEL: usize = 256;
41const N_HEADS: usize = 8;
42
43#[derive(Default)]
44pub struct Sam3SegmentationHeadWeights {
45    pub loaded: bool,
46    pub cross_attn_norm_w: Vec<f32>,
47    pub cross_attn_norm_b: Vec<f32>,
48    pub cross_attend_in_w_t: Vec<f32>,
49    pub cross_attend_in_b: Vec<f32>,
50    pub cross_attend_out_w_t: Vec<f32>,
51    pub cross_attend_out_b: Vec<f32>,
52    pub cross_attend_in_gguf_key: Option<String>,
53    pub cross_attend_out_gguf_key: Option<String>,
54    pub mask_embed_w0_gguf_key: Option<String>,
55    pub mask_embed_w1_gguf_key: Option<String>,
56    pub mask_embed_w2_gguf_key: Option<String>,
57    pub pixel_conv_w: Vec<Vec<f32>>,
58    pub pixel_conv_b: Vec<Vec<f32>>,
59    pub pixel_conv_gguf_keys: Vec<Option<String>>,
60    /// Lazy NCHW F32 `[c,c,3,3]` for packed 3×3 conv (host path when IR not compiled).
61    pub pixel_conv_nchw_cache: Vec<Option<Vec<f32>>>,
62    pub pixel_gn_w: Vec<Vec<f32>>,
63    pub pixel_gn_b: Vec<Vec<f32>>,
64    pub inst_w: Vec<f32>,
65    pub inst_b: Vec<f32>,
66    pub inst_gguf_key: Option<String>,
67    pub sem_w: Vec<f32>,
68    pub sem_b: Vec<f32>,
69    pub sem_gguf_key: Option<String>,
70    pub mask_embed: Mlp3,
71    pub pixel_steps: Vec<Sam3PixelDecoderStepCompiled>,
72    pub inst_head: Option<Sam3Conv1x1Compiled>,
73    pub sem_head: Option<Sam3Conv1x1Compiled>,
74}
75
76#[derive(Clone, Default)]
77pub struct Sam3DotProductScoringWeights {
78    pub loaded: bool,
79    pub prompt_mlp: Mlp2,
80    pub prompt_mlp_out_norm_w: Vec<f32>,
81    pub prompt_mlp_out_norm_b: Vec<f32>,
82    pub prompt_proj_w_t: Vec<f32>,
83    pub prompt_proj_b: Vec<f32>,
84    pub hs_proj_w_t: Vec<f32>,
85    pub hs_proj_b: Vec<f32>,
86    pub prompt_mlp_w0_gguf_key: Option<String>,
87    pub prompt_mlp_w1_gguf_key: Option<String>,
88    pub prompt_proj_gguf_key: Option<String>,
89    pub hs_proj_gguf_key: Option<String>,
90}
91
92pub fn extract_segmentation_head_weights(
93    weights: &mut WeightMap,
94    gguf_packed: Option<&GgufPackedParams>,
95) -> Result<Sam3SegmentationHeadWeights> {
96    let base = "detector.segmentation_head";
97
98    let (cross_attn_norm_w, _) = take_or_gguf(
99        weights,
100        gguf_packed,
101        &format!("{base}.cross_attn_norm.weight"),
102    )?;
103    let (cross_attn_norm_b, _) = take_or_gguf(
104        weights,
105        gguf_packed,
106        &format!("{base}.cross_attn_norm.bias"),
107    )?;
108    let (cross_attend_in_w_t, cross_attend_in_gguf_key) = take_transposed_with_gguf_key(
109        weights,
110        gguf_packed,
111        &format!("{base}.cross_attend_prompt.in_proj_weight"),
112    )?;
113    let (cross_attend_in_b, _) = take_or_gguf(
114        weights,
115        gguf_packed,
116        &format!("{base}.cross_attend_prompt.in_proj_bias"),
117    )?;
118    let (cross_attend_out_w_t, cross_attend_out_gguf_key) = take_transposed_with_gguf_key(
119        weights,
120        gguf_packed,
121        &format!("{base}.cross_attend_prompt.out_proj.weight"),
122    )?;
123    let (cross_attend_out_b, _) = take_or_gguf(
124        weights,
125        gguf_packed,
126        &format!("{base}.cross_attend_prompt.out_proj.bias"),
127    )?;
128
129    let mut pixel_conv_w = Vec::new();
130    let mut pixel_conv_b = Vec::new();
131    let mut pixel_conv_gguf_keys = Vec::new();
132    let mut pixel_gn_w = Vec::new();
133    let mut pixel_gn_b = Vec::new();
134    for i in 0..3 {
135        let (cw, cs, ck) = take_conv3x3_with_gguf_key(
136            weights,
137            gguf_packed,
138            &format!("{base}.pixel_decoder.conv_layers.{i}.weight"),
139        )?;
140        ensure!(
141            cs == vec![D_MODEL, D_MODEL, 3, 3],
142            "pixel_decoder conv {i} shape {cs:?}"
143        );
144        let (cb, _) = take_or_gguf(
145            weights,
146            gguf_packed,
147            &format!("{base}.pixel_decoder.conv_layers.{i}.bias"),
148        )?;
149        let (nw, _) = take_or_gguf(
150            weights,
151            gguf_packed,
152            &format!("{base}.pixel_decoder.norms.{i}.weight"),
153        )?;
154        let (nb, _) = take_or_gguf(
155            weights,
156            gguf_packed,
157            &format!("{base}.pixel_decoder.norms.{i}.bias"),
158        )?;
159        pixel_conv_w.push(cw);
160        pixel_conv_b.push(cb);
161        pixel_conv_gguf_keys.push(ck);
162        pixel_gn_w.push(nw);
163        pixel_gn_b.push(nb);
164    }
165
166    let (inst_w, ins, inst_gguf_key) = take_conv1x1_with_gguf_key(
167        weights,
168        gguf_packed,
169        &format!("{base}.instance_seg_head.weight"),
170    )?;
171    ensure!(
172        ins == vec![D_MODEL, D_MODEL, 1, 1],
173        "instance_seg_head shape {ins:?}"
174    );
175    let (inst_b, _) = take_or_gguf(
176        weights,
177        gguf_packed,
178        &format!("{base}.instance_seg_head.bias"),
179    )?;
180    let (sem_w, ss, sem_gguf_key) = take_conv1x1_with_gguf_key(
181        weights,
182        gguf_packed,
183        &format!("{base}.semantic_seg_head.weight"),
184    )?;
185    ensure!(
186        ss == vec![1, D_MODEL, 1, 1],
187        "semantic_seg_head shape {ss:?}"
188    );
189    let (sem_b, _) = take_or_gguf(
190        weights,
191        gguf_packed,
192        &format!("{base}.semantic_seg_head.bias"),
193    )?;
194
195    let (m0_t, mask_embed_w0_gguf_key) = take_transposed_with_gguf_key(
196        weights,
197        gguf_packed,
198        &format!("{base}.mask_predictor.mask_embed.layers.0.weight"),
199    )?;
200    let (m0_b, _) = take_or_gguf(
201        weights,
202        gguf_packed,
203        &format!("{base}.mask_predictor.mask_embed.layers.0.bias"),
204    )?;
205    let (m1_t, mask_embed_w1_gguf_key) = take_transposed_with_gguf_key(
206        weights,
207        gguf_packed,
208        &format!("{base}.mask_predictor.mask_embed.layers.1.weight"),
209    )?;
210    let (m1_b, _) = take_or_gguf(
211        weights,
212        gguf_packed,
213        &format!("{base}.mask_predictor.mask_embed.layers.1.bias"),
214    )?;
215    let (m2_t, mask_embed_w2_gguf_key) = take_transposed_with_gguf_key(
216        weights,
217        gguf_packed,
218        &format!("{base}.mask_predictor.mask_embed.layers.2.weight"),
219    )?;
220    let (m2_b, _) = take_or_gguf(
221        weights,
222        gguf_packed,
223        &format!("{base}.mask_predictor.mask_embed.layers.2.bias"),
224    )?;
225    let mask_embed = Mlp3 {
226        w0_t: m0_t,
227        b0: m0_b,
228        w1_t: m1_t,
229        b1: m1_b,
230        w2_t: m2_t,
231        b2: m2_b,
232        in_dim: D_MODEL,
233        hidden: D_MODEL,
234        out_dim: D_MODEL,
235        w0_gguf_key: mask_embed_w0_gguf_key.clone(),
236        w1_gguf_key: mask_embed_w1_gguf_key.clone(),
237        w2_gguf_key: mask_embed_w2_gguf_key.clone(),
238    };
239
240    Ok(Sam3SegmentationHeadWeights {
241        loaded: true,
242        cross_attn_norm_w,
243        cross_attn_norm_b,
244        cross_attend_in_w_t,
245        cross_attend_in_b,
246        cross_attend_out_w_t,
247        cross_attend_out_b,
248        cross_attend_in_gguf_key,
249        cross_attend_out_gguf_key,
250        mask_embed_w0_gguf_key,
251        mask_embed_w1_gguf_key,
252        mask_embed_w2_gguf_key,
253        pixel_conv_w,
254        pixel_conv_b,
255        pixel_conv_gguf_keys,
256        pixel_conv_nchw_cache: vec![None; 3],
257        pixel_gn_w,
258        pixel_gn_b,
259        inst_w,
260        inst_b,
261        inst_gguf_key,
262        sem_w,
263        sem_b,
264        sem_gguf_key,
265        mask_embed,
266        pixel_steps: Vec::new(),
267        inst_head: None,
268        sem_head: None,
269    })
270}
271
272/// Fill empty F32 conv tensors from packed GGUF so IR compile can run (one-time dequant).
273pub fn materialize_segmentation_gguf_weights(
274    weights: &mut Sam3SegmentationHeadWeights,
275    gguf_packed: Option<&GgufPackedParams>,
276) -> Result<()> {
277    let Some(gguf) = gguf_packed else {
278        return Ok(());
279    };
280    for i in 0..weights.pixel_conv_gguf_keys.len() {
281        if weights.pixel_conv_w[i].is_empty() {
282            if let Some(key) = &weights.pixel_conv_gguf_keys[i] {
283                let p = packed_linear(gguf, key)
284                    .ok_or_else(|| anyhow::anyhow!("missing packed pixel conv: {key}"))?;
285                weights.pixel_conv_w[i] = gguf_packed_conv3_to_f32(p, D_MODEL, D_MODEL)?;
286            }
287        }
288    }
289    if weights.inst_w.is_empty() {
290        if let Some(key) = &weights.inst_gguf_key {
291            weights.inst_w = gguf_packed_conv1_to_nchw(gguf, key, D_MODEL, D_MODEL)?;
292        }
293    }
294    if weights.sem_w.is_empty() {
295        if let Some(key) = &weights.sem_gguf_key {
296            weights.sem_w = gguf_packed_conv1_to_nchw(gguf, key, 1, D_MODEL)?;
297        }
298    }
299    Ok(())
300}
301
302/// Compile pixel-decoder IR graphs for SAM3 base (72×72 trunk grid).
303pub fn compile_segmentation_ir(
304    weights: &mut Sam3SegmentationHeadWeights,
305    gguf_packed: Option<&GgufPackedParams>,
306    trunk_grid: usize,
307    device: Device,
308    profile: &rlx_flow::CompileProfile,
309) -> Result<()> {
310    if !weights.loaded {
311        return Ok(());
312    }
313    materialize_segmentation_gguf_weights(weights, gguf_packed)?;
314
315    if !weights.pixel_conv_w[0].is_empty() {
316        weights.pixel_steps = compile_pixel_decoder_steps(
317            &weights.pixel_conv_w,
318            &weights.pixel_conv_b,
319            &weights.pixel_gn_w,
320            &weights.pixel_gn_b,
321            trunk_grid,
322            device,
323            profile,
324        )?;
325    }
326
327    let g2 = trunk_grid * 4;
328    if let Some(gguf) = gguf_packed {
329        if weights.inst_gguf_key.is_some() || !weights.inst_w.is_empty() {
330            weights.inst_head = Some(Sam3Conv1x1Compiled::compile_with_gguf(
331                D_MODEL,
332                D_MODEL,
333                g2,
334                g2,
335                &weights.inst_w,
336                &weights.inst_b,
337                weights.inst_gguf_key.as_deref(),
338                gguf,
339                device,
340                profile,
341            )?);
342        }
343        if weights.sem_gguf_key.is_some() || !weights.sem_w.is_empty() {
344            weights.sem_head = Some(Sam3Conv1x1Compiled::compile_with_gguf(
345                D_MODEL,
346                1,
347                g2,
348                g2,
349                &weights.sem_w,
350                &weights.sem_b,
351                weights.sem_gguf_key.as_deref(),
352                gguf,
353                device,
354                profile,
355            )?);
356        }
357    } else {
358        if !weights.inst_w.is_empty() {
359            weights.inst_head = Some(Sam3Conv1x1Compiled::compile_with_profile(
360                D_MODEL,
361                D_MODEL,
362                g2,
363                g2,
364                &weights.inst_w,
365                &weights.inst_b,
366                device,
367                profile,
368            )?);
369        }
370        if !weights.sem_w.is_empty() {
371            weights.sem_head = Some(Sam3Conv1x1Compiled::compile_with_profile(
372                D_MODEL,
373                1,
374                g2,
375                g2,
376                &weights.sem_w,
377                &weights.sem_b,
378                device,
379                profile,
380            )?);
381        }
382    }
383    Ok(())
384}
385
386pub fn extract_dot_product_scoring_weights(
387    weights: &mut WeightMap,
388    gguf_packed: Option<&GgufPackedParams>,
389) -> Result<Sam3DotProductScoringWeights> {
390    let base = "detector.dot_prod_scoring";
391    let (pm0_t, prompt_mlp_w0_gguf_key) = take_transposed_with_gguf_key(
392        weights,
393        gguf_packed,
394        &format!("{base}.prompt_mlp.layers.0.weight"),
395    )?;
396    let (pm0_b, _) = take_or_gguf(
397        weights,
398        gguf_packed,
399        &format!("{base}.prompt_mlp.layers.0.bias"),
400    )?;
401    let (pm1_t, prompt_mlp_w1_gguf_key) = take_transposed_with_gguf_key(
402        weights,
403        gguf_packed,
404        &format!("{base}.prompt_mlp.layers.1.weight"),
405    )?;
406    let (pm1_b, _) = take_or_gguf(
407        weights,
408        gguf_packed,
409        &format!("{base}.prompt_mlp.layers.1.bias"),
410    )?;
411    let prompt_mlp = Mlp2 {
412        w0_t: pm0_t,
413        b0: pm0_b,
414        w1_t: pm1_t,
415        b1: pm1_b,
416        in_dim: D_MODEL,
417        hidden: 2048,
418        out_dim: D_MODEL,
419        w0_gguf_key: prompt_mlp_w0_gguf_key.clone(),
420        w1_gguf_key: prompt_mlp_w1_gguf_key.clone(),
421    };
422    let (pm_norm_w, _) = take_or_gguf(
423        weights,
424        gguf_packed,
425        &format!("{base}.prompt_mlp.out_norm.weight"),
426    )?;
427    let (pm_norm_b, _) = take_or_gguf(
428        weights,
429        gguf_packed,
430        &format!("{base}.prompt_mlp.out_norm.bias"),
431    )?;
432    let (pp_t, prompt_proj_gguf_key) =
433        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.prompt_proj.weight"))?;
434    let (pp_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.prompt_proj.bias"))?;
435    let (hs_t, hs_proj_gguf_key) =
436        take_transposed_with_gguf_key(weights, gguf_packed, &format!("{base}.hs_proj.weight"))?;
437    let (hs_b, _) = take_or_gguf(weights, gguf_packed, &format!("{base}.hs_proj.bias"))?;
438    Ok(Sam3DotProductScoringWeights {
439        loaded: true,
440        prompt_mlp,
441        prompt_mlp_out_norm_w: pm_norm_w,
442        prompt_mlp_out_norm_b: pm_norm_b,
443        prompt_proj_w_t: pp_t,
444        prompt_proj_b: pp_b,
445        hs_proj_w_t: hs_t,
446        hs_proj_b: hs_b,
447        prompt_mlp_w0_gguf_key,
448        prompt_mlp_w1_gguf_key,
449        prompt_proj_gguf_key,
450        hs_proj_gguf_key,
451    })
452}
453
454#[derive(Debug, Clone, Default)]
455pub struct Sam3SegmentationOutput {
456    pub mask_pred: Vec<f32>,
457    pub semantic_seg: Vec<f32>,
458    pub h_out: usize,
459    pub w_out: usize,
460    pub num_queries: usize,
461}
462
463fn split_in_proj_w(in_proj_w_t: &[f32], embed_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
464    let e = embed_dim;
465    let mut wq = vec![0f32; e * e];
466    let mut wk = vec![0f32; e * e];
467    let mut wv = vec![0f32; e * e];
468    for i in 0..e {
469        for j in 0..e {
470            wq[i * e + j] = in_proj_w_t[i * 3 * e + j];
471            wk[i * e + j] = in_proj_w_t[i * 3 * e + e + j];
472            wv[i * e + j] = in_proj_w_t[i * 3 * e + 2 * e + j];
473        }
474    }
475    (wq, wk, wv)
476}
477
478fn repack_heads(
479    flat: &[f32],
480    out: &mut [f32],
481    batch: usize,
482    seq: usize,
483    num_heads: usize,
484    head_dim: usize,
485) {
486    for bi in 0..batch {
487        for l in 0..seq {
488            for h in 0..num_heads {
489                let src = (bi * seq + l) * num_heads * head_dim + h * head_dim;
490                let dst = (bi * num_heads + h) * seq * head_dim + l * head_dim;
491                out[dst..dst + head_dim].copy_from_slice(&flat[src..src + head_dim]);
492            }
493        }
494    }
495}
496
497#[allow(clippy::too_many_arguments)]
498fn cross_attend_prompt(
499    q: &[f32],
500    k: &[f32],
501    v: &[f32],
502    in_proj_w_t: &[f32],
503    in_proj_b: &[f32],
504    in_gguf_key: Option<&str>,
505    out_proj_w_t: &[f32],
506    out_proj_b: &[f32],
507    out_gguf_key: Option<&str>,
508    gguf_packed: Option<&GgufPackedParams>,
509    batch: usize,
510    l_q: usize,
511    l_k: usize,
512    embed_dim: usize,
513    num_heads: usize,
514    key_padding_mask: Option<&[u8]>,
515) -> Result<Vec<f32>> {
516    if in_gguf_key.is_none() && out_gguf_key.is_none() {
517        return multihead_attention(
518            q,
519            k,
520            v,
521            in_proj_w_t,
522            in_proj_b,
523            out_proj_w_t,
524            out_proj_b,
525            batch,
526            l_q,
527            l_k,
528            embed_dim,
529            num_heads,
530            key_padding_mask,
531        );
532    }
533    ensure!(
534        embed_dim.is_multiple_of(num_heads),
535        "embed_dim {embed_dim} not divisible by num_heads {num_heads}"
536    );
537    let head_dim = embed_dim / num_heads;
538    let rows_q = batch * l_q;
539    let rows_k = batch * l_k;
540
541    let (q_proj, k_proj, v_proj) = if let Some(in_key) = in_gguf_key {
542        let qkv_q = linear_maybe_gguf(
543            q,
544            rows_q,
545            embed_dim,
546            in_proj_w_t,
547            Some(in_key),
548            gguf_packed,
549            3 * embed_dim,
550            in_proj_b,
551        )?;
552        let qkv_k = linear_maybe_gguf(
553            k,
554            rows_k,
555            embed_dim,
556            in_proj_w_t,
557            Some(in_key),
558            gguf_packed,
559            3 * embed_dim,
560            in_proj_b,
561        )?;
562        let qkv_v = linear_maybe_gguf(
563            v,
564            rows_k,
565            embed_dim,
566            in_proj_w_t,
567            Some(in_key),
568            gguf_packed,
569            3 * embed_dim,
570            in_proj_b,
571        )?;
572        (
573            narrow_last(qkv_q, rows_q, embed_dim, 0, embed_dim),
574            narrow_last(qkv_k, rows_k, embed_dim, embed_dim, embed_dim),
575            narrow_last(qkv_v, rows_k, embed_dim, 2 * embed_dim, embed_dim),
576        )
577    } else {
578        let (wq, wk, wv) = split_in_proj_w(in_proj_w_t, embed_dim);
579        let bq = &in_proj_b[0..embed_dim];
580        let bk = &in_proj_b[embed_dim..2 * embed_dim];
581        let bv = &in_proj_b[2 * embed_dim..3 * embed_dim];
582        (
583            linear_maybe_gguf(q, rows_q, embed_dim, &wq, None, gguf_packed, embed_dim, bq)?,
584            linear_maybe_gguf(k, rows_k, embed_dim, &wk, None, gguf_packed, embed_dim, bk)?,
585            linear_maybe_gguf(v, rows_k, embed_dim, &wv, None, gguf_packed, embed_dim, bv)?,
586        )
587    };
588
589    let bh = batch * num_heads;
590    let mut qh = vec![0f32; bh * l_q * head_dim];
591    let mut kh = vec![0f32; bh * l_k * head_dim];
592    let mut vh = vec![0f32; bh * l_k * head_dim];
593    repack_heads(&q_proj, &mut qh, batch, l_q, num_heads, head_dim);
594    repack_heads(&k_proj, &mut kh, batch, l_k, num_heads, head_dim);
595    repack_heads(&v_proj, &mut vh, batch, l_k, num_heads, head_dim);
596
597    let scale = 1.0f32 / (head_dim as f32).sqrt();
598    let mut scores = vec![0f32; l_q * l_k];
599    let mut attn_out = vec![0f32; bh * l_q * head_dim];
600    for bi in 0..batch {
601        for h in 0..num_heads {
602            let bhi = bi * num_heads + h;
603            let q_h = &qh[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
604            let k_h = &kh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
605            let v_h = &vh[bhi * l_k * head_dim..(bhi + 1) * l_k * head_dim];
606            matmul_bt(q_h, k_h, &mut scores, l_q, head_dim, l_k, scale);
607            if let Some(mask) = key_padding_mask {
608                let mask_b = &mask[bi * l_k..(bi + 1) * l_k];
609                for r in 0..l_q {
610                    let row = &mut scores[r * l_k..(r + 1) * l_k];
611                    for (c, m) in mask_b.iter().enumerate() {
612                        if *m != 0 {
613                            row[c] = f32::NEG_INFINITY;
614                        }
615                    }
616                }
617            }
618            softmax_rows(&mut scores, l_q, l_k);
619            let out_h = &mut attn_out[bhi * l_q * head_dim..(bhi + 1) * l_q * head_dim];
620            matmul(&scores, v_h, out_h, l_q, l_k, head_dim);
621        }
622    }
623
624    let mut packed = vec![0f32; batch * l_q * embed_dim];
625    for bi in 0..batch {
626        for l in 0..l_q {
627            for h in 0..num_heads {
628                let src = ((bi * num_heads + h) * l_q + l) * head_dim;
629                let dst = (bi * l_q + l) * embed_dim + h * head_dim;
630                packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
631            }
632        }
633    }
634    linear_maybe_gguf(
635        &packed,
636        batch * l_q,
637        embed_dim,
638        out_proj_w_t,
639        out_gguf_key,
640        gguf_packed,
641        embed_dim,
642        out_proj_b,
643    )
644}
645
646fn narrow_last(qkv: Vec<f32>, rows: usize, width: usize, start: usize, len: usize) -> Vec<f32> {
647    let mut out = vec![0f32; rows * len];
648    for r in 0..rows {
649        for i in 0..len {
650            out[r * len + i] = qkv[r * width + start + i];
651        }
652    }
653    out
654}
655
656fn mlp3_forward_gguf(
657    mlp: &Mlp3,
658    w0_key: Option<&str>,
659    w1_key: Option<&str>,
660    w2_key: Option<&str>,
661    gguf_packed: Option<&GgufPackedParams>,
662    x: &[f32],
663    rows: usize,
664) -> Result<Vec<f32>> {
665    let mut h = linear_maybe_gguf(
666        x,
667        rows,
668        mlp.in_dim,
669        &mlp.w0_t,
670        w0_key,
671        gguf_packed,
672        mlp.hidden,
673        &mlp.b0,
674    )?;
675    for v in h.iter_mut() {
676        if *v < 0.0 {
677            *v = 0.0;
678        }
679    }
680    h = linear_maybe_gguf(
681        &h,
682        rows,
683        mlp.hidden,
684        &mlp.w1_t,
685        w1_key,
686        gguf_packed,
687        mlp.hidden,
688        &mlp.b1,
689    )?;
690    for v in h.iter_mut() {
691        if *v < 0.0 {
692            *v = 0.0;
693        }
694    }
695    linear_maybe_gguf(
696        &h,
697        rows,
698        mlp.hidden,
699        &mlp.w2_t,
700        w2_key,
701        gguf_packed,
702        mlp.out_dim,
703        &mlp.b2,
704    )
705}
706
707#[allow(clippy::too_many_arguments)]
708pub fn forward_segmentation(
709    weights: &mut Sam3SegmentationHeadWeights,
710    enc_memory_bf: &[f32],
711    backbone_fpn: &[Vec<f32>],
712    backbone_shapes: &[(usize, usize)],
713    obj_queries_last_bf: &[f32],
714    prompt_seq_first: &[f32],
715    prompt_kpm: &[u8],
716    batch: usize,
717    enc_h: usize,
718    enc_w: usize,
719    num_queries: usize,
720    seq_len: usize,
721    gguf_packed: Option<&GgufPackedParams>,
722) -> Result<Sam3SegmentationOutput> {
723    ensure!(weights.loaded, "SAM3 segmentation head not loaded");
724    ensure!(batch == 1, "batch > 1 not supported yet");
725    ensure!(
726        backbone_fpn.len() == 3,
727        "expected 3 FPN levels (after scalp)"
728    );
729
730    let hw = enc_h * enc_w;
731    let norm_mem = layer_norm(
732        enc_memory_bf,
733        &weights.cross_attn_norm_w,
734        &weights.cross_attn_norm_b,
735        D_MODEL,
736        1e-5,
737    )?;
738    let mut prompt_bf = vec![0f32; batch * seq_len * D_MODEL];
739    for b in 0..batch {
740        for l in 0..seq_len {
741            let s = (l * batch + b) * D_MODEL;
742            let d = (b * seq_len + l) * D_MODEL;
743            prompt_bf[d..d + D_MODEL].copy_from_slice(&prompt_seq_first[s..s + D_MODEL]);
744        }
745    }
746    let ca = cross_attend_prompt(
747        &norm_mem,
748        &prompt_bf,
749        &prompt_bf,
750        &weights.cross_attend_in_w_t,
751        &weights.cross_attend_in_b,
752        weights.cross_attend_in_gguf_key.as_deref(),
753        &weights.cross_attend_out_w_t,
754        &weights.cross_attend_out_b,
755        weights.cross_attend_out_gguf_key.as_deref(),
756        gguf_packed,
757        batch,
758        hw,
759        seq_len,
760        D_MODEL,
761        N_HEADS,
762        Some(prompt_kpm),
763    )?;
764    let mut enc_refined = enc_memory_bf.to_vec();
765    for i in 0..enc_refined.len() {
766        enc_refined[i] += ca[i];
767    }
768    let mut enc_visual = vec![0f32; batch * D_MODEL * hw];
769    for b in 0..batch {
770        for y in 0..enc_h {
771            for xc in 0..enc_w {
772                for c in 0..D_MODEL {
773                    enc_visual[((b * D_MODEL + c) * enc_h + y) * enc_w + xc] =
774                        enc_refined[(b * hw + y * enc_w + xc) * D_MODEL + c];
775                }
776            }
777        }
778    }
779
780    let mut levels = backbone_fpn.to_vec();
781    levels[2] = enc_visual;
782    let mut shapes = backbone_shapes.to_vec();
783    shapes[2] = (enc_h, enc_w);
784
785    let mut prev = levels.pop().unwrap();
786    let (mut ph, mut pw) = shapes.pop().unwrap();
787
788    if weights.pixel_steps.len() == 2 {
789        for (i, (curr, (ch, cw))) in levels.iter().rev().zip(shapes.iter().rev()).enumerate() {
790            prev = weights.pixel_steps[i].run(&prev, curr)?;
791            ph = *ch;
792            pw = *cw;
793        }
794    } else {
795        for (i, (curr, (ch, cw))) in levels.iter().rev().zip(shapes.iter().rev()).enumerate() {
796            let up = nearest_upsample_nchw(&prev, D_MODEL, ph, pw, *ch, *cw);
797            let mut combined = vec![0f32; curr.len()];
798            for j in 0..combined.len() {
799                combined[j] = curr[j] + up[j];
800            }
801            let conv = conv2d_3x3_pad1_maybe_gguf(
802                &combined,
803                D_MODEL,
804                *ch,
805                *cw,
806                &weights.pixel_conv_w[i],
807                weights.pixel_conv_gguf_keys[i].as_deref(),
808                gguf_packed,
809                &weights.pixel_conv_b[i],
810                &mut weights.pixel_conv_nchw_cache[i],
811            )?;
812            let mut relud = group_norm(
813                &conv,
814                batch,
815                D_MODEL,
816                *ch,
817                *cw,
818                8,
819                &weights.pixel_gn_w[i],
820                &weights.pixel_gn_b[i],
821            );
822            for v in relud.iter_mut() {
823                if *v < 0.0 {
824                    *v = 0.0;
825                }
826            }
827            prev = relud;
828            ph = *ch;
829            pw = *cw;
830        }
831    }
832    let pixel_embed = prev;
833
834    let inst = if let Some(ref mut head) = weights.inst_head {
835        head.run(&pixel_embed)?
836    } else {
837        conv2d_1x1_maybe_gguf(
838            &pixel_embed,
839            D_MODEL,
840            D_MODEL,
841            ph,
842            pw,
843            &weights.inst_w,
844            weights.inst_gguf_key.as_deref(),
845            gguf_packed,
846            &weights.inst_b,
847        )?
848    };
849
850    let mask_embed_out = mlp3_forward_gguf(
851        &weights.mask_embed,
852        weights.mask_embed_w0_gguf_key.as_deref(),
853        weights.mask_embed_w1_gguf_key.as_deref(),
854        weights.mask_embed_w2_gguf_key.as_deref(),
855        gguf_packed,
856        obj_queries_last_bf,
857        batch * num_queries,
858    )?;
859    let mut mask_pred = vec![0f32; batch * num_queries * ph * pw];
860    for b in 0..batch {
861        for q in 0..num_queries {
862            for c in 0..D_MODEL {
863                let qcoeff = mask_embed_out[(b * num_queries + q) * D_MODEL + c];
864                if qcoeff == 0.0 {
865                    continue;
866                }
867                let plane =
868                    &inst[((b * D_MODEL + c) * ph * pw)..((b * D_MODEL + c) * ph * pw + ph * pw)];
869                let dst = &mut mask_pred
870                    [(b * num_queries + q) * ph * pw..(b * num_queries + q + 1) * ph * pw];
871                for p in 0..ph * pw {
872                    dst[p] += qcoeff * plane[p];
873                }
874            }
875        }
876    }
877
878    let semantic_seg = if let Some(ref mut head) = weights.sem_head {
879        head.run(&pixel_embed)?
880    } else {
881        conv2d_1x1_maybe_gguf(
882            &pixel_embed,
883            D_MODEL,
884            1,
885            ph,
886            pw,
887            &weights.sem_w,
888            weights.sem_gguf_key.as_deref(),
889            gguf_packed,
890            &weights.sem_b,
891        )?
892    };
893
894    Ok(Sam3SegmentationOutput {
895        mask_pred,
896        semantic_seg,
897        h_out: ph,
898        w_out: pw,
899        num_queries,
900    })
901}
902
903#[allow(clippy::too_many_arguments)]
904pub fn forward_dot_prod_scoring(
905    weights: &Sam3DotProductScoringWeights,
906    hs_bf: &[f32],
907    prompt_seq_first: &[f32],
908    prompt_kpm: &[u8],
909    num_layers: usize,
910    batch: usize,
911    num_queries: usize,
912    seq_len: usize,
913    gguf_packed: Option<&GgufPackedParams>,
914) -> Result<Vec<f32>> {
915    ensure!(weights.loaded, "SAM3 dot product scoring not loaded");
916    let rows = seq_len * batch;
917    let pm = &weights.prompt_mlp;
918    let mut h = linear_maybe_gguf(
919        prompt_seq_first,
920        rows,
921        pm.in_dim,
922        &pm.w0_t,
923        weights.prompt_mlp_w0_gguf_key.as_deref(),
924        gguf_packed,
925        pm.hidden,
926        &pm.b0,
927    )?;
928    for v in h.iter_mut() {
929        if *v < 0.0 {
930            *v = 0.0;
931        }
932    }
933    h = linear_maybe_gguf(
934        &h,
935        rows,
936        pm.hidden,
937        &pm.w1_t,
938        weights.prompt_mlp_w1_gguf_key.as_deref(),
939        gguf_packed,
940        pm.out_dim,
941        &pm.b1,
942    )?;
943    for i in 0..h.len() {
944        h[i] += prompt_seq_first[i];
945    }
946    let h = layer_norm(
947        &h,
948        &weights.prompt_mlp_out_norm_w,
949        &weights.prompt_mlp_out_norm_b,
950        D_MODEL,
951        1e-5,
952    )?;
953
954    let mut pooled = vec![0f32; batch * D_MODEL];
955    let mut counts = vec![0.0f32; batch];
956    for b in 0..batch {
957        for l in 0..seq_len {
958            if prompt_kpm[b * seq_len + l] == 0 {
959                let src = (l * batch + b) * D_MODEL;
960                let dst = b * D_MODEL;
961                for c in 0..D_MODEL {
962                    pooled[dst + c] += h[src + c];
963                }
964                counts[b] += 1.0;
965            }
966        }
967    }
968    for b in 0..batch {
969        let denom = counts[b].max(1.0);
970        for c in 0..D_MODEL {
971            pooled[b * D_MODEL + c] /= denom;
972        }
973    }
974
975    let proj_pooled = linear_maybe_gguf(
976        &pooled,
977        batch,
978        D_MODEL,
979        &weights.prompt_proj_w_t,
980        weights.prompt_proj_gguf_key.as_deref(),
981        gguf_packed,
982        D_MODEL,
983        &weights.prompt_proj_b,
984    )?;
985    let proj_hs = linear_maybe_gguf(
986        hs_bf,
987        num_layers * batch * num_queries,
988        D_MODEL,
989        &weights.hs_proj_w_t,
990        weights.hs_proj_gguf_key.as_deref(),
991        gguf_packed,
992        D_MODEL,
993        &weights.hs_proj_b,
994    )?;
995
996    let scale = 1.0f32 / (D_MODEL as f32).sqrt();
997    let clamp = 12.0f32;
998    let mut scores = vec![0f32; num_layers * batch * num_queries];
999    for l in 0..num_layers {
1000        for b in 0..batch {
1001            let pp = &proj_pooled[b * D_MODEL..(b + 1) * D_MODEL];
1002            for q in 0..num_queries {
1003                let row = &proj_hs[((l * batch + b) * num_queries + q) * D_MODEL
1004                    ..((l * batch + b) * num_queries + q + 1) * D_MODEL];
1005                let mut acc = 0.0f32;
1006                for c in 0..D_MODEL {
1007                    acc += row[c] * pp[c];
1008                }
1009                let s = (acc * scale).clamp(-clamp, clamp);
1010                scores[(l * batch + b) * num_queries + q] = s;
1011            }
1012        }
1013    }
1014    Ok(scores)
1015}
1016
1017fn nearest_upsample_nchw(
1018    x: &[f32],
1019    c: usize,
1020    src_h: usize,
1021    src_w: usize,
1022    dst_h: usize,
1023    dst_w: usize,
1024) -> Vec<f32> {
1025    let mut out = vec![0f32; c * dst_h * dst_w];
1026    for cc in 0..c {
1027        let inp = &x[cc * src_h * src_w..(cc + 1) * src_h * src_w];
1028        let oup = &mut out[cc * dst_h * dst_w..(cc + 1) * dst_h * dst_w];
1029        for y in 0..dst_h {
1030            let sy = y * src_h / dst_h;
1031            for x in 0..dst_w {
1032                let sx = x * src_w / dst_w;
1033                oup[y * dst_w + x] = inp[sy * src_w + sx];
1034            }
1035        }
1036    }
1037    out
1038}
1039
1040fn conv2d_3x3_pad1_maybe_gguf(
1041    input: &[f32],
1042    c: usize,
1043    h: usize,
1044    w: usize,
1045    weight: &[f32],
1046    weight_gguf_key: Option<&str>,
1047    gguf_packed: Option<&GgufPackedParams>,
1048    bias: &[f32],
1049    nchw_cache: &mut Option<Vec<f32>>,
1050) -> Result<Vec<f32>> {
1051    if !weight.is_empty() {
1052        return Ok(conv2d_3x3_nchw_pad1(input, c, h, w, weight, bias));
1053    }
1054    let key = weight_gguf_key
1055        .ok_or_else(|| anyhow::anyhow!("conv3: missing F32 weights and GGUF key"))?;
1056    let p = gguf_packed
1057        .and_then(|m| packed_linear(m, key))
1058        .ok_or_else(|| anyhow::anyhow!("missing packed conv3 weight: {key}"))?;
1059    conv2d_3x3_nchw_gguf(input, c, h, w, p, bias, nchw_cache)
1060}
1061
1062fn conv2d_1x1(
1063    input: &[f32],
1064    in_c: usize,
1065    out_c: usize,
1066    h: usize,
1067    w: usize,
1068    weight: &[f32],
1069    bias: &[f32],
1070) -> Vec<f32> {
1071    let n = h * w;
1072    let mut out = vec![0f32; out_c * n];
1073    rlx_cpu::blas::sgemm(weight, input, &mut out, out_c, in_c, n);
1074    for oc in 0..out_c {
1075        let b = bias[oc];
1076        let row = &mut out[oc * n..(oc + 1) * n];
1077        for v in row {
1078            *v += b;
1079        }
1080    }
1081    out
1082}
1083
1084fn conv2d_1x1_maybe_gguf(
1085    input: &[f32],
1086    in_c: usize,
1087    out_c: usize,
1088    h: usize,
1089    w: usize,
1090    weight: &[f32],
1091    weight_gguf_key: Option<&str>,
1092    gguf_packed: Option<&GgufPackedParams>,
1093    bias: &[f32],
1094) -> Result<Vec<f32>> {
1095    if weight_gguf_key.is_none() {
1096        return Ok(conv2d_1x1(input, in_c, out_c, h, w, weight, bias));
1097    }
1098    let n = h * w;
1099    let mut rows = vec![0f32; n * in_c];
1100    for ic in 0..in_c {
1101        for p in 0..n {
1102            rows[p * in_c + ic] = input[ic * n + p];
1103        }
1104    }
1105    let flat = linear_maybe_gguf(
1106        &rows,
1107        n,
1108        in_c,
1109        weight,
1110        weight_gguf_key,
1111        gguf_packed,
1112        out_c,
1113        bias,
1114    )?;
1115    let mut out = vec![0f32; out_c * n];
1116    for oc in 0..out_c {
1117        for p in 0..n {
1118            out[oc * n + p] = flat[p * out_c + oc];
1119        }
1120    }
1121    Ok(out)
1122}
1123
1124fn group_norm(
1125    x: &[f32],
1126    batch: usize,
1127    channels: usize,
1128    h: usize,
1129    w: usize,
1130    num_groups: usize,
1131    gamma: &[f32],
1132    beta: &[f32],
1133) -> Vec<f32> {
1134    assert!(channels.is_multiple_of(num_groups));
1135    let cpg = channels / num_groups;
1136    let spatial = h * w;
1137    let mut out = vec![0f32; batch * channels * spatial];
1138    for b in 0..batch {
1139        for g in 0..num_groups {
1140            let c0 = g * cpg;
1141            let n = (cpg * spatial) as f32;
1142            let mut mean = 0.0f32;
1143            for c in 0..cpg {
1144                let plane = &x
1145                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1146                for v in plane {
1147                    mean += *v;
1148                }
1149            }
1150            mean /= n;
1151            let mut var = 0.0f32;
1152            for c in 0..cpg {
1153                let plane = &x
1154                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1155                for v in plane {
1156                    let d = *v - mean;
1157                    var += d * d;
1158                }
1159            }
1160            var /= n;
1161            let inv = 1.0 / (var + 1e-5).sqrt();
1162            for c in 0..cpg {
1163                let src = &x
1164                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1165                let dst = &mut out
1166                    [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1167                let g_ = gamma[c0 + c];
1168                let bias = beta[c0 + c];
1169                for (s, d) in src.iter().zip(dst.iter_mut()) {
1170                    *d = (*s - mean) * inv * g_ + bias;
1171                }
1172            }
1173        }
1174    }
1175    out
1176}
1177
1178/// Legacy stub used by the not-yet-finished `Sam3::predict_image` path.
1179pub fn segmentation_forward_native(
1180    _weights: &Sam3SegmentationHeadWeights,
1181    detector: &Sam3DetectorOutput,
1182    h_out: usize,
1183    w_out: usize,
1184) -> Sam3ImagePrediction {
1185    Sam3ImagePrediction {
1186        masks: vec![0.0; detector.num_queries * h_out * w_out],
1187        mask_shape: vec![detector.num_queries, h_out, w_out],
1188        boxes: vec![0.0; detector.num_queries * 4],
1189        boxes_shape: vec![detector.num_queries, 4],
1190        scores: vec![0.0; detector.num_queries],
1191        scores_shape: vec![detector.num_queries],
1192        num_instances: detector.num_queries,
1193        h_out,
1194        w_out,
1195    }
1196}