1use super::config::{SAM2_IMG_SIZE, Sam2MemoryEncoderConfig};
43use super::memory_mask_ir::{
44 Sam2MemoryConv1x1Compiled, Sam2MemoryFuserCompiled, Sam2MemoryMaskDownCompiled,
45 Sam2MemoryPrefixCompiled,
46};
47use super::prompt_encoder::{conv2d_1x1, gelu_erf_inplace, layernorm2d_nchw, sigmoid_inplace};
48use anyhow::{Result, ensure};
49use rlx_core::weight_map::WeightMap;
50use rlx_runtime::Device;
51use std::f32::consts::PI;
52
53pub struct Sam2MaskDownSamplerWeights {
56 pub levels: Vec<DownSampleLevel>,
59 pub final_conv_w: Vec<f32>,
61 pub final_conv_b: Vec<f32>,
62 pub kernel: usize,
63 pub stride: usize,
64 pub padding: usize,
65 pub embed_dim: usize,
66}
67
68pub struct DownSampleLevel {
69 pub conv_w: Vec<f32>, pub conv_b: Vec<f32>, pub ln_g: Vec<f32>, pub ln_b: Vec<f32>,
73 pub in_c: usize,
74 pub out_c: usize,
75}
76
77pub struct Sam2CXBlockWeights {
78 pub dw_conv_w: Vec<f32>, pub dw_conv_b: Vec<f32>, pub ln_g: Vec<f32>,
81 pub ln_b: Vec<f32>,
82 pub pw1_w: Vec<f32>, pub pw1_b: Vec<f32>, pub pw2_w: Vec<f32>, pub pw2_b: Vec<f32>, pub gamma: Option<Vec<f32>>,
89 pub dim: usize,
90 pub kernel: usize,
91 pub padding: usize,
92}
93
94pub struct Sam2FuserWeights {
95 pub input_proj_w: Option<Vec<f32>>,
97 pub input_proj_b: Option<Vec<f32>>,
98 pub layers: Vec<Sam2CXBlockWeights>,
99 pub dim: usize,
100}
101
102pub struct Sam2MemoryEncoderWeights {
103 pub mask_downsampler: Sam2MaskDownSamplerWeights,
104 pub prefix: Option<Sam2MemoryPrefixCompiled>,
105 pub mask_down: Option<Sam2MemoryMaskDownCompiled>,
106 pub pix_proj: Option<Sam2MemoryConv1x1Compiled>,
107 pub fuser_ir: Option<Sam2MemoryFuserCompiled>,
108 pub out_proj_ir: Option<Sam2MemoryConv1x1Compiled>,
109 pub pix_feat_proj_w: Vec<f32>, pub pix_feat_proj_b: Vec<f32>,
111 pub fuser: Sam2FuserWeights,
112 pub out_proj_w: Option<Vec<f32>>,
115 pub out_proj_b: Option<Vec<f32>>,
116 pub in_dim: usize,
117 pub out_dim: usize,
118 pub pe_num_pos_feats: usize,
119 pub pe_temperature: f32,
120}
121
122pub fn extract_memory_encoder_weights(
125 weights: &mut WeightMap,
126 cfg: &Sam2MemoryEncoderConfig,
127) -> Result<Sam2MemoryEncoderWeights> {
128 let mask_downsampler = extract_mask_downsampler(weights, cfg)?;
129
130 let (pix_feat_proj_w, sh) = weights.take("memory_encoder.pix_feat_proj.weight")?;
131 ensure!(
132 sh == vec![cfg.in_dim, cfg.in_dim, 1, 1],
133 "pix_feat_proj.weight shape {sh:?} not [{}, {}, 1, 1]",
134 cfg.in_dim,
135 cfg.in_dim
136 );
137 let (pix_feat_proj_b, _) = weights.take("memory_encoder.pix_feat_proj.bias")?;
138
139 let fuser = extract_fuser(weights, cfg)?;
140
141 let (out_proj_w, out_proj_b) = if cfg.in_dim == cfg.out_dim {
142 (None, None)
143 } else {
144 let (w, sh) = weights.take("memory_encoder.out_proj.weight")?;
145 ensure!(
146 sh == vec![cfg.out_dim, cfg.in_dim, 1, 1],
147 "out_proj.weight shape {sh:?} not [{}, {}, 1, 1]",
148 cfg.out_dim,
149 cfg.in_dim
150 );
151 let (b, _) = weights.take("memory_encoder.out_proj.bias")?;
152 (Some(w), Some(b))
153 };
154
155 Ok(Sam2MemoryEncoderWeights {
156 mask_downsampler,
157 prefix: None,
158 mask_down: None,
159 pix_proj: None,
160 fuser_ir: None,
161 out_proj_ir: None,
162 pix_feat_proj_w,
163 pix_feat_proj_b,
164 fuser,
165 out_proj_w,
166 out_proj_b,
167 in_dim: cfg.in_dim,
168 out_dim: cfg.out_dim,
169 pe_num_pos_feats: cfg.pe_num_pos_feats,
170 pe_temperature: cfg.pe_temperature,
171 })
172}
173
174fn extract_mask_downsampler(
175 weights: &mut WeightMap,
176 cfg: &Sam2MemoryEncoderConfig,
177) -> Result<Sam2MaskDownSamplerWeights> {
178 let mut num_layers = 0;
181 let mut acc = 1usize;
182 while acc < cfg.mask_downsampler_total_stride {
183 acc *= cfg.mask_downsampler_stride;
184 num_layers += 1;
185 }
186 ensure!(
187 acc == cfg.mask_downsampler_total_stride,
188 "mask_downsampler total_stride {} must be a power of stride {}",
189 cfg.mask_downsampler_total_stride,
190 cfg.mask_downsampler_stride
191 );
192
193 let mut levels = Vec::with_capacity(num_layers);
194 let mut in_c = 1usize;
195 let stride2 = cfg.mask_downsampler_stride * cfg.mask_downsampler_stride;
196 for li in 0..num_layers {
200 let out_c = in_c * stride2;
201 let conv_idx = li * 3;
202 let ln_idx = conv_idx + 1;
203 let (conv_w, sh) = weights.take(&format!(
204 "memory_encoder.mask_downsampler.encoder.{conv_idx}.weight"
205 ))?;
206 ensure!(
207 sh == vec![
208 out_c,
209 in_c,
210 cfg.mask_downsampler_kernel,
211 cfg.mask_downsampler_kernel
212 ],
213 "mask_downsampler conv {li} weight shape {sh:?} not [{out_c}, {in_c}, {}, {}]",
214 cfg.mask_downsampler_kernel,
215 cfg.mask_downsampler_kernel
216 );
217 let (conv_b, _) = weights.take(&format!(
218 "memory_encoder.mask_downsampler.encoder.{conv_idx}.bias"
219 ))?;
220 let (ln_g, _) = weights.take(&format!(
221 "memory_encoder.mask_downsampler.encoder.{ln_idx}.weight"
222 ))?;
223 let (ln_b, _) = weights.take(&format!(
224 "memory_encoder.mask_downsampler.encoder.{ln_idx}.bias"
225 ))?;
226 levels.push(DownSampleLevel {
227 conv_w,
228 conv_b,
229 ln_g,
230 ln_b,
231 in_c,
232 out_c,
233 });
234 in_c = out_c;
235 }
236 let final_idx = num_layers * 3;
238 let (final_conv_w, sh) = weights.take(&format!(
239 "memory_encoder.mask_downsampler.encoder.{final_idx}.weight"
240 ))?;
241 ensure!(
242 sh == vec![cfg.in_dim, in_c, 1, 1],
243 "mask_downsampler final conv weight shape {sh:?} not [{}, {in_c}, 1, 1]",
244 cfg.in_dim
245 );
246 let (final_conv_b, _) = weights.take(&format!(
247 "memory_encoder.mask_downsampler.encoder.{final_idx}.bias"
248 ))?;
249
250 Ok(Sam2MaskDownSamplerWeights {
251 levels,
252 final_conv_w,
253 final_conv_b,
254 kernel: cfg.mask_downsampler_kernel,
255 stride: cfg.mask_downsampler_stride,
256 padding: cfg.mask_downsampler_padding,
257 embed_dim: cfg.in_dim,
258 })
259}
260
261fn extract_fuser(
262 weights: &mut WeightMap,
263 cfg: &Sam2MemoryEncoderConfig,
264) -> Result<Sam2FuserWeights> {
265 let (input_proj_w, input_proj_b) = if cfg.fuser_input_projection {
266 let (w, sh) = weights.take("memory_encoder.fuser.proj.weight")?;
267 ensure!(
268 sh == vec![cfg.fuser_dim, cfg.fuser_dim, 1, 1],
269 "fuser.proj.weight shape {sh:?} not [{}, {}, 1, 1]",
270 cfg.fuser_dim,
271 cfg.fuser_dim
272 );
273 let (b, _) = weights.take("memory_encoder.fuser.proj.bias")?;
274 (Some(w), Some(b))
275 } else {
276 (None, None)
277 };
278
279 let mut layers = Vec::with_capacity(cfg.fuser_num_layers);
280 for i in 0..cfg.fuser_num_layers {
281 let p = format!("memory_encoder.fuser.layers.{i}");
282 let (dw_conv_w, sh) = weights.take(&format!("{p}.dwconv.weight"))?;
283 let dim = cfg.fuser_dim;
285 let k = cfg.fuser_kernel;
286 if cfg.fuser_use_dwconv {
287 ensure!(
288 sh == vec![dim, 1, k, k],
289 "{p}.dwconv.weight shape {sh:?} not [{dim}, 1, {k}, {k}]"
290 );
291 } else {
292 ensure!(
293 sh == vec![dim, dim, k, k],
294 "{p}.dwconv.weight shape {sh:?} not [{dim}, {dim}, {k}, {k}]"
295 );
296 }
297 let (dw_conv_b, _) = weights.take(&format!("{p}.dwconv.bias"))?;
298 let (ln_g, _) = weights.take(&format!("{p}.norm.weight"))?;
299 let (ln_b, _) = weights.take(&format!("{p}.norm.bias"))?;
300 let (pw1_w, sh) = weights.take(&format!("{p}.pwconv1.weight"))?;
301 ensure!(
302 sh == vec![4 * dim, dim],
303 "{p}.pwconv1.weight shape {sh:?} not [{}, {dim}]",
304 4 * dim
305 );
306 let (pw1_b, _) = weights.take(&format!("{p}.pwconv1.bias"))?;
307 let (pw2_w, _) = weights.take(&format!("{p}.pwconv2.weight"))?;
308 let (pw2_b, _) = weights.take(&format!("{p}.pwconv2.bias"))?;
309 let gamma = if cfg.fuser_layer_scale_init_value > 0.0 {
310 let (g, _) = weights.take(&format!("{p}.gamma"))?;
311 Some(g)
312 } else {
313 None
314 };
315 layers.push(Sam2CXBlockWeights {
316 dw_conv_w,
317 dw_conv_b,
318 ln_g,
319 ln_b,
320 pw1_w,
321 pw1_b,
322 pw2_w,
323 pw2_b,
324 gamma,
325 dim,
326 kernel: k,
327 padding: cfg.fuser_padding,
328 });
329 }
330 Ok(Sam2FuserWeights {
331 input_proj_w,
332 input_proj_b,
333 layers,
334 dim: cfg.fuser_dim,
335 })
336}
337
338pub fn compile_memory_encoder_ir(
340 weights: &mut Sam2MemoryEncoderWeights,
341 mask_in_h: usize,
342 mask_in_w: usize,
343 feat_h: usize,
344 feat_w: usize,
345 device: Device,
346 profile: &rlx_flow::CompileProfile,
347) -> Result<()> {
348 weights.prefix = Some(Sam2MemoryPrefixCompiled::compile_with_profile(
349 &weights.mask_downsampler,
350 weights.in_dim,
351 mask_in_h,
352 mask_in_w,
353 feat_h,
354 feat_w,
355 &weights.pix_feat_proj_w,
356 &weights.pix_feat_proj_b,
357 device,
358 profile,
359 )?);
360 weights.fuser_ir = Some(Sam2MemoryFuserCompiled::compile_with_profile(
361 &weights.fuser,
362 feat_h,
363 feat_w,
364 device,
365 profile,
366 )?);
367 if let (Some(opw), Some(opb)) = (&weights.out_proj_w, &weights.out_proj_b) {
368 weights.out_proj_ir = Some(Sam2MemoryConv1x1Compiled::compile_with_profile(
369 weights.in_dim,
370 weights.out_dim,
371 feat_h,
372 feat_w,
373 opw,
374 opb,
375 device,
376 profile,
377 )?);
378 }
379 Ok(())
380}
381
382pub fn compile_memory_mask_ir(
384 weights: &mut Sam2MemoryEncoderWeights,
385 mask_in_h: usize,
386 mask_in_w: usize,
387 device: Device,
388) -> Result<()> {
389 compile_memory_encoder_ir(
390 weights,
391 mask_in_h,
392 mask_in_w,
393 mask_in_h / total_stride(&weights.mask_downsampler),
394 mask_in_w / total_stride(&weights.mask_downsampler),
395 device,
396 &rlx_flow::CompileProfile::sam2(),
397 )
398}
399
400pub struct Sam2MemoryEncoderOutput {
403 pub features: Vec<f32>,
405 pub pos: Vec<f32>,
407 pub h: usize,
408 pub w: usize,
409}
410
411pub fn memory_encoder_forward(
420 w: &mut Sam2MemoryEncoderWeights,
421 pix_feat: &[f32],
422 masks: &[f32],
423 pix_h: usize,
424 pix_w: usize,
425 skip_mask_sigmoid: bool,
426) -> Result<Sam2MemoryEncoderOutput> {
427 ensure!(
428 pix_feat.len() == w.in_dim * pix_h * pix_w,
429 "pix_feat len {} ≠ in_dim·h·w ({}·{pix_h}·{pix_w})",
430 pix_feat.len(),
431 w.in_dim
432 );
433 let in_h = SAM2_IMG_SIZE;
434 let in_w = SAM2_IMG_SIZE;
435 ensure!(
436 masks.len() == in_h * in_w,
437 "masks len {} ≠ H·W ({in_h}·{in_w}); pass a full-resolution mask",
438 masks.len()
439 );
440
441 let mut m: Vec<f32> = masks.to_vec();
443 if !skip_mask_sigmoid {
444 sigmoid_inplace(&mut m);
445 }
446
447 let x = if let Some(ref mut prefix) = w.prefix {
449 prefix.run(&m, pix_feat)?
450 } else {
451 let m_down = if let Some(ref mut md) = w.mask_down {
452 md.run(&m)?
453 } else {
454 mask_downsampler_forward(&w.mask_downsampler, &m, in_h, in_w)?
455 };
456 let down_h = in_h / total_stride(&w.mask_downsampler);
457 let down_w = in_w / total_stride(&w.mask_downsampler);
458 ensure!(
459 down_h == pix_h && down_w == pix_w,
460 "mask after downsampling ({down_h}×{down_w}) doesn't match pix_feat ({pix_h}×{pix_w})"
461 );
462 let mut x = if let Some(ref mut p) = w.pix_proj {
463 p.run(pix_feat)?
464 } else {
465 conv2d_1x1(
466 pix_feat,
467 w.in_dim,
468 w.in_dim,
469 pix_h,
470 pix_w,
471 &w.pix_feat_proj_w,
472 &w.pix_feat_proj_b,
473 )
474 };
475 for i in 0..x.len() {
476 x[i] += m_down[i];
477 }
478 x
479 };
480
481 let x = if let Some(ref mut f) = w.fuser_ir {
483 f.run(&x)?
484 } else {
485 fuser_forward(&w.fuser, x, pix_h, pix_w)
486 };
487
488 let features = if let Some(ref mut o) = w.out_proj_ir {
490 o.run(&x)?
491 } else if let (Some(opw), Some(opb)) = (&w.out_proj_w, &w.out_proj_b) {
492 conv2d_1x1(&x, w.in_dim, w.out_dim, pix_h, pix_w, opw, opb)
493 } else {
494 x
495 };
496
497 let pos = sinusoidal_pos_2d(2 * w.pe_num_pos_feats, pix_h, pix_w, w.pe_temperature);
499
500 Ok(Sam2MemoryEncoderOutput {
501 features,
502 pos,
503 h: pix_h,
504 w: pix_w,
505 })
506}
507
508fn total_stride(d: &Sam2MaskDownSamplerWeights) -> usize {
509 d.stride.pow(d.levels.len() as u32)
510}
511
512fn mask_downsampler_forward(
516 w: &Sam2MaskDownSamplerWeights,
517 input: &[f32],
518 h: usize,
519 ww: usize,
520) -> Result<Vec<f32>> {
521 let mut cur = input.to_vec();
522 let mut cur_c = 1usize;
523 let mut cur_h = h;
524 let mut cur_w = ww;
525 for level in &w.levels {
526 let out_h = (cur_h + 2 * w.padding - w.kernel) / w.stride + 1;
527 let out_w = (cur_w + 2 * w.padding - w.kernel) / w.stride + 1;
528 cur = conv2d_general(
529 &cur,
530 cur_c,
531 level.out_c,
532 cur_h,
533 cur_w,
534 w.kernel,
535 w.stride,
536 w.padding,
537 &level.conv_w,
538 &level.conv_b,
539 );
540 cur_c = level.out_c;
541 cur_h = out_h;
542 cur_w = out_w;
543 layernorm2d_nchw(
544 &mut cur,
545 cur_c,
546 cur_h,
547 cur_w,
548 &level.ln_g,
549 &level.ln_b,
550 1e-6,
551 );
552 gelu_erf_inplace(&mut cur);
553 }
554 let out = conv2d_1x1(
556 &cur,
557 cur_c,
558 w.embed_dim,
559 cur_h,
560 cur_w,
561 &w.final_conv_w,
562 &w.final_conv_b,
563 );
564 Ok(out)
565}
566
567fn fuser_forward(w: &Sam2FuserWeights, mut x: Vec<f32>, h: usize, ww: usize) -> Vec<f32> {
568 if let (Some(pw), Some(pb)) = (&w.input_proj_w, &w.input_proj_b) {
569 x = conv2d_1x1(&x, w.dim, w.dim, h, ww, pw, pb);
570 }
571 for layer in &w.layers {
572 x = cx_block_forward(layer, x, h, ww);
573 }
574 x
575}
576
577fn cx_block_forward(w: &Sam2CXBlockWeights, x: Vec<f32>, h: usize, ww: usize) -> Vec<f32> {
578 let dim = w.dim;
579 let mut y = conv2d_depthwise_k_pad(
581 &x,
582 dim,
583 h,
584 ww,
585 w.kernel,
586 w.padding,
587 &w.dw_conv_w,
588 &w.dw_conv_b,
589 );
590 layernorm2d_nchw(&mut y, dim, h, ww, &w.ln_g, &w.ln_b, 1e-6);
592 let mut nhwc = vec![0f32; h * ww * dim];
595 for c in 0..dim {
596 for yy in 0..h {
597 for xx in 0..ww {
598 nhwc[(yy * ww + xx) * dim + c] = y[c * h * ww + yy * ww + xx];
599 }
600 }
601 }
602 let four_d = 4 * dim;
603 let mut up = vec![0f32; h * ww * four_d];
604 for r in 0..h * ww {
605 for o in 0..four_d {
606 let mut acc = w.pw1_b[o];
607 for k in 0..dim {
608 acc += nhwc[r * dim + k] * w.pw1_w[o * dim + k];
609 }
610 up[r * four_d + o] = acc;
611 }
612 }
613 gelu_erf_inplace(&mut up);
614 let mut down = vec![0f32; h * ww * dim];
615 for r in 0..h * ww {
616 for o in 0..dim {
617 let mut acc = w.pw2_b[o];
618 for k in 0..four_d {
619 acc += up[r * four_d + k] * w.pw2_w[o * four_d + k];
620 }
621 down[r * dim + o] = acc;
622 }
623 }
624 if let Some(gamma) = &w.gamma {
625 for r in 0..h * ww {
626 for c in 0..dim {
627 down[r * dim + c] *= gamma[c];
628 }
629 }
630 }
631 let mut out = x;
633 for c in 0..dim {
634 for yy in 0..h {
635 for xx in 0..ww {
636 out[c * h * ww + yy * ww + xx] += down[(yy * ww + xx) * dim + c];
637 }
638 }
639 }
640 out
641}
642
643fn conv2d_general(
648 input: &[f32],
649 in_c: usize,
650 out_c: usize,
651 h: usize,
652 w: usize,
653 k: usize,
654 s: usize,
655 p: usize,
656 weight: &[f32], bias: &[f32], ) -> Vec<f32> {
659 let out_h = (h + 2 * p - k) / s + 1;
660 let out_w = (w + 2 * p - k) / s + 1;
661 let mut out = vec![0f32; out_c * out_h * out_w];
662 for oc in 0..out_c {
663 let b = bias[oc];
664 for oy in 0..out_h {
665 for ox in 0..out_w {
666 let mut acc = b;
667 for ic in 0..in_c {
668 for ky in 0..k {
669 let iy = oy as isize * s as isize + ky as isize - p as isize;
670 if iy < 0 || iy >= h as isize {
671 continue;
672 }
673 for kx in 0..k {
674 let ix = ox as isize * s as isize + kx as isize - p as isize;
675 if ix < 0 || ix >= w as isize {
676 continue;
677 }
678 let v = input[ic * h * w + iy as usize * w + ix as usize];
679 let w_idx = ((oc * in_c + ic) * k + ky) * k + kx;
680 acc += v * weight[w_idx];
681 }
682 }
683 }
684 out[oc * out_h * out_w + oy * out_w + ox] = acc;
685 }
686 }
687 }
688 out
689}
690
691fn conv2d_depthwise_k_pad(
693 input: &[f32],
694 dim: usize,
695 h: usize,
696 w: usize,
697 k: usize,
698 p: usize,
699 weight: &[f32],
700 bias: &[f32],
701) -> Vec<f32> {
702 let mut out = vec![0f32; dim * h * w];
703 for c in 0..dim {
704 let b = bias[c];
705 let w_base = c * k * k; for oy in 0..h {
707 for ox in 0..w {
708 let mut acc = b;
709 for ky in 0..k {
710 let iy = oy as isize + ky as isize - p as isize;
711 if iy < 0 || iy >= h as isize {
712 continue;
713 }
714 for kx in 0..k {
715 let ix = ox as isize + kx as isize - p as isize;
716 if ix < 0 || ix >= w as isize {
717 continue;
718 }
719 let v = input[c * h * w + iy as usize * w + ix as usize];
720 acc += v * weight[w_base + ky * k + kx];
721 }
722 }
723 out[c * h * w + oy * w + ox] = acc;
724 }
725 }
726 }
727 out
728}
729
730pub(super) fn sinusoidal_pos_2d(d_model: usize, h: usize, w: usize, temperature: f32) -> Vec<f32> {
734 let nf = d_model / 2;
735 let scale: f32 = 2.0 * PI;
736 let eps: f32 = 1e-6;
737 let mut out = vec![0f32; d_model * h * w];
738 let mut dim_t = vec![0f32; nf];
739 for i in 0..nf {
740 let exp = 2.0 * ((i / 2) as f32) / (nf as f32);
741 dim_t[i] = temperature.powf(exp);
742 }
743 for y in 0..h {
744 let y_emb = ((y + 1) as f32) / ((h as f32) + eps) * scale;
745 for x in 0..w {
746 let x_emb = ((x + 1) as f32) / ((w as f32) + eps) * scale;
747 for i in 0..nf {
748 let py = y_emb / dim_t[i];
749 let v = if i % 2 == 0 { py.sin() } else { py.cos() };
750 out[i * h * w + y * w + x] = v;
751 }
752 for i in 0..nf {
753 let px = x_emb / dim_t[i];
754 let v = if i % 2 == 0 { px.sin() } else { px.cos() };
755 out[(nf + i) * h * w + y * w + x] = v;
756 }
757 }
758 }
759 out
760}