1use super::config::{SAM3_PATCH_GRID, Sam3VitConfig};
33use super::preprocess::{Sam3PreprocessWeights, assemble_patch_tokens, extract_preprocess_weights};
34use super::tensor::{gelu_tanh, layer_norm, linear, matmul, matmul_bt, softmax_rows};
35use anyhow::{Result, ensure};
36use rlx_core::weight_map::WeightMap;
37use rlx_flow::{GgufPackedLinear, GgufPackedParams};
38
39#[derive(Clone)]
40pub struct Sam3VitBlockWeights {
41 pub norm1_w: Vec<f32>,
42 pub norm1_b: Vec<f32>,
43 pub qkv_w_t: Vec<f32>,
44 pub qkv_b: Vec<f32>,
45 pub qkv_gguf_prefix: Option<String>,
47 pub proj_w_t: Vec<f32>,
48 pub proj_b: Vec<f32>,
49 pub proj_gguf_prefix: Option<String>,
50 pub norm2_w: Vec<f32>,
51 pub norm2_b: Vec<f32>,
52 pub mlp_fc1_w_t: Vec<f32>,
53 pub mlp_fc1_b: Vec<f32>,
54 pub mlp_fc1_gguf_prefix: Option<String>,
55 pub mlp_fc2_w_t: Vec<f32>,
56 pub mlp_fc2_b: Vec<f32>,
57 pub mlp_fc2_gguf_prefix: Option<String>,
58}
59
60#[derive(Clone)]
61pub struct Sam3VisionEncoderWeights {
62 pub pre: Sam3PreprocessWeights,
63 pub ln_pre_w: Vec<f32>,
64 pub ln_pre_b: Vec<f32>,
65 pub blocks: Vec<Sam3VitBlockWeights>,
66}
67
68pub struct Sam3VisionOutput {
69 pub tokens: Vec<f32>,
70 pub grid: usize,
71 pub dim: usize,
72}
73
74pub fn extract_vision_encoder_weights(
75 weights: &mut WeightMap,
76 cfg: &Sam3VitConfig,
77 gguf_packed: Option<&GgufPackedParams>,
78) -> Result<Sam3VisionEncoderWeights> {
79 let pre = extract_preprocess_weights(weights, cfg)?;
80 let e = cfg.embed_dim;
81 let (ln_pre_w, ln_pre_b) = take_layer_norm(weights, &prefixes("ln_pre"), e)?;
82 let hidden = (e as f64 * cfg.mlp_ratio) as usize;
83 let mut blocks = Vec::with_capacity(cfg.depth);
84 for i in 0..cfg.depth {
85 let p = format!("blocks.{i}");
86 let pref = prefixes(&p);
87 let (norm1_w, norm1_b) = take_layer_norm(weights, &prefixed(&pref, "norm1"), e)?;
88 let (qkv_w_t, qkv_gguf_prefix) =
89 take_linear_w_or_gguf(weights, gguf_packed, &prefixed(&pref, "attn.qkv"), e, 3 * e)?;
90 let qkv_b = take_linear_b(weights, &prefixed(&pref, "attn.qkv"), 3 * e)?;
91 let (proj_w_t, proj_gguf_prefix) =
92 take_linear_w_or_gguf(weights, gguf_packed, &prefixed(&pref, "attn.proj"), e, e)?;
93 let proj_b = take_linear_b(weights, &prefixed(&pref, "attn.proj"), e)?;
94 let (norm2_w, norm2_b) = take_layer_norm(weights, &prefixed(&pref, "norm2"), e)?;
95 let (mlp_fc1_w_t, mlp_fc1_gguf_prefix) = take_linear_w_any_or_gguf(
96 weights,
97 gguf_packed,
98 &pref,
99 &["mlp.fc1", "mlp.lin1"],
100 e,
101 hidden,
102 )?;
103 let mlp_fc1_b = take_linear_b_any(weights, &pref, &["mlp.fc1", "mlp.lin1"], hidden)?;
104 let (mlp_fc2_w_t, mlp_fc2_gguf_prefix) = take_linear_w_any_or_gguf(
105 weights,
106 gguf_packed,
107 &pref,
108 &["mlp.fc2", "mlp.lin2"],
109 hidden,
110 e,
111 )?;
112 let mlp_fc2_b = take_linear_b_any(weights, &pref, &["mlp.fc2", "mlp.lin2"], e)?;
113 blocks.push(Sam3VitBlockWeights {
114 norm1_w,
115 norm1_b,
116 qkv_w_t,
117 qkv_b,
118 qkv_gguf_prefix,
119 proj_w_t,
120 proj_b,
121 proj_gguf_prefix,
122 norm2_w,
123 norm2_b,
124 mlp_fc1_w_t,
125 mlp_fc1_b,
126 mlp_fc1_gguf_prefix,
127 mlp_fc2_w_t,
128 mlp_fc2_b,
129 mlp_fc2_gguf_prefix,
130 });
131 }
132 Ok(Sam3VisionEncoderWeights {
133 pre,
134 ln_pre_w,
135 ln_pre_b,
136 blocks,
137 })
138}
139
140pub fn encode_image_native(
141 weights: &Sam3VisionEncoderWeights,
142 gguf_packed: Option<&GgufPackedParams>,
143 cfg: &Sam3VitConfig,
144 image_nchw: &[f32],
145) -> Result<Sam3VisionOutput> {
146 let e = cfg.embed_dim;
147 let grid = cfg.patch_grid();
148 ensure!(
149 grid == SAM3_PATCH_GRID,
150 "SAM3 base grid must be {SAM3_PATCH_GRID}"
151 );
152 let head_dim = e / cfg.num_heads;
153 ensure!(
154 head_dim * cfg.num_heads == e,
155 "embed_dim {e} not divisible by num_heads {}",
156 cfg.num_heads
157 );
158 let rope_pt = if cfg.window_size > 0 {
159 cfg.window_size
160 } else {
161 grid
162 };
163
164 let mut x = assemble_patch_tokens(&weights.pre, image_nchw)?;
166 x = layer_norm(
167 &x,
168 &weights.ln_pre_w,
169 &weights.ln_pre_b,
170 e,
171 cfg.layer_norm_eps as f32,
172 )?;
173
174 let global_set: std::collections::HashSet<usize> =
175 cfg.global_att_blocks.iter().copied().collect();
176 let rope_global = build_rope_freqs(head_dim, grid, grid, 10000.0, rope_pt as f32 / grid as f32);
177 let rope_window = build_rope_freqs(head_dim, cfg.window_size, cfg.window_size, 10000.0, 1.0);
178
179 for (i, block) in weights.blocks.iter().enumerate() {
180 let is_global = global_set.contains(&i);
181 block_forward(
182 &mut x,
183 block,
184 gguf_packed,
185 cfg,
186 grid,
187 if is_global { 0 } else { cfg.window_size },
188 if is_global {
189 &rope_global
190 } else {
191 &rope_window
192 },
193 head_dim,
194 cfg.num_heads,
195 )?;
196 }
197 Ok(Sam3VisionOutput {
200 tokens: x,
201 grid,
202 dim: e,
203 })
204}
205
206fn build_rope_freqs(
210 head_dim: usize,
211 end_x: usize,
212 end_y: usize,
213 theta: f32,
214 scale_pos: f32,
215) -> Vec<f32> {
216 let half = head_dim / 2;
217 assert!(
218 head_dim.is_multiple_of(4),
219 "RoPE head_dim must be divisible by 4"
220 );
221 let pair_per_axis = head_dim / 4;
222 let mut freqs_per_pair = Vec::with_capacity(pair_per_axis);
223 for k in 0..pair_per_axis {
224 let exp = (4 * k) as f32 / head_dim as f32;
225 freqs_per_pair.push(1.0 / theta.powf(exp));
226 }
227 let l = end_x * end_y;
228 let mut out = vec![0f32; l * head_dim];
229 for pos in 0..l {
230 let t_x = (pos % end_x) as f32 * scale_pos;
231 let t_y = (pos / end_x) as f32 * scale_pos;
232 for k in 0..pair_per_axis {
233 let ang_x = t_x * freqs_per_pair[k];
234 let ang_y = t_y * freqs_per_pair[k];
235 out[pos * head_dim + 2 * k] = ang_x.cos();
236 out[pos * head_dim + 2 * k + 1] = ang_x.sin();
237 out[pos * head_dim + 2 * (k + pair_per_axis)] = ang_y.cos();
238 out[pos * head_dim + 2 * (k + pair_per_axis) + 1] = ang_y.sin();
239 }
240 }
241 let _ = half;
242 out
243}
244
245fn rope_apply_inplace(
249 qk: &mut [f32],
250 freqs_cis: &[f32],
251 rows: usize,
252 seq_len: usize,
253 head_dim: usize,
254) {
255 let pairs = head_dim / 2;
256 for r in 0..rows {
257 let l = r % seq_len;
258 let f = &freqs_cis[l * head_dim..(l + 1) * head_dim];
259 let v = &mut qk[r * head_dim..(r + 1) * head_dim];
260 for k in 0..pairs {
261 let vr = v[2 * k];
262 let vi = v[2 * k + 1];
263 let fr = f[2 * k];
264 let fi = f[2 * k + 1];
265 v[2 * k] = vr * fr - vi * fi;
266 v[2 * k + 1] = vr * fi + vi * fr;
267 }
268 }
269}
270
271fn linear_maybe_gguf(
274 x: &[f32],
275 m: usize,
276 k: usize,
277 w_t: &[f32],
278 gguf: Option<&GgufPackedLinear>,
279 n: usize,
280 b: &[f32],
281) -> Result<Vec<f32>> {
282 let mut out = vec![0f32; m * n];
283 if let Some(p) = gguf {
284 ensure!(
285 p.in_dim == k && p.out_dim == n,
286 "packed linear shape {k}x{n} vs gguf {}x{}",
287 p.in_dim,
288 p.out_dim
289 );
290 rlx_cpu::gguf_matmul::gguf_matmul_bt(x, &p.w_q, &mut out, m, k, n, p.scheme);
291 } else {
292 ensure!(
293 !w_t.is_empty(),
294 "linear: missing F32 weights and no GGUF packed entry"
295 );
296 return linear(x, m, k, w_t, n, b);
297 }
298 for row in 0..m {
299 for col in 0..n {
300 out[row * n + col] += b[col];
301 }
302 }
303 Ok(out)
304}
305
306fn packed_for_prefix<'a>(
307 packed: Option<&'a GgufPackedParams>,
308 prefix: Option<&String>,
309) -> Option<&'a GgufPackedLinear> {
310 let key = format!("{}.weight", prefix.as_ref()?);
311 packed?.get_linear(&key)
312}
313
314fn block_forward(
315 x: &mut [f32],
316 block: &Sam3VitBlockWeights,
317 gguf_packed: Option<&GgufPackedParams>,
318 cfg: &Sam3VitConfig,
319 grid: usize,
320 window_size: usize,
321 freqs_cis: &[f32],
322 head_dim: usize,
323 num_heads: usize,
324) -> Result<()> {
325 let e = cfg.embed_dim;
326 let n = grid * grid;
327 let eps = cfg.layer_norm_eps as f32;
328
329 let n1 = layer_norm(x, &block.norm1_w, &block.norm1_b, e, eps)?;
331 let qkv_gguf = packed_for_prefix(gguf_packed, block.qkv_gguf_prefix.as_ref());
332 let proj_gguf = packed_for_prefix(gguf_packed, block.proj_gguf_prefix.as_ref());
333 let attn_out = if window_size == 0 {
334 attention_native(
335 &n1,
336 1,
337 n,
338 &block.qkv_w_t,
339 qkv_gguf,
340 &block.qkv_b,
341 &block.proj_w_t,
342 proj_gguf,
343 &block.proj_b,
344 freqs_cis,
345 num_heads,
346 head_dim,
347 )?
348 } else {
349 attention_windowed(
350 &n1,
351 grid,
352 grid,
353 window_size,
354 e,
355 &block.qkv_w_t,
356 qkv_gguf,
357 &block.qkv_b,
358 &block.proj_w_t,
359 proj_gguf,
360 &block.proj_b,
361 freqs_cis,
362 num_heads,
363 head_dim,
364 )?
365 };
366 for i in 0..x.len() {
367 x[i] += attn_out[i];
368 }
369
370 let n2 = layer_norm(x, &block.norm2_w, &block.norm2_b, e, eps)?;
371 let hidden = block.mlp_fc1_b.len();
372 let fc1_gguf = packed_for_prefix(gguf_packed, block.mlp_fc1_gguf_prefix.as_ref());
373 let fc2_gguf = packed_for_prefix(gguf_packed, block.mlp_fc2_gguf_prefix.as_ref());
374 let mut mlp = linear_maybe_gguf(
375 &n2,
376 n,
377 e,
378 &block.mlp_fc1_w_t,
379 fc1_gguf,
380 hidden,
381 &block.mlp_fc1_b,
382 )?;
383 gelu_tanh(&mut mlp);
384 let ffn = linear_maybe_gguf(
385 &mlp,
386 n,
387 hidden,
388 &block.mlp_fc2_w_t,
389 fc2_gguf,
390 e,
391 &block.mlp_fc2_b,
392 )?;
393 for i in 0..x.len() {
394 x[i] += ffn[i];
395 }
396 Ok(())
397}
398
399fn attention_windowed(
400 x: &[f32],
401 h: usize,
402 w: usize,
403 ws: usize,
404 e: usize,
405 qkv_w_t: &[f32],
406 qkv_gguf: Option<&GgufPackedLinear>,
407 qkv_b: &[f32],
408 proj_w_t: &[f32],
409 proj_gguf: Option<&GgufPackedLinear>,
410 proj_b: &[f32],
411 freqs_cis: &[f32],
412 num_heads: usize,
413 head_dim: usize,
414) -> Result<Vec<f32>> {
415 let pad_h = (ws - h % ws) % ws;
416 let pad_w = (ws - w % ws) % ws;
417 let hp = h + pad_h;
418 let wp = w + pad_w;
419 let nh = hp / ws;
420 let nw = wp / ws;
421 let num_windows = nh * nw;
422 let win_len = ws * ws;
423
424 let mut win = vec![0f32; num_windows * win_len * e];
426 for y in 0..hp {
427 for xc in 0..wp {
428 let wy = y / ws;
429 let wx = xc / ws;
430 let ry = y % ws;
431 let rx = xc % ws;
432 let widx = wy * nw + wx;
433 let dst = ((widx * ws + ry) * ws + rx) * e;
434 if y < h && xc < w {
435 let src = (y * w + xc) * e;
436 win[dst..dst + e].copy_from_slice(&x[src..src + e]);
437 }
438 }
440 }
441
442 let attn = attention_native(
443 &win,
444 num_windows,
445 win_len,
446 qkv_w_t,
447 qkv_gguf,
448 qkv_b,
449 proj_w_t,
450 proj_gguf,
451 proj_b,
452 freqs_cis,
453 num_heads,
454 head_dim,
455 )?;
456
457 let mut out = vec![0f32; h * w * e];
460 for y in 0..h {
461 for xc in 0..w {
462 let wy = y / ws;
463 let wx = xc / ws;
464 let ry = y % ws;
465 let rx = xc % ws;
466 let widx = wy * nw + wx;
467 let src = ((widx * ws + ry) * ws + rx) * e;
468 let dst = (y * w + xc) * e;
469 out[dst..dst + e].copy_from_slice(&attn[src..src + e]);
470 }
471 }
472 Ok(out)
473}
474
475fn attention_native(
478 x: &[f32],
479 b: usize,
480 l: usize,
481 qkv_w_t: &[f32],
482 qkv_gguf: Option<&GgufPackedLinear>,
483 qkv_b: &[f32],
484 proj_w_t: &[f32],
485 proj_gguf: Option<&GgufPackedLinear>,
486 proj_b: &[f32],
487 freqs_cis: &[f32],
488 num_heads: usize,
489 head_dim: usize,
490) -> Result<Vec<f32>> {
491 let e = num_heads * head_dim;
492 let rows = b * l;
493 let qkv = linear_maybe_gguf(x, rows, e, qkv_w_t, qkv_gguf, 3 * e, qkv_b)?;
494
495 let bh = b * num_heads;
498 let mut q = vec![0f32; bh * l * head_dim];
499 let mut k = vec![0f32; bh * l * head_dim];
500 let mut v = vec![0f32; bh * l * head_dim];
501 for bi in 0..b {
502 for li in 0..l {
503 let src = (bi * l + li) * 3 * e;
504 for hd in 0..num_heads {
505 let qd_src = src + hd * head_dim;
506 let kd_src = src + e + hd * head_dim;
507 let vd_src = src + 2 * e + hd * head_dim;
508 let dst = ((bi * num_heads + hd) * l + li) * head_dim;
509 q[dst..dst + head_dim].copy_from_slice(&qkv[qd_src..qd_src + head_dim]);
510 k[dst..dst + head_dim].copy_from_slice(&qkv[kd_src..kd_src + head_dim]);
511 v[dst..dst + head_dim].copy_from_slice(&qkv[vd_src..vd_src + head_dim]);
512 }
513 }
514 }
515
516 rope_apply_inplace(&mut q, freqs_cis, bh * l, l, head_dim);
517 rope_apply_inplace(&mut k, freqs_cis, bh * l, l, head_dim);
518
519 let scale = 1.0f32 / (head_dim as f32).sqrt();
520 let mut attn_out = vec![0f32; bh * l * head_dim];
521 let mut scores = vec![0f32; l * l];
522
523 for bhi in 0..bh {
524 let q_h = &q[bhi * l * head_dim..(bhi + 1) * l * head_dim];
525 let k_h = &k[bhi * l * head_dim..(bhi + 1) * l * head_dim];
526 let v_h = &v[bhi * l * head_dim..(bhi + 1) * l * head_dim];
527 matmul_bt(q_h, k_h, &mut scores, l, head_dim, l, scale);
529 softmax_rows(&mut scores, l, l);
530 let out_h = &mut attn_out[bhi * l * head_dim..(bhi + 1) * l * head_dim];
532 matmul(&scores, v_h, out_h, l, l, head_dim);
533 }
534
535 let mut packed = vec![0f32; rows * e];
537 for bi in 0..b {
538 for li in 0..l {
539 for hd in 0..num_heads {
540 let src = ((bi * num_heads + hd) * l + li) * head_dim;
541 let dst = (bi * l + li) * e + hd * head_dim;
542 packed[dst..dst + head_dim].copy_from_slice(&attn_out[src..src + head_dim]);
543 }
544 }
545 }
546 linear_maybe_gguf(&packed, rows, e, proj_w_t, proj_gguf, e, proj_b)
547}
548
549fn prefixes(suffix: &str) -> Vec<String> {
550 [
551 "detector.backbone.vision_backbone.trunk",
552 "detector.backbone.visual.trunk",
553 "backbone.vision_backbone.trunk",
554 "backbone.visual.trunk",
555 "visual.trunk",
556 "trunk",
557 ]
558 .iter()
559 .map(|p| format!("{p}.{suffix}"))
560 .collect()
561}
562
563fn prefixed(prefixes: &[String], suffix: &str) -> Vec<String> {
564 prefixes.iter().map(|p| format!("{p}.{suffix}")).collect()
565}
566
567fn take_layer_norm(
568 weights: &mut WeightMap,
569 bases: &[String],
570 dim: usize,
571) -> Result<(Vec<f32>, Vec<f32>)> {
572 let w = take_shape(weights, &suffixes(bases, "weight"), &[dim])?;
573 let b = take_shape(weights, &suffixes(bases, "bias"), &[dim])?;
574 Ok((w, b))
575}
576
577fn take_linear_w_or_gguf(
578 weights: &mut WeightMap,
579 gguf_packed: Option<&GgufPackedParams>,
580 bases: &[String],
581 in_dim: usize,
582 out_dim: usize,
583) -> Result<(Vec<f32>, Option<String>)> {
584 let keys = suffixes(bases, "weight");
585 for key in &keys {
586 if weights.has(key) {
587 let w = take_linear_w(weights, bases, in_dim, out_dim)?;
588 return Ok((w, None));
589 }
590 if let Some(packed) = gguf_packed {
591 if let Some(prefix) = key.strip_suffix(".weight") {
592 if packed.get_linear(key).is_some() {
593 return Ok((Vec::new(), Some(prefix.to_string())));
594 }
595 }
596 }
597 }
598 anyhow::bail!("none of the SAM3 linear weight keys were found: {keys:?}")
599}
600
601fn take_linear_w_any_or_gguf(
602 weights: &mut WeightMap,
603 gguf_packed: Option<&GgufPackedParams>,
604 block_prefixes: &[String],
605 names: &[&str],
606 in_dim: usize,
607 out_dim: usize,
608) -> Result<(Vec<f32>, Option<String>)> {
609 let bases: Vec<String> = block_prefixes
610 .iter()
611 .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
612 .collect();
613 take_linear_w_or_gguf(weights, gguf_packed, &bases, in_dim, out_dim)
614}
615
616fn take_linear_w(
617 weights: &mut WeightMap,
618 bases: &[String],
619 in_dim: usize,
620 out_dim: usize,
621) -> Result<Vec<f32>> {
622 let keys = suffixes(bases, "weight");
623 for key in &keys {
624 if weights.has(key) {
625 let (data, shape) = weights.take_transposed(key)?;
626 ensure!(
627 shape == vec![in_dim, out_dim],
628 "{key} expected [{in_dim}, {out_dim}], got {shape:?}"
629 );
630 return Ok(data);
631 }
632 }
633 anyhow::bail!("none of the SAM3 linear weight keys were found: {keys:?}")
634}
635
636#[allow(dead_code)]
637fn take_linear_w_any(
638 weights: &mut WeightMap,
639 block_prefixes: &[String],
640 names: &[&str],
641 in_dim: usize,
642 out_dim: usize,
643) -> Result<Vec<f32>> {
644 let bases: Vec<String> = block_prefixes
645 .iter()
646 .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
647 .collect();
648 take_linear_w(weights, &bases, in_dim, out_dim)
649}
650
651fn take_linear_b(weights: &mut WeightMap, bases: &[String], dim: usize) -> Result<Vec<f32>> {
652 take_shape(weights, &suffixes(bases, "bias"), &[dim])
653}
654
655fn take_linear_b_any(
656 weights: &mut WeightMap,
657 block_prefixes: &[String],
658 names: &[&str],
659 dim: usize,
660) -> Result<Vec<f32>> {
661 let bases: Vec<String> = block_prefixes
662 .iter()
663 .flat_map(|p| names.iter().map(move |name| format!("{p}.{name}")))
664 .collect();
665 take_linear_b(weights, &bases, dim)
666}
667
668fn suffixes(bases: &[String], suffix: &str) -> Vec<String> {
669 bases.iter().map(|b| format!("{b}.{suffix}")).collect()
670}
671
672fn take_shape(weights: &mut WeightMap, keys: &[String], expected: &[usize]) -> Result<Vec<f32>> {
673 for key in keys {
674 if weights.has(key) {
675 let (data, shape) = weights.take(key)?;
676 ensure!(
677 shape == expected,
678 "{key} expected {expected:?}, got {shape:?}"
679 );
680 return Ok(data);
681 }
682 }
683 anyhow::bail!("none of the SAM3 weight keys were found: {keys:?}")
684}