1use super::config::Sam2MemoryConfig;
56use super::transformer::{layer_norm_last, layer_norm_last_cpu, linear};
57use anyhow::{Result, ensure};
58use rlx_core::weight_map::WeightMap;
59
60pub struct Sam2RoPEAttnWeights {
63 pub q_w: Vec<f32>, pub q_b: Vec<f32>,
65 pub k_w: Vec<f32>, pub k_b: Vec<f32>,
67 pub v_w: Vec<f32>, pub v_b: Vec<f32>,
69 pub out_w: Vec<f32>, pub out_b: Vec<f32>,
71 pub embedding_dim: usize,
72 pub kv_in_dim: usize,
73 pub internal_dim: usize,
74 pub num_heads: usize,
75 pub rope_theta: f32,
76 pub rope_feat_size: [usize; 2],
77 pub rope_k_repeat: bool,
78}
79
80pub struct Sam2MemoryAttentionLayerWeights {
81 pub self_attn: Sam2RoPEAttnWeights,
82 pub cross_attn: Sam2RoPEAttnWeights,
83 pub norm1_g: Vec<f32>,
84 pub norm1_b: Vec<f32>,
85 pub norm2_g: Vec<f32>,
86 pub norm2_b: Vec<f32>,
87 pub norm3_g: Vec<f32>,
88 pub norm3_b: Vec<f32>,
89 pub linear1_w: Vec<f32>, pub linear1_b: Vec<f32>,
91 pub linear2_w: Vec<f32>, pub linear2_b: Vec<f32>,
93 pub pos_enc_at_attn: bool,
94 pub pos_enc_at_cross_attn_queries: bool,
95 pub pos_enc_at_cross_attn_keys: bool,
96 pub d_model: usize,
97}
98
99pub struct Sam2MemoryAttentionWeights {
100 pub layers: Vec<Sam2MemoryAttentionLayerWeights>,
101 pub norm_g: Vec<f32>,
102 pub norm_b: Vec<f32>,
103 pub d_model: usize,
104 pub pos_enc_at_input: bool,
105}
106
107fn load_rope_attn(
110 weights: &mut WeightMap,
111 prefix: &str,
112 cfg: &Sam2MemoryConfig,
113 is_self: bool,
114) -> Result<Sam2RoPEAttnWeights> {
115 let d = cfg.d_model;
116 let internal_dim = d; let kv_in_dim = if is_self { d } else { cfg.kv_in_dim };
118 let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
119 ensure!(
120 sh == vec![internal_dim, d],
121 "{prefix}.q_proj.weight shape {sh:?} not [{internal_dim}, {d}]"
122 );
123 let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
124 let (k_w, sh) = weights.take(&format!("{prefix}.k_proj.weight"))?;
125 ensure!(
126 sh == vec![internal_dim, kv_in_dim],
127 "{prefix}.k_proj.weight shape {sh:?} not [{internal_dim}, {kv_in_dim}]"
128 );
129 let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
130 let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
131 let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
132 let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
133 ensure!(
134 sh == vec![d, internal_dim],
135 "{prefix}.out_proj.weight shape {sh:?} not [{d}, {internal_dim}]"
136 );
137 let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
138 Ok(Sam2RoPEAttnWeights {
139 q_w,
140 q_b,
141 k_w,
142 k_b,
143 v_w,
144 v_b,
145 out_w,
146 out_b,
147 embedding_dim: d,
148 kv_in_dim,
149 internal_dim,
150 num_heads: cfg.num_heads,
151 rope_theta: cfg.rope_theta,
152 rope_feat_size: cfg.rope_feat_size,
153 rope_k_repeat: cfg.rope_k_repeat,
154 })
155}
156
157pub fn extract_memory_attention_weights(
158 weights: &mut WeightMap,
159 cfg: &Sam2MemoryConfig,
160) -> Result<Sam2MemoryAttentionWeights> {
161 let mut layers = Vec::with_capacity(cfg.num_layers);
162 for i in 0..cfg.num_layers {
163 let p = format!("memory_attention.layers.{i}");
164 let self_attn = load_rope_attn(
165 weights,
166 &format!("{p}.self_attn"),
167 cfg,
168 true,
169 )?;
170 let cross_attn = load_rope_attn(
171 weights,
172 &format!("{p}.cross_attn_image"),
173 cfg,
174 false,
175 )?;
176 let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
177 let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
178 let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
179 let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
180 let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
181 let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
182 let (linear1_w, sh) = weights.take(&format!("{p}.linear1.weight"))?;
183 ensure!(
184 sh == vec![cfg.dim_feedforward, cfg.d_model],
185 "{p}.linear1.weight shape {sh:?} not [{}, {}]",
186 cfg.dim_feedforward,
187 cfg.d_model
188 );
189 let (linear1_b, _) = weights.take(&format!("{p}.linear1.bias"))?;
190 let (linear2_w, _) = weights.take(&format!("{p}.linear2.weight"))?;
191 let (linear2_b, _) = weights.take(&format!("{p}.linear2.bias"))?;
192 layers.push(Sam2MemoryAttentionLayerWeights {
193 self_attn,
194 cross_attn,
195 norm1_g,
196 norm1_b,
197 norm2_g,
198 norm2_b,
199 norm3_g,
200 norm3_b,
201 linear1_w,
202 linear1_b,
203 linear2_w,
204 linear2_b,
205 pos_enc_at_attn: cfg.pos_enc_at_attn,
206 pos_enc_at_cross_attn_queries: cfg.pos_enc_at_cross_attn_queries,
207 pos_enc_at_cross_attn_keys: cfg.pos_enc_at_cross_attn_keys,
208 d_model: cfg.d_model,
209 });
210 }
211 let (norm_g, _) = weights.take("memory_attention.norm.weight")?;
212 let (norm_b, _) = weights.take("memory_attention.norm.bias")?;
213 Ok(Sam2MemoryAttentionWeights {
214 layers,
215 norm_g,
216 norm_b,
217 d_model: cfg.d_model,
218 pos_enc_at_input: cfg.pos_enc_at_input,
219 })
220}
221
222pub fn memory_attention_forward(
237 w: &Sam2MemoryAttentionWeights,
238 curr: &[f32],
239 curr_pos: &[f32],
240 memory: &[f32],
241 memory_pos: &[f32],
242 n_img: usize,
243 n_mem: usize,
244 kv_in_dim: usize,
245 num_obj_ptr_tokens: usize,
246) -> Result<Vec<f32>> {
247 let d = w.d_model;
248 ensure!(curr.len() == n_img * d, "curr len mismatch");
249 ensure!(curr_pos.len() == n_img * d, "curr_pos len mismatch");
250 ensure!(memory.len() == n_mem * kv_in_dim, "memory len mismatch");
251 ensure!(
252 memory_pos.len() == n_mem * kv_in_dim,
253 "memory_pos len mismatch"
254 );
255
256 let mut output = curr.to_vec();
258 if w.pos_enc_at_input {
259 for i in 0..output.len() {
260 output[i] += 0.1 * curr_pos[i];
261 }
262 }
263
264 for layer in &w.layers {
265 output = memory_attention_layer_forward(
266 layer,
267 output,
268 curr_pos,
269 memory,
270 memory_pos,
271 n_img,
272 n_mem,
273 kv_in_dim,
274 num_obj_ptr_tokens,
275 )?;
276 }
277
278 layer_norm_last(&mut output, n_img, d, &w.norm_g, &w.norm_b, 1e-5);
279 Ok(output)
280}
281
282pub fn memory_attention_forward_layers_only(
284 w: &Sam2MemoryAttentionWeights,
285 curr: &[f32],
286 curr_pos: &[f32],
287 memory: &[f32],
288 memory_pos: &[f32],
289 n_img: usize,
290 n_mem: usize,
291 kv_in_dim: usize,
292 num_obj_ptr_tokens: usize,
293) -> Result<Vec<f32>> {
294 let _d = w.d_model;
295 let mut output = curr.to_vec();
296 if w.pos_enc_at_input {
297 for i in 0..output.len() {
298 output[i] += 0.1 * curr_pos[i];
299 }
300 }
301 for layer in &w.layers {
302 output = memory_attention_layer_forward(
303 layer,
304 output,
305 curr_pos,
306 memory,
307 memory_pos,
308 n_img,
309 n_mem,
310 kv_in_dim,
311 num_obj_ptr_tokens,
312 )?;
313 }
314 Ok(output)
315}
316
317pub fn memory_attention_forward_ir_stack(
321 w: &Sam2MemoryAttentionWeights,
322 curr: &[f32],
323 curr_pos: &[f32],
324 memory: &[f32],
325 memory_pos: &[f32],
326 n_img: usize,
327 n_mem: usize,
328 kv_in_dim: usize,
329 num_obj_ptr_tokens: usize,
330) -> Result<Vec<f32>> {
331 let d = w.d_model;
332 let mut output = memory_attention_forward_layers_only(
333 w,
334 curr,
335 curr_pos,
336 memory,
337 memory_pos,
338 n_img,
339 n_mem,
340 kv_in_dim,
341 num_obj_ptr_tokens,
342 )?;
343 layer_norm_last_cpu(&mut output, n_img, d, &w.norm_g, &w.norm_b, 1e-5);
344 Ok(output)
345}
346
347#[allow(clippy::too_many_arguments)]
348pub(crate) fn memory_attention_layer_forward(
349 w: &Sam2MemoryAttentionLayerWeights,
350 mut tgt: Vec<f32>,
351 query_pos: &[f32],
352 memory: &[f32],
353 memory_pos: &[f32],
354 n_img: usize,
355 n_mem: usize,
356 kv_in_dim: usize,
357 num_obj_ptr_tokens: usize,
358) -> Result<Vec<f32>> {
359 let d = w.d_model;
360
361 let mut tgt2 = tgt.clone();
363 layer_norm_last(&mut tgt2, n_img, d, &w.norm1_g, &w.norm1_b, 1e-5);
364 let q_in = if w.pos_enc_at_attn {
365 let mut x = tgt2.clone();
366 for i in 0..x.len() {
367 x[i] += query_pos[i];
368 }
369 x
370 } else {
371 tgt2.clone()
372 };
373 let k_in = q_in.clone();
374 let v_in = tgt2.clone();
375 let sa_out = rope_attn_forward(
376 &w.self_attn,
377 &q_in,
378 n_img,
379 &k_in,
380 n_img,
381 &v_in,
382 n_img,
383 d,
384 d,
385 0,
386 );
387 for i in 0..tgt.len() {
388 tgt[i] += sa_out[i];
389 }
390
391 let mut tgt2 = tgt.clone();
393 layer_norm_last(&mut tgt2, n_img, d, &w.norm2_g, &w.norm2_b, 1e-5);
394 let q_in = if w.pos_enc_at_cross_attn_queries {
395 let mut x = tgt2.clone();
396 for i in 0..x.len() {
397 x[i] += query_pos[i];
398 }
399 x
400 } else {
401 tgt2
402 };
403 let k_in = if w.pos_enc_at_cross_attn_keys {
404 let mut x = memory.to_vec();
405 for i in 0..x.len() {
406 x[i] += memory_pos[i];
407 }
408 x
409 } else {
410 memory.to_vec()
411 };
412 let ca_out = rope_attn_forward(
413 &w.cross_attn,
414 &q_in,
415 n_img,
416 &k_in,
417 n_mem,
418 memory,
419 n_mem,
420 d,
421 kv_in_dim,
422 num_obj_ptr_tokens,
423 );
424 for i in 0..tgt.len() {
425 tgt[i] += ca_out[i];
426 }
427
428 let mut tgt2 = tgt.clone();
430 layer_norm_last(&mut tgt2, n_img, d, &w.norm3_g, &w.norm3_b, 1e-5);
431 let dim_ff = w.linear1_b.len();
432 let mut mid = linear(&tgt2, &w.linear1_w, &w.linear1_b, n_img, d, dim_ff);
433 for v in mid.iter_mut() {
436 if *v < 0.0 {
437 *v = 0.0;
438 }
439 }
440 let down = linear(&mid, &w.linear2_w, &w.linear2_b, n_img, dim_ff, d);
441 for i in 0..tgt.len() {
442 tgt[i] += down[i];
443 }
444
445 Ok(tgt)
446}
447
448#[allow(clippy::too_many_arguments)]
449fn rope_attn_forward(
450 w: &Sam2RoPEAttnWeights,
451 q: &[f32],
452 q_n: usize,
453 k: &[f32],
454 k_n: usize,
455 v: &[f32],
456 v_n: usize,
457 q_in_dim: usize,
458 kv_in_dim: usize,
459 num_k_exclude_rope: usize,
460) -> Vec<f32> {
461 let d = w.embedding_dim;
462 let id = w.internal_dim;
463 let nh = w.num_heads;
464 let dh = id / nh;
465 let scale = 1.0 / (dh as f32).sqrt();
466 let _ = q_in_dim;
467
468 let q_p = linear(q, &w.q_w, &w.q_b, q_n, d, id);
470 let k_p = linear(k, &w.k_w, &w.k_b, k_n, kv_in_dim, id);
471 let v_p = linear(v, &w.v_w, &w.v_b, v_n, kv_in_dim, id);
472
473 let q_h = separate_heads_b1(&q_p, q_n, nh, dh);
475 let mut k_h = separate_heads_b1(&k_p, k_n, nh, dh);
476 let v_h = separate_heads_b1(&v_p, v_n, nh, dh);
477
478 let num_k_rope = k_n.saturating_sub(num_k_exclude_rope);
482 let [end_x, end_y] = w.rope_feat_size;
483 let spatial = end_x * end_y;
484 let q_h = super::axial_rope::apply_axial_rope_2d(
485 &q_h,
486 nh,
487 q_n,
488 dh,
489 end_x,
490 end_y,
491 w.rope_theta,
492 1,
493 );
494 if num_k_rope > 0 {
495 let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
496 num_k_rope / spatial
497 } else {
498 1
499 };
500 let mut k_prefix = vec![0f32; nh * num_k_rope * dh];
501 for h in 0..nh {
502 let src = &k_h[h * k_n * dh..(h * k_n + num_k_rope) * dh];
503 let dst = &mut k_prefix[h * num_k_rope * dh..(h + 1) * num_k_rope * dh];
504 dst.copy_from_slice(src);
505 }
506 let rotated = super::axial_rope::apply_axial_rope_2d(
507 &k_prefix,
508 nh,
509 num_k_rope,
510 dh,
511 end_x,
512 end_y,
513 w.rope_theta,
514 r,
515 );
516 for h in 0..nh {
517 let src = &rotated[h * num_k_rope * dh..(h + 1) * num_k_rope * dh];
518 let dst = &mut k_h[h * k_n * dh..(h * k_n + num_k_rope) * dh];
519 dst.copy_from_slice(src);
520 }
521 }
522
523 let mut out_h = vec![0f32; nh * q_n * dh];
525 let mut scores = vec![0f32; q_n * k_n];
526 for h in 0..nh {
527 for i in 0..q_n {
528 for j in 0..k_n {
529 let mut acc = 0f32;
530 for dd in 0..dh {
531 acc += q_h[(h * q_n + i) * dh + dd] * k_h[(h * k_n + j) * dh + dd];
532 }
533 scores[i * k_n + j] = acc * scale;
534 }
535 }
536 for i in 0..q_n {
537 let row = &mut scores[i * k_n..(i + 1) * k_n];
538 let m = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
539 let mut s = 0f32;
540 for vv in row.iter_mut() {
541 *vv = (*vv - m).exp();
542 s += *vv;
543 }
544 for vv in row.iter_mut() {
545 *vv /= s;
546 }
547 }
548 for i in 0..q_n {
549 for dd in 0..dh {
550 let mut acc = 0f32;
551 for j in 0..k_n {
552 acc += scores[i * k_n + j] * v_h[(h * v_n + j) * dh + dd];
553 }
554 out_h[(h * q_n + i) * dh + dd] = acc;
555 }
556 }
557 }
558
559 let merged = recombine_heads_b1(&out_h, q_n, nh, dh);
561
562 linear(&merged, &w.out_w, &w.out_b, q_n, id, d)
564}
565
566fn separate_heads_b1(x: &[f32], n: usize, nh: usize, dh: usize) -> Vec<f32> {
567 let mut out = vec![0f32; nh * n * dh];
568 for i in 0..n {
569 for h in 0..nh {
570 for d in 0..dh {
571 out[(h * n + i) * dh + d] = x[i * (nh * dh) + h * dh + d];
572 }
573 }
574 }
575 out
576}
577
578fn recombine_heads_b1(x: &[f32], n: usize, nh: usize, dh: usize) -> Vec<f32> {
579 let mut out = vec![0f32; n * nh * dh];
580 for h in 0..nh {
581 for i in 0..n {
582 for d in 0..dh {
583 out[i * (nh * dh) + h * dh + d] = x[(h * n + i) * dh + d];
584 }
585 }
586 }
587 out
588}