1use super::config::Sam3TextConfig;
35use super::tensor::{layer_norm, matmul, matmul_bt, softmax_rows};
36use rlx_core::weight_map::WeightMap;
37use rlx_flow::GgufPackedParams;
38
39use crate::packed_gguf::{linear_maybe_gguf, take_or_gguf, take_transposed_with_gguf_key};
40use anyhow::{Result, ensure};
41
42#[derive(Clone)]
43pub struct Sam3TextBlock {
44 pub ln1_w: Vec<f32>,
45 pub ln1_b: Vec<f32>,
46 pub qkv_w_t: Vec<f32>,
47 pub qkv_b: Vec<f32>,
48 pub proj_w_t: Vec<f32>,
49 pub proj_b: Vec<f32>,
50 pub ln2_w: Vec<f32>,
51 pub ln2_b: Vec<f32>,
52 pub mlp_fc_w_t: Vec<f32>,
53 pub mlp_fc_b: Vec<f32>,
54 pub mlp_proj_w_t: Vec<f32>,
55 pub mlp_proj_b: Vec<f32>,
56 pub qkv_gguf_key: Option<String>,
57 pub proj_gguf_key: Option<String>,
58 pub mlp_fc_gguf_key: Option<String>,
59 pub mlp_proj_gguf_key: Option<String>,
60}
61
62#[derive(Clone, Default)]
63pub struct Sam3TextEncoderWeights {
64 pub loaded: bool,
65 pub width: usize,
66 pub heads: usize,
67 pub context_length: usize,
68 pub d_model: usize,
69 pub vocab_size: usize,
70 pub token_embedding: Vec<f32>,
71 pub positional_embedding: Vec<f32>,
72 pub ln_final_w: Vec<f32>,
73 pub ln_final_b: Vec<f32>,
74 pub blocks: Vec<Sam3TextBlock>,
75 pub resizer_w_t: Vec<f32>,
76 pub resizer_b: Vec<f32>,
77 pub resizer_gguf_key: Option<String>,
78}
79
80#[derive(Debug, Clone, Default)]
81pub struct Sam3TextEncoded {
82 pub attention_mask: Vec<u8>,
84 pub text_memory_resized: Vec<f32>,
86 pub inputs_embeds: Vec<f32>,
88 pub seq_len: usize,
89 pub batch: usize,
90 pub d_model: usize,
91 pub width: usize,
92}
93
94pub fn extract_text_encoder_weights(
95 weights: &mut WeightMap,
96 cfg: &Sam3TextConfig,
97 gguf_packed: Option<&GgufPackedParams>,
98) -> Result<Sam3TextEncoderWeights> {
99 let width = cfg.width;
100 let heads = cfg.heads;
101 let layers = cfg.layers;
102 let d_model = cfg.d_model;
103 let context_length = 32usize;
104 let vocab_size = 49408usize;
105 let _mlp_width = width * 4;
106
107 let prefixes = [
108 "detector.backbone.language_backbone",
109 "backbone.language_backbone",
110 "language_backbone",
111 ];
112 let enc_prefix = {
113 let mut found = None;
114 for p in prefixes {
115 let key = format!("{p}.encoder.token_embedding.weight");
116 if weights.has(&key) {
117 found = Some(p);
118 break;
119 }
120 }
121 found.ok_or_else(|| anyhow::anyhow!("SAM3 language_backbone not found"))?
122 };
123
124 let (token_embedding, te_shape) = take_or_gguf(
125 weights,
126 gguf_packed,
127 &format!("{enc_prefix}.encoder.token_embedding.weight"),
128 )?;
129 ensure!(
130 te_shape == vec![vocab_size, width],
131 "token_embedding shape {te_shape:?}"
132 );
133 let (positional_embedding, pe_shape) = take_or_gguf(
134 weights,
135 gguf_packed,
136 &format!("{enc_prefix}.encoder.positional_embedding"),
137 )?;
138 ensure!(
139 pe_shape == vec![context_length, width],
140 "positional_embedding shape {pe_shape:?}"
141 );
142 let (ln_final_w, _) = take_or_gguf(
143 weights,
144 gguf_packed,
145 &format!("{enc_prefix}.encoder.ln_final.weight"),
146 )?;
147 let (ln_final_b, _) = take_or_gguf(
148 weights,
149 gguf_packed,
150 &format!("{enc_prefix}.encoder.ln_final.bias"),
151 )?;
152
153 let _ = weights.take(&format!("{enc_prefix}.encoder.text_projection"));
156
157 let mut blocks = Vec::with_capacity(layers);
158 for i in 0..layers {
159 let bp = format!("{enc_prefix}.encoder.transformer.resblocks.{i}");
160 let (ln1_w, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_1.weight"))?;
161 let (ln1_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_1.bias"))?;
162 let (qkv_w_t, qkv_gguf_key) = take_transposed_with_gguf_key(
163 weights,
164 gguf_packed,
165 &format!("{bp}.attn.in_proj_weight"),
166 )?;
167 let (qkv_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.attn.in_proj_bias"))?;
168 let (proj_w_t, proj_gguf_key) = take_transposed_with_gguf_key(
169 weights,
170 gguf_packed,
171 &format!("{bp}.attn.out_proj.weight"),
172 )?;
173 let (proj_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.attn.out_proj.bias"))?;
174 let (ln2_w, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_2.weight"))?;
175 let (ln2_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.ln_2.bias"))?;
176 let (mlp_fc_w_t, mlp_fc_gguf_key) =
177 take_transposed_with_gguf_key(weights, gguf_packed, &format!("{bp}.mlp.c_fc.weight"))?;
178 let (mlp_fc_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.mlp.c_fc.bias"))?;
179 let (mlp_proj_w_t, mlp_proj_gguf_key) = take_transposed_with_gguf_key(
180 weights,
181 gguf_packed,
182 &format!("{bp}.mlp.c_proj.weight"),
183 )?;
184 let (mlp_proj_b, _) = take_or_gguf(weights, gguf_packed, &format!("{bp}.mlp.c_proj.bias"))?;
185 blocks.push(Sam3TextBlock {
186 ln1_w,
187 ln1_b,
188 qkv_w_t,
189 qkv_b,
190 proj_w_t,
191 proj_b,
192 ln2_w,
193 ln2_b,
194 mlp_fc_w_t,
195 mlp_fc_b,
196 mlp_proj_w_t,
197 mlp_proj_b,
198 qkv_gguf_key,
199 proj_gguf_key,
200 mlp_fc_gguf_key,
201 mlp_proj_gguf_key,
202 });
203 }
204
205 let (resizer_w_t, resizer_gguf_key) = take_transposed_with_gguf_key(
206 weights,
207 gguf_packed,
208 &format!("{enc_prefix}.resizer.weight"),
209 )?;
210 let (resizer_b, _) = take_or_gguf(weights, gguf_packed, &format!("{enc_prefix}.resizer.bias"))?;
211
212 Ok(Sam3TextEncoderWeights {
213 loaded: true,
214 width,
215 heads,
216 context_length,
217 d_model,
218 vocab_size,
219 token_embedding,
220 positional_embedding,
221 ln_final_w,
222 ln_final_b,
223 blocks,
224 resizer_w_t,
225 resizer_b,
226 resizer_gguf_key,
227 })
228}
229
230pub fn encode_tokens(
232 weights: &Sam3TextEncoderWeights,
233 tokens: &[u32],
234 batch: usize,
235 seq_len: usize,
236 gguf_packed: Option<&GgufPackedParams>,
237) -> Result<Sam3TextEncoded> {
238 ensure!(weights.loaded, "SAM3 text encoder weights not loaded");
239 ensure!(
240 tokens.len() == batch * seq_len,
241 "expected {} tokens, got {}",
242 batch * seq_len,
243 tokens.len()
244 );
245 ensure!(
246 seq_len <= weights.context_length,
247 "seq_len {seq_len} exceeds context_length {}",
248 weights.context_length
249 );
250 let w = weights.width;
251 let h = weights.heads;
252 let head_dim = w / h;
253 ensure!(head_dim * h == w, "width {w} not divisible by heads {h}");
254
255 let mut x = vec![0f32; batch * seq_len * w];
256 let mut inputs_embeds = vec![0f32; batch * seq_len * w];
257 for b in 0..batch {
258 for l in 0..seq_len {
259 let tok = tokens[b * seq_len + l] as usize;
260 ensure!(tok < weights.vocab_size, "token id {tok} out of vocab");
261 let src = &weights.token_embedding[tok * w..(tok + 1) * w];
262 let dst_x = &mut x[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
263 let dst_emb = &mut inputs_embeds[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
264 dst_emb.copy_from_slice(src);
265 let pos = &weights.positional_embedding[l * w..(l + 1) * w];
266 for k in 0..w {
267 dst_x[k] = src[k] + pos[k];
268 }
269 }
270 }
271
272 let neg_inf = f32::NEG_INFINITY;
274 let mut mask = vec![0f32; seq_len * seq_len];
275 for i in 0..seq_len {
276 for j in (i + 1)..seq_len {
277 mask[i * seq_len + j] = neg_inf;
278 }
279 }
280
281 for block in &weights.blocks {
282 x = block_forward(
283 &x,
284 block,
285 batch,
286 seq_len,
287 w,
288 h,
289 head_dim,
290 &mask,
291 gguf_packed,
292 )?;
293 }
294 x = layer_norm(&x, &weights.ln_final_w, &weights.ln_final_b, w, 1e-5)?;
295
296 let mut text_memory_seq_first = vec![0f32; seq_len * batch * w];
298 for b in 0..batch {
299 for l in 0..seq_len {
300 let src = &x[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
301 let dst = &mut text_memory_seq_first[(l * batch + b) * w..(l * batch + b + 1) * w];
302 dst.copy_from_slice(src);
303 }
304 }
305 let mut inputs_embeds_seq_first = vec![0f32; seq_len * batch * w];
306 for b in 0..batch {
307 for l in 0..seq_len {
308 let src = &inputs_embeds[(b * seq_len + l) * w..(b * seq_len + l + 1) * w];
309 let dst = &mut inputs_embeds_seq_first[(l * batch + b) * w..(l * batch + b + 1) * w];
310 dst.copy_from_slice(src);
311 }
312 }
313
314 let text_memory_resized = linear_maybe_gguf(
315 &text_memory_seq_first,
316 seq_len * batch,
317 w,
318 &weights.resizer_w_t,
319 weights.resizer_gguf_key.as_deref(),
320 gguf_packed,
321 weights.d_model,
322 &weights.resizer_b,
323 )?;
324
325 let mut attention_mask = vec![0u8; batch * seq_len];
326 for i in 0..batch * seq_len {
327 attention_mask[i] = if tokens[i] == 0 { 1 } else { 0 };
328 }
329
330 Ok(Sam3TextEncoded {
331 attention_mask,
332 text_memory_resized,
333 inputs_embeds: inputs_embeds_seq_first,
334 seq_len,
335 batch,
336 d_model: weights.d_model,
337 width: w,
338 })
339}
340
341fn block_forward(
342 x_in: &[f32],
343 block: &Sam3TextBlock,
344 batch: usize,
345 seq_len: usize,
346 width: usize,
347 heads: usize,
348 head_dim: usize,
349 mask: &[f32],
350 gguf_packed: Option<&GgufPackedParams>,
351) -> Result<Vec<f32>> {
352 let rows = batch * seq_len;
353 let n1 = layer_norm(x_in, &block.ln1_w, &block.ln1_b, width, 1e-5)?;
354 let qkv = linear_maybe_gguf(
355 &n1,
356 rows,
357 width,
358 &block.qkv_w_t,
359 block.qkv_gguf_key.as_deref(),
360 gguf_packed,
361 3 * width,
362 &block.qkv_b,
363 )?;
364
365 let bh = batch * heads;
366 let mut q = vec![0f32; bh * seq_len * head_dim];
367 let mut k = vec![0f32; bh * seq_len * head_dim];
368 let mut v = vec![0f32; bh * seq_len * head_dim];
369 for b in 0..batch {
370 for l in 0..seq_len {
371 let src = (b * seq_len + l) * 3 * width;
372 for hd in 0..heads {
373 let qd_src = src + hd * head_dim;
374 let kd_src = src + width + hd * head_dim;
375 let vd_src = src + 2 * width + hd * head_dim;
376 let dst = ((b * heads + hd) * seq_len + l) * head_dim;
377 q[dst..dst + head_dim].copy_from_slice(&qkv[qd_src..qd_src + head_dim]);
378 k[dst..dst + head_dim].copy_from_slice(&qkv[kd_src..kd_src + head_dim]);
379 v[dst..dst + head_dim].copy_from_slice(&qkv[vd_src..vd_src + head_dim]);
380 }
381 }
382 }
383
384 let scale = 1.0f32 / (head_dim as f32).sqrt();
385 let mut attn_out = vec![0f32; bh * seq_len * head_dim];
386 let mut scores = vec![0f32; seq_len * seq_len];
387 for bhi in 0..bh {
388 let q_h = &q[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
389 let k_h = &k[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
390 let v_h = &v[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
391 matmul_bt(q_h, k_h, &mut scores, seq_len, head_dim, seq_len, scale);
392 for r in 0..seq_len {
393 for c in 0..seq_len {
394 scores[r * seq_len + c] += mask[r * seq_len + c];
395 }
396 }
397 softmax_rows(&mut scores, seq_len, seq_len);
398 let out_h = &mut attn_out[bhi * seq_len * head_dim..(bhi + 1) * seq_len * head_dim];
399 matmul(&scores, v_h, out_h, seq_len, seq_len, head_dim);
400 }
401
402 let mut packed = vec![0f32; rows * width];
403 for b in 0..batch {
404 for l in 0..seq_len {
405 for hd in 0..heads {
406 let src = ((b * heads + hd) * seq_len + l) * head_dim;
407 let dst = (b * seq_len + l) * width + hd * head_dim;
408 packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
409 }
410 }
411 }
412 let attn_proj = linear_maybe_gguf(
413 &packed,
414 rows,
415 width,
416 &block.proj_w_t,
417 block.proj_gguf_key.as_deref(),
418 gguf_packed,
419 width,
420 &block.proj_b,
421 )?;
422
423 let mut x = x_in.to_vec();
424 for i in 0..x.len() {
425 x[i] += attn_proj[i];
426 }
427 let n2 = layer_norm(&x, &block.ln2_w, &block.ln2_b, width, 1e-5)?;
428 let mlp_hidden = block.mlp_fc_b.len();
429 let mut mlp = linear_maybe_gguf(
430 &n2,
431 rows,
432 width,
433 &block.mlp_fc_w_t,
434 block.mlp_fc_gguf_key.as_deref(),
435 gguf_packed,
436 mlp_hidden,
437 &block.mlp_fc_b,
438 )?;
439 gelu_exact_inplace(&mut mlp);
440 let ffn = linear_maybe_gguf(
441 &mlp,
442 rows,
443 mlp_hidden,
444 &block.mlp_proj_w_t,
445 block.mlp_proj_gguf_key.as_deref(),
446 gguf_packed,
447 width,
448 &block.mlp_proj_b,
449 )?;
450 for i in 0..x.len() {
451 x[i] += ffn[i];
452 }
453 Ok(x)
454}
455
456fn gelu_exact_inplace(x: &mut [f32]) {
457 let inv_sqrt2 = 1.0f32 / std::f32::consts::SQRT_2;
458 for v in x.iter_mut() {
459 *v = 0.5 * *v * (1.0 + erf_approx(*v * inv_sqrt2));
460 }
461}
462
463fn erf_approx(x: f32) -> f32 {
464 let sign = if x < 0.0 { -1.0f32 } else { 1.0 };
465 let ax = x.abs();
466 let p = 0.3275911f32;
467 let a1 = 0.2548296f32;
468 let a2 = -0.2844967f32;
469 let a3 = 1.4214138f32;
470 let a4 = -1.4531521f32;
471 let a5 = 1.0614054f32;
472 let t = 1.0 / (1.0 + p * ax);
473 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-ax * ax).exp();
474 sign * y
475}
476
477pub fn encode_text_native(
481 weights: &Sam3TextEncoderWeights,
482 cfg: &Sam3TextConfig,
483 _prompt: Option<&str>,
484 gguf_packed: Option<&GgufPackedParams>,
485) -> Result<Sam3TextEncoded> {
486 if !weights.loaded {
487 return Ok(Sam3TextEncoded {
488 d_model: cfg.d_model,
489 width: cfg.width,
490 ..Default::default()
491 });
492 }
493 let seq_len = weights.context_length;
494 let tokens = vec![0u32; seq_len];
495 encode_tokens(weights, &tokens, 1, seq_len, gguf_packed)
496}