1use super::detector_decoder::mha_with_bias_maybe_gguf;
34use super::tensor::layer_norm;
35use rlx_core::weight_map::WeightMap;
36use rlx_flow::GgufPackedParams;
37
38use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
39use anyhow::{Result, ensure};
40
41const D_MODEL: usize = 256;
42const DIM_FF: usize = 2048;
43const N_HEADS: usize = 8;
44pub const N_LAYERS: usize = 6;
45
46#[derive(Clone)]
47pub struct Sam3EncoderLayerWeights {
48 pub self_attn_in_w_t: Vec<f32>,
49 pub self_attn_in_b: Vec<f32>,
50 pub self_attn_in_gguf_key: Option<String>,
51 pub self_attn_out_w_t: Vec<f32>,
52 pub self_attn_out_b: Vec<f32>,
53 pub self_attn_out_gguf_key: Option<String>,
54 pub cross_attn_in_w_t: Vec<f32>,
55 pub cross_attn_in_b: Vec<f32>,
56 pub cross_attn_in_gguf_key: Option<String>,
57 pub cross_attn_out_w_t: Vec<f32>,
58 pub cross_attn_out_b: Vec<f32>,
59 pub cross_attn_out_gguf_key: Option<String>,
60 pub linear1_w_t: Vec<f32>,
61 pub linear1_b: Vec<f32>,
62 pub linear1_gguf_key: Option<String>,
63 pub linear2_w_t: Vec<f32>,
64 pub linear2_b: Vec<f32>,
65 pub linear2_gguf_key: Option<String>,
66 pub norm1_w: Vec<f32>,
67 pub norm1_b: Vec<f32>,
68 pub norm2_w: Vec<f32>,
69 pub norm2_b: Vec<f32>,
70 pub norm3_w: Vec<f32>,
71 pub norm3_b: Vec<f32>,
72}
73
74#[derive(Clone, Default)]
75pub struct Sam3EncoderWeights {
76 pub loaded: bool,
77 pub prefix: String,
79 pub layers: Vec<Sam3EncoderLayerWeights>,
80}
81
82pub fn extract_encoder_weights(
83 weights: &mut WeightMap,
84 gguf_packed: Option<&GgufPackedParams>,
85) -> Result<Sam3EncoderWeights> {
86 let prefixes = ["detector.transformer.encoder", "transformer.encoder"];
87 let base = {
88 let mut found = None;
89 for p in prefixes {
90 let k = format!("{p}.layers.0.self_attn.in_proj_weight");
91 if weights.has(&k) {
92 found = Some(p);
93 break;
94 }
95 }
96 found.ok_or_else(|| anyhow::anyhow!("SAM3 detector encoder not found"))?
97 };
98
99 let mut layers = Vec::with_capacity(N_LAYERS);
100 for i in 0..N_LAYERS {
101 let p = format!("{base}.layers.{i}");
102 let (self_attn_in_w_t, self_attn_in_gguf_key) = take_transposed_with_gguf_key(
103 weights,
104 gguf_packed,
105 &format!("{p}.self_attn.in_proj_weight"),
106 )?;
107 let (self_attn_in_b, _) =
108 take_or_gguf(weights, gguf_packed, &format!("{p}.self_attn.in_proj_bias"))?;
109 let (self_attn_out_w_t, self_attn_out_gguf_key) = take_transposed_with_gguf_key(
110 weights,
111 gguf_packed,
112 &format!("{p}.self_attn.out_proj.weight"),
113 )?;
114 let (self_attn_out_b, _) = take_or_gguf(
115 weights,
116 gguf_packed,
117 &format!("{p}.self_attn.out_proj.bias"),
118 )?;
119 let (cross_attn_in_w_t, cross_attn_in_gguf_key) = take_transposed_with_gguf_key(
120 weights,
121 gguf_packed,
122 &format!("{p}.cross_attn_image.in_proj_weight"),
123 )?;
124 let (cross_attn_in_b, _) = take_or_gguf(
125 weights,
126 gguf_packed,
127 &format!("{p}.cross_attn_image.in_proj_bias"),
128 )?;
129 let (cross_attn_out_w_t, cross_attn_out_gguf_key) = take_transposed_with_gguf_key(
130 weights,
131 gguf_packed,
132 &format!("{p}.cross_attn_image.out_proj.weight"),
133 )?;
134 let (cross_attn_out_b, _) = take_or_gguf(
135 weights,
136 gguf_packed,
137 &format!("{p}.cross_attn_image.out_proj.bias"),
138 )?;
139 let (linear1_w_t, linear1_gguf_key) =
140 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear1.weight"))?;
141 let (linear1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear1.bias"))?;
142 let (linear2_w_t, linear2_gguf_key) =
143 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{p}.linear2.weight"))?;
144 let (linear2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{p}.linear2.bias"))?;
145 let (norm1_w, _) = weights.take(&format!("{p}.norm1.weight"))?;
146 let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
147 let (norm2_w, _) = weights.take(&format!("{p}.norm2.weight"))?;
148 let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
149 let (norm3_w, _) = weights.take(&format!("{p}.norm3.weight"))?;
150 let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
151 layers.push(Sam3EncoderLayerWeights {
152 self_attn_in_w_t,
153 self_attn_in_b,
154 self_attn_in_gguf_key,
155 self_attn_out_w_t,
156 self_attn_out_b,
157 self_attn_out_gguf_key,
158 cross_attn_in_w_t,
159 cross_attn_in_b,
160 cross_attn_in_gguf_key,
161 cross_attn_out_w_t,
162 cross_attn_out_b,
163 cross_attn_out_gguf_key,
164 linear1_w_t,
165 linear1_b,
166 linear1_gguf_key,
167 linear2_w_t,
168 linear2_b,
169 linear2_gguf_key,
170 norm1_w,
171 norm1_b,
172 norm2_w,
173 norm2_b,
174 norm3_w,
175 norm3_b,
176 });
177 }
178 Ok(Sam3EncoderWeights {
179 loaded: true,
180 prefix: base.to_string(),
181 layers,
182 })
183}
184
185#[allow(clippy::too_many_arguments)]
190pub fn forward_encoder(
191 weights: &Sam3EncoderWeights,
192 src_bchw: &[f32],
193 src_pos_bchw: &[f32],
194 prompt_seq_first: &[f32],
195 prompt_kpm: &[u8],
196 batch: usize,
197 src_h: usize,
198 src_w: usize,
199 prompt_len: usize,
200 gguf_packed: Option<&GgufPackedParams>,
201) -> Result<Vec<f32>> {
202 ensure!(weights.loaded, "SAM3 detector encoder not loaded");
203 ensure!(
204 src_bchw.len() == batch * D_MODEL * src_h * src_w,
205 "encoder src shape mismatch"
206 );
207 ensure!(
208 prompt_seq_first.len() == prompt_len * batch * D_MODEL,
209 "encoder prompt shape mismatch"
210 );
211 ensure!(
212 prompt_kpm.len() == batch * prompt_len,
213 "encoder prompt mask shape mismatch"
214 );
215
216 let hw = src_h * src_w;
217
218 let mut tgt = vec![0f32; batch * hw * D_MODEL];
221 let mut pos = vec![0f32; batch * hw * D_MODEL];
222 for b in 0..batch {
223 for s in 0..hw {
224 for c in 0..D_MODEL {
225 tgt[(b * hw + s) * D_MODEL + c] = src_bchw[((b * D_MODEL + c) * hw) + s];
226 pos[(b * hw + s) * D_MODEL + c] = src_pos_bchw[((b * D_MODEL + c) * hw) + s];
227 }
228 }
229 }
230
231 let mut prompt_bf = vec![0f32; batch * prompt_len * D_MODEL];
233 for b in 0..batch {
234 for l in 0..prompt_len {
235 let src = (l * batch + b) * D_MODEL;
236 let dst = (b * prompt_len + l) * D_MODEL;
237 prompt_bf[dst..dst + D_MODEL].copy_from_slice(&prompt_seq_first[src..src + D_MODEL]);
238 }
239 }
240
241 for layer in &weights.layers {
242 let n1 = layer_norm(&tgt, &layer.norm1_w, &layer.norm1_b, D_MODEL, 1e-5)?;
244 let mut q = vec![0f32; n1.len()];
245 for i in 0..n1.len() {
246 q[i] = n1[i] + pos[i];
247 }
248 let sa = mha_with_bias_maybe_gguf(
249 &q,
250 &q,
251 &n1,
252 &layer.self_attn_in_w_t,
253 &layer.self_attn_in_b,
254 layer.self_attn_in_gguf_key.as_deref(),
255 &layer.self_attn_out_w_t,
256 &layer.self_attn_out_b,
257 layer.self_attn_out_gguf_key.as_deref(),
258 gguf_packed,
259 batch,
260 hw,
261 hw,
262 D_MODEL,
263 N_HEADS,
264 None,
265 None,
266 )?;
267 for i in 0..tgt.len() {
268 tgt[i] += sa[i];
269 }
270
271 let n2 = layer_norm(&tgt, &layer.norm2_w, &layer.norm2_b, D_MODEL, 1e-5)?;
273 let ca = mha_with_bias_maybe_gguf(
274 &n2,
275 &prompt_bf,
276 &prompt_bf,
277 &layer.cross_attn_in_w_t,
278 &layer.cross_attn_in_b,
279 layer.cross_attn_in_gguf_key.as_deref(),
280 &layer.cross_attn_out_w_t,
281 &layer.cross_attn_out_b,
282 layer.cross_attn_out_gguf_key.as_deref(),
283 gguf_packed,
284 batch,
285 hw,
286 prompt_len,
287 D_MODEL,
288 N_HEADS,
289 None,
290 Some(prompt_kpm),
291 )?;
292 for i in 0..tgt.len() {
293 tgt[i] += ca[i];
294 }
295
296 let n3 = layer_norm(&tgt, &layer.norm3_w, &layer.norm3_b, D_MODEL, 1e-5)?;
298 let mut ff = linear_maybe_gguf(
299 &n3,
300 batch * hw,
301 D_MODEL,
302 &layer.linear1_w_t,
303 layer.linear1_gguf_key.as_deref(),
304 gguf_packed,
305 DIM_FF,
306 &layer.linear1_b,
307 )?;
308 for v in ff.iter_mut() {
309 if *v < 0.0 {
310 *v = 0.0;
311 }
312 }
313 let ffn = linear_maybe_gguf(
314 &ff,
315 batch * hw,
316 DIM_FF,
317 &layer.linear2_w_t,
318 layer.linear2_gguf_key.as_deref(),
319 gguf_packed,
320 D_MODEL,
321 &layer.linear2_b,
322 )?;
323 for i in 0..tgt.len() {
324 tgt[i] += ffn[i];
325 }
326 }
327
328 Ok(tgt)
329}