Skip to main content

rlx_sam3/
detector_encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 detector encoder fusion (pre-norm, 6 layers, d_model=256).
17//!
18//! Mirrors `sam3.model.encoder.TransformerEncoderFusion` configured by
19//! `model_builder._create_transformer_encoder`. Each layer runs:
20//!
21//!   `tgt2 = norm1(tgt); q=k=tgt2 + pos`
22//!   `tgt += self_attn(q, k, v=tgt2, key_padding_mask=src_kpm)`
23//!   `tgt2 = norm2(tgt)`
24//!   `tgt += cross_attn(q=tgt2, k=v=prompt, key_padding_mask=prompt_kpm)`
25//!   `tgt2 = norm3(tgt)`
26//!   `tgt += linear2(relu(linear1(tgt2)))`
27//!
28//! Builder flags (encoder fusion): `pre_norm=True`, `pos_enc_at_attn=True`,
29//! `pos_enc_at_cross_attn_keys=False`, `pos_enc_at_cross_attn_queries=False`,
30//! `num_feature_levels=1`, `add_pooled_text_to_img_feat=False`.
31//! Hence no `level_embed` or `text_pooling_proj` weights are loaded.
32
33use super::detector_decoder::mha_with_bias_maybe_gguf;
34use super::tensor::layer_norm;
35use rlx_core::weight_map::WeightMap;
36use rlx_flow::GgufPackedParams;
37
38use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
39use anyhow::{Result, ensure};
40
41const D_MODEL: usize = 256;
42const DIM_FF: usize = 2048;
43const N_HEADS: usize = 8;
44pub const N_LAYERS: usize = 6;
45
46#[derive(Clone)]
47pub struct Sam3EncoderLayerWeights {
48    pub self_attn_in_w_t: Vec<f32>,
49    pub self_attn_in_b: Vec<f32>,
50    pub self_attn_in_gguf_key: Option<String>,
51    pub self_attn_out_w_t: Vec<f32>,
52    pub self_attn_out_b: Vec<f32>,
53    pub self_attn_out_gguf_key: Option<String>,
54    pub cross_attn_in_w_t: Vec<f32>,
55    pub cross_attn_in_b: Vec<f32>,
56    pub cross_attn_in_gguf_key: Option<String>,
57    pub cross_attn_out_w_t: Vec<f32>,
58    pub cross_attn_out_b: Vec<f32>,
59    pub cross_attn_out_gguf_key: Option<String>,
60    pub linear1_w_t: Vec<f32>,
61    pub linear1_b: Vec<f32>,
62    pub linear1_gguf_key: Option<String>,
63    pub linear2_w_t: Vec<f32>,
64    pub linear2_b: Vec<f32>,
65    pub linear2_gguf_key: Option<String>,
66    pub norm1_w: Vec<f32>,
67    pub norm1_b: Vec<f32>,
68    pub norm2_w: Vec<f32>,
69    pub norm2_b: Vec<f32>,
70    pub norm3_w: Vec<f32>,
71    pub norm3_b: Vec<f32>,
72}
73
74#[derive(Clone, Default)]
75pub struct Sam3EncoderWeights {
76    pub loaded: bool,
77    /// Checkpoint prefix (`detector.transformer.encoder` or `transformer.encoder`).
78    pub prefix: String,
79    pub layers: Vec<Sam3EncoderLayerWeights>,
80}
81
82pub fn extract_encoder_weights(
83    weights: &mut WeightMap,
84    gguf_packed: Option<&GgufPackedParams>,
85) -> Result<Sam3EncoderWeights> {
86    let prefixes = ["detector.transformer.encoder", "transformer.encoder"];
87    let base = {
88        let mut found = None;
89        for p in prefixes {
90            let k = format!("{p}.layers.0.self_attn.in_proj_weight");
91            if weights.has(&k) {
92                found = Some(p);
93                break;
94            }
95        }
96        found.ok_or_else(|| anyhow::anyhow!("SAM3 detector encoder not found"))?
97    };
98
99    let mut layers = Vec::with_capacity(N_LAYERS);
100    for i in 0..N_LAYERS {
101        let p = format!("{base}.layers.{i}");
102        let (self_attn_in_w_t, self_attn_in_gguf_key) = take_transposed_with_gguf_key(
103            weights,
104            gguf_packed,
105            &format!("{p}.self_attn.in_proj_weight"),
106        )?;
107        let (self_attn_in_b, _) =
108            take_or_gguf(weights, gguf_packed, &format!("{p}.self_attn.in_proj_bias"))?;
109        let (self_attn_out_w_t, self_attn_out_gguf_key) = take_transposed_with_gguf_key(
110            weights,
111            gguf_packed,
112            &format!("{p}.self_attn.out_proj.weight"),
113        )?;
114        let (self_attn_out_b, _) = take_or_gguf(
115            weights,
116            gguf_packed,
117            &format!("{p}.self_attn.out_proj.bias"),
118        )?;
119        let (cross_attn_in_w_t, cross_attn_in_gguf_key) = take_transposed_with_gguf_key(
120            weights,
121            gguf_packed,
122            &format!("{p}.cross_attn_image.in_proj_weight"),
123        )?;
124        let (cross_attn_in_b, _) = take_or_gguf(
125            weights,
126            gguf_packed,
127            &format!("{p}.cross_attn_image.in_proj_bias"),
128        )?;
129        let (cross_attn_out_w_t, cross_attn_out_gguf_key) = take_transposed_with_gguf_key(
130            weights,
131            gguf_packed,
132            &format!("{p}.cross_attn_image.out_proj.weight"),
133        )?;
134        let (cross_attn_out_b, _) = take_or_gguf(
135            weights,
136            gguf_packed,
137            &format!("{p}.cross_attn_image.out_proj.bias"),
138        )?;
139        let (linear1_w_t, linear1_gguf_key) =
140            take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear1.weight"))?;
141        let (linear1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear1.bias"))?;
142        let (linear2_w_t, linear2_gguf_key) =
143            take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear2.weight"))?;
144        let (linear2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear2.bias"))?;
145        let (norm1_w, _) = weights.take(&format!("{p}.norm1.weight"))?;
146        let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
147        let (norm2_w, _) = weights.take(&format!("{p}.norm2.weight"))?;
148        let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
149        let (norm3_w, _) = weights.take(&format!("{p}.norm3.weight"))?;
150        let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
151        layers.push(Sam3EncoderLayerWeights {
152            self_attn_in_w_t,
153            self_attn_in_b,
154            self_attn_in_gguf_key,
155            self_attn_out_w_t,
156            self_attn_out_b,
157            self_attn_out_gguf_key,
158            cross_attn_in_w_t,
159            cross_attn_in_b,
160            cross_attn_in_gguf_key,
161            cross_attn_out_w_t,
162            cross_attn_out_b,
163            cross_attn_out_gguf_key,
164            linear1_w_t,
165            linear1_b,
166            linear1_gguf_key,
167            linear2_w_t,
168            linear2_b,
169            linear2_gguf_key,
170            norm1_w,
171            norm1_b,
172            norm2_w,
173            norm2_b,
174            norm3_w,
175            norm3_b,
176        });
177    }
178    Ok(Sam3EncoderWeights {
179        loaded: true,
180        prefix: base.to_string(),
181        layers,
182    })
183}
184
185/// Run the encoder fusion. `src` is the FPN feature flat in NCHW
186/// `[B, C, H, W]`. `src_pos` matches. `prompt` is sequence-first
187/// `[L_p, B, C]`. Returns the encoded memory in batch-first flat
188/// `[B, H*W, C]`.
189#[allow(clippy::too_many_arguments)]
190pub fn forward_encoder(
191    weights: &Sam3EncoderWeights,
192    src_bchw: &[f32],
193    src_pos_bchw: &[f32],
194    prompt_seq_first: &[f32],
195    prompt_kpm: &[u8],
196    batch: usize,
197    src_h: usize,
198    src_w: usize,
199    prompt_len: usize,
200    gguf_packed: Option<&GgufPackedParams>,
201) -> Result<Vec<f32>> {
202    ensure!(weights.loaded, "SAM3 detector encoder not loaded");
203    ensure!(
204        src_bchw.len() == batch * D_MODEL * src_h * src_w,
205        "encoder src shape mismatch"
206    );
207    ensure!(
208        prompt_seq_first.len() == prompt_len * batch * D_MODEL,
209        "encoder prompt shape mismatch"
210    );
211    ensure!(
212        prompt_kpm.len() == batch * prompt_len,
213        "encoder prompt mask shape mismatch"
214    );
215
216    let hw = src_h * src_w;
217
218    // Flatten src and pos from NCHW → [B, H*W, C] (batch-first), matching
219    // `src.flatten(2).transpose(1, 2)` upstream.
220    let mut tgt = vec![0f32; batch * hw * D_MODEL];
221    let mut pos = vec![0f32; batch * hw * D_MODEL];
222    for b in 0..batch {
223        for s in 0..hw {
224            for c in 0..D_MODEL {
225                tgt[(b * hw + s) * D_MODEL + c] = src_bchw[((b * D_MODEL + c) * hw) + s];
226                pos[(b * hw + s) * D_MODEL + c] = src_pos_bchw[((b * D_MODEL + c) * hw) + s];
227            }
228        }
229    }
230
231    // Reorder prompt from [L, B, C] to [B, L, C] for batch-first attention.
232    let mut prompt_bf = vec![0f32; batch * prompt_len * D_MODEL];
233    for b in 0..batch {
234        for l in 0..prompt_len {
235            let src = (l * batch + b) * D_MODEL;
236            let dst = (b * prompt_len + l) * D_MODEL;
237            prompt_bf[dst..dst + D_MODEL].copy_from_slice(&prompt_seq_first[src..src + D_MODEL]);
238        }
239    }
240
241    for layer in &weights.layers {
242        // Pre-norm self-attention with pos added to Q and K.
243        let n1 = layer_norm(&tgt, &layer.norm1_w, &layer.norm1_b, D_MODEL, 1e-5)?;
244        let mut q = vec![0f32; n1.len()];
245        for i in 0..n1.len() {
246            q[i] = n1[i] + pos[i];
247        }
248        let sa = mha_with_bias_maybe_gguf(
249            &q,
250            &q,
251            &n1,
252            &layer.self_attn_in_w_t,
253            &layer.self_attn_in_b,
254            layer.self_attn_in_gguf_key.as_deref(),
255            &layer.self_attn_out_w_t,
256            &layer.self_attn_out_b,
257            layer.self_attn_out_gguf_key.as_deref(),
258            gguf_packed,
259            batch,
260            hw,
261            hw,
262            D_MODEL,
263            N_HEADS,
264            None,
265            None,
266        )?;
267        for i in 0..tgt.len() {
268            tgt[i] += sa[i];
269        }
270
271        // Pre-norm cross-attention to prompt (text). No pos added to Q/K.
272        let n2 = layer_norm(&tgt, &layer.norm2_w, &layer.norm2_b, D_MODEL, 1e-5)?;
273        let ca = mha_with_bias_maybe_gguf(
274            &n2,
275            &prompt_bf,
276            &prompt_bf,
277            &layer.cross_attn_in_w_t,
278            &layer.cross_attn_in_b,
279            layer.cross_attn_in_gguf_key.as_deref(),
280            &layer.cross_attn_out_w_t,
281            &layer.cross_attn_out_b,
282            layer.cross_attn_out_gguf_key.as_deref(),
283            gguf_packed,
284            batch,
285            hw,
286            prompt_len,
287            D_MODEL,
288            N_HEADS,
289            None,
290            Some(prompt_kpm),
291        )?;
292        for i in 0..tgt.len() {
293            tgt[i] += ca[i];
294        }
295
296        // Pre-norm FFN with ReLU.
297        let n3 = layer_norm(&tgt, &layer.norm3_w, &layer.norm3_b, D_MODEL, 1e-5)?;
298        let mut ff = linear_maybe_gguf(
299            &n3,
300            batch * hw,
301            D_MODEL,
302            &layer.linear1_w_t,
303            layer.linear1_gguf_key.as_deref(),
304            gguf_packed,
305            DIM_FF,
306            &layer.linear1_b,
307        )?;
308        for v in ff.iter_mut() {
309            if *v < 0.0 {
310                *v = 0.0;
311            }
312        }
313        let ffn = linear_maybe_gguf(
314            &ff,
315            batch * hw,
316            DIM_FF,
317            &layer.linear2_w_t,
318            layer.linear2_gguf_key.as_deref(),
319            gguf_packed,
320            D_MODEL,
321            &layer.linear2_b,
322        )?;
323        for i in 0..tgt.len() {
324            tgt[i] += ffn[i];
325        }
326    }
327
328    Ok(tgt)
329}