1use super::axial_rope::apply_axial_rope_2d;
26use super::memory_attention::{
27 Sam2MemoryAttentionLayerWeights, Sam2MemoryAttentionWeights, Sam2RoPEAttnWeights,
28};
29use anyhow::Result;
30use rlx_ir::infer::GraphExt;
31use rlx_ir::op::{Activation, BinaryOp, MaskKind};
32use rlx_ir::{DType, Graph, NodeId, Shape};
33use rlx_runtime::{CompileOptions, CompiledGraph, Device, Session};
34use std::collections::HashMap;
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38enum LayerRopeMode {
39 HostBetweenGraphs,
41 InGraph,
43}
44
45const LN_EPS: f32 = 1e-5;
46const INPUT_POS_SCALE: f32 = 0.1;
47
48pub const MAX_MEMORY_FRAMES_IN_ATTN: usize = 7;
50
51pub fn max_memory_slots(n_img: usize, max_obj_ptr_tokens: usize) -> usize {
52 MAX_MEMORY_FRAMES_IN_ATTN * n_img + max_obj_ptr_tokens
53}
54
55struct MemoryAttentionLayerCompiled {
56 mode: LayerRopeMode,
57 fused: Option<CompiledGraph>,
59 self_proj: Option<CompiledGraph>,
60 self_attn: Option<CompiledGraph>,
61 cross_proj: Option<CompiledGraph>,
62 cross_attn: Option<CompiledGraph>,
63 ffn: Option<CompiledGraph>,
64 layer: Sam2MemoryAttentionLayerWeights,
65}
66
67pub struct MemoryAttentionCompiled {
68 layers: Vec<MemoryAttentionLayerCompiled>,
69 final_norm: CompiledGraph,
71 pub n_img: usize,
72 pub max_n_mem: usize,
73 pub d_model: usize,
74 pub kv_in_dim: usize,
75 pub max_obj_ptr_tokens: usize,
76 pos_enc_at_input: bool,
77}
78
79impl MemoryAttentionCompiled {
80 pub fn compile(
81 w: &Sam2MemoryAttentionWeights,
82 n_img: usize,
83 max_n_mem: usize,
84 max_obj_ptr_tokens: usize,
85 device: Device,
86 ) -> Result<Self> {
87 Self::compile_with_profile(
88 w,
89 n_img,
90 max_n_mem,
91 max_obj_ptr_tokens,
92 device,
93 &rlx_flow::CompileProfile::sam_encoder(),
94 )
95 }
96
97 pub fn compile_with_profile(
98 w: &Sam2MemoryAttentionWeights,
99 n_img: usize,
100 max_n_mem: usize,
101 max_obj_ptr_tokens: usize,
102 device: Device,
103 profile: &rlx_flow::CompileProfile,
104 ) -> Result<Self> {
105 Self::compile_with_mode(
106 w,
107 n_img,
108 max_n_mem,
109 max_obj_ptr_tokens,
110 device,
111 LayerRopeMode::HostBetweenGraphs,
112 profile,
113 )
114 }
115
116 pub fn compile_in_graph_rope(
119 w: &Sam2MemoryAttentionWeights,
120 n_img: usize,
121 max_n_mem: usize,
122 max_obj_ptr_tokens: usize,
123 device: Device,
124 ) -> Result<Self> {
125 Self::compile_in_graph_rope_with_profile(
126 w,
127 n_img,
128 max_n_mem,
129 max_obj_ptr_tokens,
130 device,
131 &rlx_flow::CompileProfile::sam_encoder(),
132 )
133 }
134
135 pub fn compile_in_graph_rope_with_profile(
136 w: &Sam2MemoryAttentionWeights,
137 n_img: usize,
138 max_n_mem: usize,
139 max_obj_ptr_tokens: usize,
140 device: Device,
141 profile: &rlx_flow::CompileProfile,
142 ) -> Result<Self> {
143 Self::compile_with_mode(
144 w,
145 n_img,
146 max_n_mem,
147 max_obj_ptr_tokens,
148 device,
149 LayerRopeMode::InGraph,
150 profile,
151 )
152 }
153
154 fn compile_with_mode(
155 w: &Sam2MemoryAttentionWeights,
156 n_img: usize,
157 max_n_mem: usize,
158 max_obj_ptr_tokens: usize,
159 device: Device,
160 mode: LayerRopeMode,
161 profile: &rlx_flow::CompileProfile,
162 ) -> Result<Self> {
163 anyhow::ensure!(
164 w.layers
165 .iter()
166 .all(|l| l.self_attn.num_heads == 1 && l.cross_attn.num_heads == 1),
167 "memory_attention_ir currently requires num_heads=1"
168 );
169 let kv = w.layers[0].cross_attn.kv_in_dim;
170 let mut layers = Vec::with_capacity(w.layers.len());
171 for layer in &w.layers {
172 layers.push(compile_layer(
173 layer,
174 n_img,
175 max_n_mem,
176 kv,
177 max_obj_ptr_tokens,
178 device,
179 mode,
180 profile,
181 )?);
182 }
183 let (fn_g, fn_p) = build_final_norm_graph(&w.norm_g, &w.norm_b, n_img, w.d_model)?;
184 let mut final_norm =
185 Session::new(device).compile_with(fn_g, &compile_opts_no_fusion(device));
186 for (n, d) in &fn_p {
187 final_norm.set_param(n, d);
188 }
189 Ok(Self {
190 layers,
191 final_norm,
192 n_img,
193 max_n_mem,
194 d_model: w.d_model,
195 kv_in_dim: kv,
196 max_obj_ptr_tokens,
197 pos_enc_at_input: w.pos_enc_at_input,
198 })
199 }
200
201 pub fn run(
202 &mut self,
203 curr: &[f32],
204 curr_pos: &[f32],
205 memory: &[f32],
206 memory_pos: &[f32],
207 active_n_mem: usize,
208 num_obj_ptr_tokens: usize,
209 ) -> Result<Vec<f32>> {
210 let d = self.d_model;
211 let kv = self.kv_in_dim;
212 anyhow::ensure!(curr.len() == self.n_img * d);
213 anyhow::ensure!(curr_pos.len() == self.n_img * d);
214 anyhow::ensure!(memory.len() >= active_n_mem * kv);
215 anyhow::ensure!(memory_pos.len() >= active_n_mem * kv);
216 anyhow::ensure!(active_n_mem <= self.max_n_mem);
217 anyhow::ensure!(num_obj_ptr_tokens <= self.max_obj_ptr_tokens);
218
219 let mut tgt = curr.to_vec();
220 if self.pos_enc_at_input {
221 for i in 0..tgt.len() {
222 tgt[i] += INPUT_POS_SCALE * curr_pos[i];
223 }
224 }
225
226 let mut mem_pad = vec![0f32; self.max_n_mem * kv];
227 let mut mem_pos_pad = vec![0f32; self.max_n_mem * kv];
228 mem_pad[..active_n_mem * kv].copy_from_slice(&memory[..active_n_mem * kv]);
229 mem_pos_pad[..active_n_mem * kv].copy_from_slice(&memory_pos[..active_n_mem * kv]);
230
231 let nh = 1usize;
232 let mut mask = vec![0f32; nh * self.n_img * self.max_n_mem];
233 fill_cross_attn_bias(&mut mask, nh, self.n_img, self.max_n_mem, active_n_mem);
234
235 for layer in &mut self.layers {
236 tgt = match layer.mode {
237 LayerRopeMode::InGraph => layer
238 .fused
239 .as_mut()
240 .expect("fused layer")
241 .run(&[
242 ("tgt", &tgt),
243 ("curr_pos", curr_pos),
244 ("memory", &mem_pad),
245 ("memory_pos", &mem_pos_pad),
246 ("mask_ca", &mask),
247 ])
248 .into_iter()
249 .next()
250 .expect("fused layer output"),
251 LayerRopeMode::HostBetweenGraphs => layer.run_host_between(
252 &tgt,
253 curr_pos,
254 &mem_pad,
255 &mem_pos_pad,
256 active_n_mem,
257 num_obj_ptr_tokens,
258 )?,
259 };
260 }
261
262 let outs = self.final_norm.run(&[("tgt", &tgt)]);
263 Ok(outs.into_iter().next().expect("memory_attention output"))
264 }
265}
266
267impl MemoryAttentionLayerCompiled {
268 fn run_host_between(
269 &mut self,
270 tgt: &[f32],
271 curr_pos: &[f32],
272 memory: &[f32],
273 memory_pos: &[f32],
274 active_n_mem: usize,
275 num_obj_ptr_tokens: usize,
276 ) -> Result<Vec<f32>> {
277 let d = self.layer.d_model;
278 let kv = self.layer.cross_attn.kv_in_dim;
279 let n_img = tgt.len() / d;
280 let max_n_mem = memory.len() / kv;
281 let _id = self.layer.self_attn.internal_dim;
282
283 let p = self
284 .self_proj
285 .as_mut()
286 .expect("self_proj")
287 .run(&[("tgt", tgt), ("curr_pos", curr_pos)]);
288 let mut it = p.into_iter();
289 let mut sa_q = it.next().expect("sa_q");
290 let mut sa_k = it.next().expect("sa_k");
291 let sa_v = it.next().expect("sa_v");
292 host_rotate_qk(&mut sa_q, n_img, &self.layer.self_attn);
293 host_rotate_qk(&mut sa_k, n_img, &self.layer.self_attn);
294
295 let mut tgt = self
296 .self_attn
297 .as_mut()
298 .expect("self_attn")
299 .run(&[
300 ("tgt", tgt),
301 ("sa_q", &sa_q),
302 ("sa_k", &sa_k),
303 ("sa_v", &sa_v),
304 ])
305 .into_iter()
306 .next()
307 .expect("tgt after self");
308
309 let c = self.cross_proj.as_mut().expect("cross_proj").run(&[
310 ("tgt", &tgt),
311 ("curr_pos", curr_pos),
312 ("memory", memory),
313 ("memory_pos", memory_pos),
314 ]);
315 let mut it = c.into_iter();
316 let mut ca_q = it.next().expect("ca_q");
317 let mut ca_k = it.next().expect("ca_k");
318 host_rotate_qk(&mut ca_q, n_img, &self.layer.cross_attn);
319 host_rotate_k_partial(
320 &mut ca_k,
321 max_n_mem,
322 active_n_mem,
323 num_obj_ptr_tokens,
324 &self.layer.cross_attn,
325 );
326
327 let nh = self.layer.cross_attn.num_heads;
328 let mut mask = vec![0f32; nh * n_img * max_n_mem];
329 fill_cross_attn_bias(&mut mask, nh, n_img, max_n_mem, active_n_mem);
330
331 tgt = self
332 .cross_attn
333 .as_mut()
334 .expect("cross_attn")
335 .run(&[
336 ("tgt", &tgt),
337 ("ca_q", &ca_q),
338 ("ca_k", &ca_k),
339 ("memory", memory),
340 ("mask_ca", &mask),
341 ])
342 .into_iter()
343 .next()
344 .expect("tgt after cross");
345
346 self.ffn
347 .as_mut()
348 .expect("ffn")
349 .run(&[("tgt", &tgt)])
350 .into_iter()
351 .next()
352 .ok_or_else(|| anyhow::anyhow!("ffn output missing"))
353 }
354}
355
356fn compile_opts_no_fusion(device: Device) -> CompileOptions {
357 rlx_core::flow_bridge::compile_options_sam2_memory_attention(device)
358}
359
360fn compile_layer(
361 layer: &Sam2MemoryAttentionLayerWeights,
362 n_img: usize,
363 n_mem: usize,
364 kv: usize,
365 max_obj_ptr_tokens: usize,
366 device: Device,
367 mode: LayerRopeMode,
368 profile: &rlx_flow::CompileProfile,
369) -> Result<MemoryAttentionLayerCompiled> {
370 let compile =
371 |g: Graph, p: HashMap<String, Vec<f32>>, opts: &CompileOptions| -> Result<CompiledGraph> {
372 let mut c = Session::new(device).compile_with(g, opts);
373 for (n, d) in &p {
374 c.set_param(n, d);
375 }
376 Ok(c)
377 };
378 match mode {
379 LayerRopeMode::InGraph => {
380 let opts = compile_opts_no_fusion(device);
381 let (g, p) = build_layer_graph(layer, n_img, n_mem, kv, max_obj_ptr_tokens)?;
382 Ok(MemoryAttentionLayerCompiled {
383 mode,
384 fused: Some(compile(g, p, &opts)?),
385 self_proj: None,
386 self_attn: None,
387 cross_proj: None,
388 cross_attn: None,
389 ffn: None,
390 layer: clone_layer(layer),
391 })
392 }
393 LayerRopeMode::HostBetweenGraphs => {
394 let opts = rlx_core::flow_bridge::compile_options_for_profile(profile, device);
395 let (g1, p1) = build_self_proj_graph(layer, n_img)?;
396 let (g2, p2) = build_self_attn_graph(layer, n_img)?;
397 let (g3, p3) = build_cross_proj_graph(layer, n_img, n_mem, kv)?;
398 let (g4, p4) = build_cross_attn_graph(layer, n_img, n_mem, kv)?;
399 let (g5, p5) = build_ffn_graph(layer, n_img)?;
400 Ok(MemoryAttentionLayerCompiled {
401 mode,
402 fused: None,
403 self_proj: Some(compile(g1, p1, &opts)?),
404 self_attn: Some(compile(g2, p2, &opts)?),
405 cross_proj: Some(compile(g3, p3, &opts)?),
406 cross_attn: Some(compile(g4, p4, &opts)?),
407 ffn: Some(compile(g5, p5, &opts)?),
408 layer: clone_layer(layer),
409 })
410 }
411 }
412}
413
414fn fill_cross_attn_bias(
415 out: &mut [f32],
416 nh: usize,
417 n_img: usize,
418 max_n_mem: usize,
419 active_n_mem: usize,
420) {
421 out.fill(0.0);
422 for h in 0..nh {
423 for qi in 0..n_img {
424 for ki in active_n_mem..max_n_mem {
425 out[(h * n_img + qi) * max_n_mem + ki] = -1e4;
426 }
427 }
428 }
429}
430
431fn host_rotate_qk(seq: &mut [f32], n_tokens: usize, w: &Sam2RoPEAttnWeights) {
432 let nh = w.num_heads;
433 let dh = w.internal_dim / nh;
434 let [ex, ey] = w.rope_feat_size;
435 let out = apply_axial_rope_2d(seq, nh, n_tokens, dh, ex, ey, w.rope_theta, 1);
436 seq.copy_from_slice(&out);
437}
438
439fn host_rotate_k_partial(
440 k: &mut [f32],
441 buf_tokens: usize,
442 active_tokens: usize,
443 num_k_exclude_rope: usize,
444 w: &Sam2RoPEAttnWeights,
445) {
446 let nh = w.num_heads;
447 let dh = w.internal_dim / nh;
448 let [ex, ey] = w.rope_feat_size;
449 let spatial = ex * ey;
450 let num_k_rope = active_tokens.saturating_sub(num_k_exclude_rope);
451 if num_k_rope == 0 {
452 return;
453 }
454 let _ = buf_tokens;
455 let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
456 num_k_rope / spatial
457 } else {
458 1
459 };
460 let prefix_len = nh * num_k_rope * dh;
461 let out = apply_axial_rope_2d(
462 &k[..prefix_len],
463 nh,
464 num_k_rope,
465 dh,
466 ex,
467 ey,
468 w.rope_theta,
469 r,
470 );
471 k[..prefix_len].copy_from_slice(&out);
472}
473
474fn clone_layer(l: &Sam2MemoryAttentionLayerWeights) -> Sam2MemoryAttentionLayerWeights {
475 Sam2MemoryAttentionLayerWeights {
476 self_attn: clone_rope(&l.self_attn),
477 cross_attn: clone_rope(&l.cross_attn),
478 norm1_g: l.norm1_g.clone(),
479 norm1_b: l.norm1_b.clone(),
480 norm2_g: l.norm2_g.clone(),
481 norm2_b: l.norm2_b.clone(),
482 norm3_g: l.norm3_g.clone(),
483 norm3_b: l.norm3_b.clone(),
484 linear1_w: l.linear1_w.clone(),
485 linear1_b: l.linear1_b.clone(),
486 linear2_w: l.linear2_w.clone(),
487 linear2_b: l.linear2_b.clone(),
488 pos_enc_at_attn: l.pos_enc_at_attn,
489 pos_enc_at_cross_attn_queries: l.pos_enc_at_cross_attn_queries,
490 pos_enc_at_cross_attn_keys: l.pos_enc_at_cross_attn_keys,
491 d_model: l.d_model,
492 }
493}
494
495fn clone_rope(w: &Sam2RoPEAttnWeights) -> Sam2RoPEAttnWeights {
496 Sam2RoPEAttnWeights {
497 q_w: w.q_w.clone(),
498 q_b: w.q_b.clone(),
499 k_w: w.k_w.clone(),
500 k_b: w.k_b.clone(),
501 v_w: w.v_w.clone(),
502 v_b: w.v_b.clone(),
503 out_w: w.out_w.clone(),
504 out_b: w.out_b.clone(),
505 embedding_dim: w.embedding_dim,
506 kv_in_dim: w.kv_in_dim,
507 internal_dim: w.internal_dim,
508 num_heads: w.num_heads,
509 rope_theta: w.rope_theta,
510 rope_feat_size: w.rope_feat_size,
511 rope_k_repeat: w.rope_k_repeat,
512 }
513}
514
515fn matmul_weight(w_out_in: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
516 let mut t = vec![0f32; in_d * out_d];
517 for o in 0..out_d {
518 for k in 0..in_d {
519 t[k * out_d + o] = w_out_in[o * in_d + k];
520 }
521 }
522 t
523}
524
525fn bind_linear(
526 g: &mut Graph,
527 params: &mut HashMap<String, Vec<f32>>,
528 prefix: &str,
529 w: &[f32],
530 b: &[f32],
531 in_d: usize,
532 out_d: usize,
533) -> (NodeId, NodeId) {
534 let f = DType::F32;
535 let w_id = g.param(format!("{prefix}.w"), Shape::new(&[in_d, out_d], f));
536 let b_id = g.param(format!("{prefix}.b"), Shape::new(&[out_d], f));
537 params.insert(format!("{prefix}.w"), matmul_weight(w, in_d, out_d));
538 params.insert(format!("{prefix}.b"), b.to_vec());
539 (w_id, b_id)
540}
541
542fn linear(
543 g: &mut Graph,
544 params: &mut HashMap<String, Vec<f32>>,
545 prefix: &str,
546 x: NodeId,
547 w: &[f32],
548 b: &[f32],
549 in_d: usize,
550 out_d: usize,
551 seq: usize,
552) -> NodeId {
553 let f = DType::F32;
554 let (w_id, b_id) = bind_linear(g, params, prefix, w, b, in_d, out_d);
555 g.fused_matmul_bias_act(x, w_id, b_id, None, Shape::new(&[1, seq, out_d], f))
556}
557
558fn bind_ln(
559 g: &mut Graph,
560 params: &mut HashMap<String, Vec<f32>>,
561 prefix: &str,
562 gamm: &[f32],
563 bet: &[f32],
564 e: usize,
565) -> (NodeId, NodeId) {
566 let f = DType::F32;
567 let g_id = g.param(format!("{prefix}.g"), Shape::new(&[e], f));
568 let b_id = g.param(format!("{prefix}.b"), Shape::new(&[e], f));
569 params.insert(format!("{prefix}.g"), gamm.to_vec());
570 params.insert(format!("{prefix}.b"), bet.to_vec());
571 (g_id, b_id)
572}
573
574fn layer_norm(
575 g: &mut Graph,
576 params: &mut HashMap<String, Vec<f32>>,
577 prefix: &str,
578 x: NodeId,
579 gamm: &[f32],
580 bet: &[f32],
581 seq: usize,
582 e: usize,
583) -> NodeId {
584 let f = DType::F32;
585 let shape = Shape::new(&[1, seq, e], f);
586 let (g_id, b_id) = bind_ln(g, params, prefix, gamm, bet, e);
587 g.layer_norm(x, g_id, b_id, -1, LN_EPS, shape)
588}
589
590fn maybe_add_pos(
591 g: &mut Graph,
592 x: NodeId,
593 pos: NodeId,
594 seq: usize,
595 e: usize,
596 enabled: bool,
597) -> NodeId {
598 if enabled {
599 let f = DType::F32;
600 g.binary(BinaryOp::Add, x, pos, Shape::new(&[1, seq, e], f))
601 } else {
602 x
603 }
604}
605
606fn apply_axial_rope_graph(
607 g: &mut Graph,
608 x: NodeId,
609 w: &Sam2RoPEAttnWeights,
610 _seq: usize,
611 repeat_factor: usize,
612) -> NodeId {
613 let nh = w.num_heads;
614 let dh = w.internal_dim / nh;
615 let [ex, ey] = w.rope_feat_size;
616 g.axial_rope2d(x, ex, ey, dh, nh, w.rope_theta, repeat_factor)
617}
618
619fn build_rope_attn(
620 g: &mut Graph,
621 params: &mut HashMap<String, Vec<f32>>,
622 prefix: &str,
623 w: &Sam2RoPEAttnWeights,
624 q_in: NodeId,
625 k_in: NodeId,
626 v_in: NodeId,
627 q_len: usize,
628 k_len: usize,
629 q_in_dim: usize,
630 kv_in_dim: usize,
631 num_k_exclude_rope: usize,
632 bias: Option<NodeId>,
633) -> NodeId {
634 let d = w.embedding_dim;
635 let id = w.internal_dim;
636 let nh = w.num_heads;
637 let dh = id / nh;
638 let f = DType::F32;
639 let [end_x, end_y] = w.rope_feat_size;
640 let spatial = end_x * end_y;
641
642 let q_proj = linear(
643 g,
644 params,
645 &format!("{prefix}.q"),
646 q_in,
647 &w.q_w,
648 &w.q_b,
649 q_in_dim,
650 id,
651 q_len,
652 );
653 let k_proj = linear(
654 g,
655 params,
656 &format!("{prefix}.k"),
657 k_in,
658 &w.k_w,
659 &w.k_b,
660 kv_in_dim,
661 id,
662 k_len,
663 );
664 let v_proj = linear(
665 g,
666 params,
667 &format!("{prefix}.v"),
668 v_in,
669 &w.v_w,
670 &w.v_b,
671 kv_in_dim,
672 id,
673 k_len,
674 );
675
676 let q_rot = apply_axial_rope_graph(g, q_proj, w, q_len, 1);
677
678 let num_k_rope = k_len.saturating_sub(num_k_exclude_rope);
679 let k_rot = if num_k_rope == 0 {
680 k_proj
681 } else if num_k_rope == k_len {
682 let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
683 num_k_rope / spatial
684 } else {
685 1
686 };
687 apply_axial_rope_graph(g, k_proj, w, k_len, r)
688 } else {
689 let k_prefix = g.narrow_(k_proj, 1, 0, num_k_rope);
690 let k_suffix = g.narrow_(k_proj, 1, num_k_rope, k_len - num_k_rope);
691 let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
692 num_k_rope / spatial
693 } else {
694 1
695 };
696 let k_pre_rot = apply_axial_rope_graph(g, k_prefix, w, num_k_rope, r);
697 g.concat_(vec![k_pre_rot, k_suffix], 1)
698 };
699
700 let out_shape = Shape::new(&[1, q_len, id], f);
701 let attn = if let Some(b) = bias {
702 g.attention_bias(q_rot, k_rot, v_proj, b, nh, dh, out_shape.clone())
703 } else {
704 g.attention_kind(
705 q_rot,
706 k_rot,
707 v_proj,
708 nh,
709 dh,
710 MaskKind::None,
711 out_shape.clone(),
712 )
713 };
714 linear(
715 g,
716 params,
717 &format!("{prefix}.o"),
718 attn,
719 &w.out_w,
720 &w.out_b,
721 id,
722 d,
723 q_len,
724 )
725}
726
727fn build_layer_graph(
728 layer: &Sam2MemoryAttentionLayerWeights,
729 n_img: usize,
730 n_mem: usize,
731 kv_in_dim: usize,
732 num_obj_ptr_tokens: usize,
733) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
734 let d = layer.d_model;
735 let f = DType::F32;
736 let mut g = Graph::new("sam2_mem_attn_layer");
737 let mut params = HashMap::new();
738
739 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
740 let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
741 let memory = g.input("memory", Shape::new(&[1, n_mem, kv_in_dim], f));
742 let memory_pos = g.input("memory_pos", Shape::new(&[1, n_mem, kv_in_dim], f));
743 let mask_ca = g.input(
744 "mask_ca",
745 Shape::new(&[1, layer.cross_attn.num_heads, n_img, n_mem], f),
746 );
747
748 let seq_shape = Shape::new(&[1, n_img, d], f);
749 let mut tgt2 = layer_norm(
750 &mut g,
751 &mut params,
752 "n1",
753 tgt,
754 &layer.norm1_g,
755 &layer.norm1_b,
756 n_img,
757 d,
758 );
759 let q_sa = maybe_add_pos(&mut g, tgt2, curr_pos, n_img, d, layer.pos_enc_at_attn);
760 let sa = build_rope_attn(
761 &mut g,
762 &mut params,
763 "sa",
764 &layer.self_attn,
765 q_sa,
766 tgt2,
767 tgt2,
768 n_img,
769 n_img,
770 d,
771 d,
772 0,
773 None,
774 );
775 let mut out = g.binary(BinaryOp::Add, tgt, sa, seq_shape.clone());
776
777 tgt2 = layer_norm(
778 &mut g,
779 &mut params,
780 "n2",
781 out,
782 &layer.norm2_g,
783 &layer.norm2_b,
784 n_img,
785 d,
786 );
787 let q_ca = maybe_add_pos(
788 &mut g,
789 tgt2,
790 curr_pos,
791 n_img,
792 d,
793 layer.pos_enc_at_cross_attn_queries,
794 );
795 let k_ca = maybe_add_pos(
796 &mut g,
797 memory,
798 memory_pos,
799 n_mem,
800 kv_in_dim,
801 layer.pos_enc_at_cross_attn_keys,
802 );
803 let ca = build_rope_attn(
804 &mut g,
805 &mut params,
806 "ca",
807 &layer.cross_attn,
808 q_ca,
809 k_ca,
810 memory,
811 n_img,
812 n_mem,
813 d,
814 kv_in_dim,
815 num_obj_ptr_tokens,
816 Some(mask_ca),
817 );
818 out = g.binary(BinaryOp::Add, out, ca, seq_shape.clone());
819
820 tgt2 = layer_norm(
821 &mut g,
822 &mut params,
823 "n3",
824 out,
825 &layer.norm3_g,
826 &layer.norm3_b,
827 n_img,
828 d,
829 );
830 let dim_ff = layer.linear1_b.len();
831 let mid = linear(
832 &mut g,
833 &mut params,
834 "ff1",
835 tgt2,
836 &layer.linear1_w,
837 &layer.linear1_b,
838 d,
839 dim_ff,
840 n_img,
841 );
842 let mid = g.activation(Activation::Relu, mid, Shape::new(&[1, n_img, dim_ff], f));
843 let down = linear(
844 &mut g,
845 &mut params,
846 "ff2",
847 mid,
848 &layer.linear2_w,
849 &layer.linear2_b,
850 dim_ff,
851 d,
852 n_img,
853 );
854 out = g.binary(BinaryOp::Add, out, down, seq_shape);
855
856 g.set_outputs(vec![out]);
857 Ok((g, params))
858}
859
860fn build_qkv_proj(
861 g: &mut Graph,
862 params: &mut HashMap<String, Vec<f32>>,
863 prefix: &str,
864 w: &Sam2RoPEAttnWeights,
865 q_in: NodeId,
866 k_in: NodeId,
867 v_in: NodeId,
868 q_len: usize,
869 k_len: usize,
870 q_in_dim: usize,
871 kv_in_dim: usize,
872) -> (NodeId, NodeId, NodeId) {
873 let id = w.internal_dim;
874 let q = linear(
875 g,
876 params,
877 &format!("{prefix}.q"),
878 q_in,
879 &w.q_w,
880 &w.q_b,
881 q_in_dim,
882 id,
883 q_len,
884 );
885 let k = linear(
886 g,
887 params,
888 &format!("{prefix}.k"),
889 k_in,
890 &w.k_w,
891 &w.k_b,
892 kv_in_dim,
893 id,
894 k_len,
895 );
896 let v = linear(
897 g,
898 params,
899 &format!("{prefix}.v"),
900 v_in,
901 &w.v_w,
902 &w.v_b,
903 kv_in_dim,
904 id,
905 k_len,
906 );
907 (q, k, v)
908}
909
910fn build_attention_out(
911 g: &mut Graph,
912 params: &mut HashMap<String, Vec<f32>>,
913 prefix: &str,
914 w: &Sam2RoPEAttnWeights,
915 q: NodeId,
916 k: NodeId,
917 v: NodeId,
918 q_len: usize,
919 _k_len: usize,
920 mask: Option<NodeId>,
921) -> NodeId {
922 let d = w.embedding_dim;
923 let id = w.internal_dim;
924 let nh = w.num_heads;
925 let dh = id / nh;
926 let f = DType::F32;
927 let out_shape = Shape::new(&[1, q_len, id], f);
928 let attn = if let Some(m) = mask {
929 g.attention_bias(q, k, v, m, nh, dh, out_shape.clone())
930 } else {
931 g.attention_kind(q, k, v, nh, dh, MaskKind::None, out_shape.clone())
932 };
933 linear(
934 g,
935 params,
936 &format!("{prefix}.o"),
937 attn,
938 &w.out_w,
939 &w.out_b,
940 id,
941 d,
942 q_len,
943 )
944}
945
946fn build_self_proj_graph(
947 layer: &Sam2MemoryAttentionLayerWeights,
948 n_img: usize,
949) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
950 let d = layer.d_model;
951 let f = DType::F32;
952 let mut g = Graph::new("sam2_mem_self_proj");
953 let mut params = HashMap::new();
954 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
955 let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
956 let tgt2 = layer_norm(
957 &mut g,
958 &mut params,
959 "n1",
960 tgt,
961 &layer.norm1_g,
962 &layer.norm1_b,
963 n_img,
964 d,
965 );
966 let q_in = maybe_add_pos(&mut g, tgt2, curr_pos, n_img, d, layer.pos_enc_at_attn);
967 let (sa_q, sa_k, sa_v) = build_qkv_proj(
968 &mut g,
969 &mut params,
970 "sa",
971 &layer.self_attn,
972 q_in,
973 tgt2,
974 tgt2,
975 n_img,
976 n_img,
977 d,
978 d,
979 );
980 g.set_outputs(vec![sa_q, sa_k, sa_v]);
981 Ok((g, params))
982}
983
984fn build_self_attn_graph(
985 layer: &Sam2MemoryAttentionLayerWeights,
986 n_img: usize,
987) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
988 let d = layer.d_model;
989 let f = DType::F32;
990 let mut g = Graph::new("sam2_mem_self_attn");
991 let mut params = HashMap::new();
992 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
993 let sa_q = g.input(
994 "sa_q",
995 Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
996 );
997 let sa_k = g.input(
998 "sa_k",
999 Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
1000 );
1001 let sa_v = g.input(
1002 "sa_v",
1003 Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
1004 );
1005 let sa = build_attention_out(
1006 &mut g,
1007 &mut params,
1008 "sa",
1009 &layer.self_attn,
1010 sa_q,
1011 sa_k,
1012 sa_v,
1013 n_img,
1014 n_img,
1015 None,
1016 );
1017 let out = g.binary(BinaryOp::Add, tgt, sa, Shape::new(&[1, n_img, d], f));
1018 g.set_outputs(vec![out]);
1019 Ok((g, params))
1020}
1021
1022fn build_cross_proj_graph(
1023 layer: &Sam2MemoryAttentionLayerWeights,
1024 n_img: usize,
1025 n_mem: usize,
1026 kv: usize,
1027) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1028 let d = layer.d_model;
1029 let f = DType::F32;
1030 let mut g = Graph::new("sam2_mem_cross_proj");
1031 let mut params = HashMap::new();
1032 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1033 let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
1034 let memory = g.input("memory", Shape::new(&[1, n_mem, kv], f));
1035 let memory_pos = g.input("memory_pos", Shape::new(&[1, n_mem, kv], f));
1036 let tgt2 = layer_norm(
1037 &mut g,
1038 &mut params,
1039 "n2",
1040 tgt,
1041 &layer.norm2_g,
1042 &layer.norm2_b,
1043 n_img,
1044 d,
1045 );
1046 let q_in = maybe_add_pos(
1047 &mut g,
1048 tgt2,
1049 curr_pos,
1050 n_img,
1051 d,
1052 layer.pos_enc_at_cross_attn_queries,
1053 );
1054 let k_in = maybe_add_pos(
1055 &mut g,
1056 memory,
1057 memory_pos,
1058 n_mem,
1059 kv,
1060 layer.pos_enc_at_cross_attn_keys,
1061 );
1062 let (ca_q, ca_k, _) = build_qkv_proj(
1063 &mut g,
1064 &mut params,
1065 "ca",
1066 &layer.cross_attn,
1067 q_in,
1068 k_in,
1069 memory,
1070 n_img,
1071 n_mem,
1072 d,
1073 kv,
1074 );
1075 g.set_outputs(vec![ca_q, ca_k]);
1076 Ok((g, params))
1077}
1078
1079fn build_cross_attn_graph(
1080 layer: &Sam2MemoryAttentionLayerWeights,
1081 n_img: usize,
1082 n_mem: usize,
1083 kv: usize,
1084) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1085 let d = layer.d_model;
1086 let f = DType::F32;
1087 let mut g = Graph::new("sam2_mem_cross_attn");
1088 let mut params = HashMap::new();
1089 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1090 let ca_q = g.input(
1091 "ca_q",
1092 Shape::new(&[1, n_img, layer.cross_attn.internal_dim], f),
1093 );
1094 let ca_k = g.input(
1095 "ca_k",
1096 Shape::new(&[1, n_mem, layer.cross_attn.internal_dim], f),
1097 );
1098 let memory = g.input("memory", Shape::new(&[1, n_mem, kv], f));
1099 let mask_ca = g.input(
1100 "mask_ca",
1101 Shape::new(&[1, layer.cross_attn.num_heads, n_img, n_mem], f),
1102 );
1103 let ca = build_attention_out(
1104 &mut g,
1105 &mut params,
1106 "ca",
1107 &layer.cross_attn,
1108 ca_q,
1109 ca_k,
1110 memory,
1111 n_img,
1112 n_mem,
1113 Some(mask_ca),
1114 );
1115 let out = g.binary(BinaryOp::Add, tgt, ca, Shape::new(&[1, n_img, d], f));
1116 g.set_outputs(vec![out]);
1117 Ok((g, params))
1118}
1119
1120fn build_ffn_graph(
1121 layer: &Sam2MemoryAttentionLayerWeights,
1122 n_img: usize,
1123) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1124 let d = layer.d_model;
1125 let f = DType::F32;
1126 let mut g = Graph::new("sam2_mem_ffn");
1127 let mut params = HashMap::new();
1128 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1129 let seq_shape = Shape::new(&[1, n_img, d], f);
1130 let normed = layer_norm(
1131 &mut g,
1132 &mut params,
1133 "n3",
1134 tgt,
1135 &layer.norm3_g,
1136 &layer.norm3_b,
1137 n_img,
1138 d,
1139 );
1140 let dim_ff = layer.linear1_b.len();
1141 let mid = linear(
1142 &mut g,
1143 &mut params,
1144 "ff1",
1145 normed,
1146 &layer.linear1_w,
1147 &layer.linear1_b,
1148 d,
1149 dim_ff,
1150 n_img,
1151 );
1152 let mid = g.activation(Activation::Relu, mid, Shape::new(&[1, n_img, dim_ff], f));
1153 let down = linear(
1154 &mut g,
1155 &mut params,
1156 "ff2",
1157 mid,
1158 &layer.linear2_w,
1159 &layer.linear2_b,
1160 dim_ff,
1161 d,
1162 n_img,
1163 );
1164 let out = g.binary(BinaryOp::Add, tgt, down, seq_shape);
1165 g.set_outputs(vec![out]);
1166 Ok((g, params))
1167}
1168
1169fn build_final_norm_graph(
1170 norm_g: &[f32],
1171 norm_b: &[f32],
1172 n_img: usize,
1173 d: usize,
1174) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1175 let f = DType::F32;
1176 let mut g = Graph::new("sam2_mem_attn_final");
1177 let mut params = HashMap::new();
1178 let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1179 let out = layer_norm(
1180 &mut g,
1181 &mut params,
1182 "out_norm",
1183 tgt,
1184 norm_g,
1185 norm_b,
1186 n_img,
1187 d,
1188 );
1189 g.set_outputs(vec![out]);
1190 Ok((g, params))
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195 use super::*;
1196 use crate::axial_rope::apply_axial_rope_2d;
1197 use crate::memory_attention::{
1198 Sam2MemoryAttentionLayerWeights, Sam2MemoryAttentionWeights, Sam2RoPEAttnWeights,
1199 memory_attention_forward, memory_attention_layer_forward,
1200 };
1201 use crate::transformer::layer_norm_last_cpu;
1202 use rlx_ir::Graph;
1203
1204 #[test]
1205 fn axial_rope2d_op_matches_host_merged_layout() {
1206 let nh = 1usize;
1207 let n = 64usize;
1208 let dh = 256usize;
1209 let feat = [8usize, 8usize];
1210 let x: Vec<f32> = (0..n * nh * dh).map(|i| i as f32 * 0.001).collect();
1211 let host = apply_axial_rope_2d(&x, nh, n, dh, feat[0], feat[1], 10000.0, 1);
1212
1213 let mut g = Graph::new("axial_rope_check");
1214 let f = rlx_ir::DType::F32;
1215 let inp = g.input("x", Shape::new(&[1, n, nh * dh], f));
1216 let out = g.axial_rope2d(inp, feat[0], feat[1], dh, nh, 10000.0, 1);
1217 g.set_outputs(vec![out]);
1218 let mut compiled =
1219 rlx_core::flow_bridge::compile_graph_sam(Device::Cpu, g).expect("compile");
1220 let ir = compiled.run(&[("x", &x)]).into_iter().next().unwrap();
1221
1222 let fd = host
1223 .iter()
1224 .zip(&ir)
1225 .map(|(a, b)| (a - b).abs())
1226 .fold(0f32, f32::max);
1227 assert!(fd < 1e-5, "axial_rope2d op vs host max |Δ| = {fd:.3e}");
1228 }
1229
1230 fn synth_rope_attn(d: usize, kv: usize, feat: [usize; 2]) -> Sam2RoPEAttnWeights {
1231 let id = d;
1232 Sam2RoPEAttnWeights {
1233 q_w: vec![0.01; id * d],
1234 q_b: vec![0.0; id],
1235 k_w: vec![0.02; id * kv],
1236 k_b: vec![0.0; id],
1237 v_w: vec![0.03; id * kv],
1238 v_b: vec![0.0; id],
1239 out_w: vec![0.04; d * id],
1240 out_b: vec![0.0; d],
1241 embedding_dim: d,
1242 kv_in_dim: kv,
1243 internal_dim: id,
1244 num_heads: 1,
1245 rope_theta: 10000.0,
1246 rope_feat_size: feat,
1247 rope_k_repeat: true,
1248 }
1249 }
1250
1251 #[test]
1252 fn memory_attention_ir_matches_host_small_grid() {
1253 let d = 256usize;
1254 let kv = 64usize;
1255 let feat = [8usize, 8usize];
1256 let n_img = 64usize;
1257 let n_mem = 64usize;
1258 let layer = Sam2MemoryAttentionLayerWeights {
1259 self_attn: synth_rope_attn(d, d, feat),
1260 cross_attn: synth_rope_attn(d, kv, feat),
1261 norm1_g: vec![1.0; d],
1262 norm1_b: vec![0.0; d],
1263 norm2_g: vec![1.0; d],
1264 norm2_b: vec![0.0; d],
1265 norm3_g: vec![1.0; d],
1266 norm3_b: vec![0.0; d],
1267 linear1_w: vec![0.01; 2048 * d],
1268 linear1_b: vec![0.0; 2048],
1269 linear2_w: vec![0.02; d * 2048],
1270 linear2_b: vec![0.0; d],
1271 pos_enc_at_attn: false,
1272 pos_enc_at_cross_attn_queries: false,
1273 pos_enc_at_cross_attn_keys: true,
1274 d_model: d,
1275 };
1276 let w = Sam2MemoryAttentionWeights {
1277 layers: vec![layer],
1278 norm_g: vec![1.0; d],
1279 norm_b: vec![0.0; d],
1280 d_model: d,
1281 pos_enc_at_input: true,
1282 };
1283
1284 let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1285 let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1286 let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1287 let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1288
1289 let host = memory_attention_forward(
1290 &w,
1291 &curr,
1292 &curr_pos,
1293 &memory,
1294 &memory_pos,
1295 n_img,
1296 n_mem,
1297 kv,
1298 0,
1299 )
1300 .unwrap();
1301
1302 let mut ir = MemoryAttentionCompiled::compile(&w, n_img, n_mem, 0, Device::Cpu).unwrap();
1303 let got = ir
1304 .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1305 .unwrap();
1306
1307 let fd = host
1308 .iter()
1309 .zip(&got)
1310 .map(|(a, b)| (a - b).abs())
1311 .fold(0f32, f32::max);
1312 assert!(fd < 3e-2, "memory attention max |Δ| = {fd:.3e}");
1313 }
1314
1315 #[test]
1316 fn memory_attention_in_graph_rope_matches_host_small_grid() {
1317 let d = 256usize;
1318 let kv = 64usize;
1319 let feat = [8usize, 8usize];
1320 let n_img = 64usize;
1321 let n_mem = 64usize;
1322 let layer = Sam2MemoryAttentionLayerWeights {
1323 self_attn: synth_rope_attn(d, d, feat),
1324 cross_attn: synth_rope_attn(d, kv, feat),
1325 norm1_g: vec![1.0; d],
1326 norm1_b: vec![0.0; d],
1327 norm2_g: vec![1.0; d],
1328 norm2_b: vec![0.0; d],
1329 norm3_g: vec![1.0; d],
1330 norm3_b: vec![0.0; d],
1331 linear1_w: vec![0.01; 2048 * d],
1332 linear1_b: vec![0.0; 2048],
1333 linear2_w: vec![0.02; d * 2048],
1334 linear2_b: vec![0.0; d],
1335 pos_enc_at_attn: false,
1336 pos_enc_at_cross_attn_queries: false,
1337 pos_enc_at_cross_attn_keys: true,
1338 d_model: d,
1339 };
1340 let w = Sam2MemoryAttentionWeights {
1341 layers: vec![layer],
1342 norm_g: vec![1.0; d],
1343 norm_b: vec![0.0; d],
1344 d_model: d,
1345 pos_enc_at_input: true,
1346 };
1347
1348 let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1349 let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1350 let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1351 let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1352
1353 let host_mid = crate::memory_attention::memory_attention_forward_layers_only(
1354 &w,
1355 &curr,
1356 &curr_pos,
1357 &memory,
1358 &memory_pos,
1359 n_img,
1360 n_mem,
1361 kv,
1362 0,
1363 )
1364 .unwrap();
1365
1366 let mut ir =
1367 MemoryAttentionCompiled::compile_in_graph_rope(&w, n_img, n_mem, 0, Device::Cpu)
1368 .unwrap();
1369 let got = ir
1370 .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1371 .unwrap();
1372
1373 let mut mem_pad = vec![0f32; n_mem * kv];
1374 let mut mem_pos_pad = vec![0f32; n_mem * kv];
1375 mem_pad.copy_from_slice(&memory);
1376 mem_pos_pad.copy_from_slice(&memory_pos);
1377 let mut tgt = curr.clone();
1378 if w.pos_enc_at_input {
1379 for i in 0..tgt.len() {
1380 tgt[i] += INPUT_POS_SCALE * curr_pos[i];
1381 }
1382 }
1383 let nh = w.layers[0].cross_attn.num_heads;
1384 let mut mask = vec![0f32; nh * n_img * n_mem];
1385 fill_cross_attn_bias(&mut mask, nh, n_img, n_mem, n_mem);
1386 let layer_inputs = [
1387 ("tgt", tgt.as_slice()),
1388 ("curr_pos", curr_pos.as_slice()),
1389 ("memory", mem_pad.as_slice()),
1390 ("memory_pos", mem_pos_pad.as_slice()),
1391 ("mask_ca", mask.as_slice()),
1392 ];
1393 let ir_mid = ir.layers[0]
1394 .fused
1395 .as_mut()
1396 .expect("fused")
1397 .run(&layer_inputs)
1398 .into_iter()
1399 .next()
1400 .unwrap();
1401 let fd_layer = host_mid
1402 .iter()
1403 .zip(&ir_mid)
1404 .map(|(a, b)| (a - b).abs())
1405 .fold(0f32, f32::max);
1406 assert!(fd_layer < 5e-2, "in-graph layer max |Δ| = {fd_layer:.3e}");
1407
1408 let (fg, fp) = build_final_norm_graph(&w.norm_g, &w.norm_b, n_img, d).unwrap();
1409 let mut fn_alone =
1410 Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1411 for (n, data) in &fp {
1412 fn_alone.set_param(n, data);
1413 }
1414 let ir_via_fn = fn_alone
1415 .run(&[("tgt", &ir_mid)])
1416 .into_iter()
1417 .next()
1418 .unwrap();
1419 let fd_got_fn = got
1420 .iter()
1421 .zip(&ir_via_fn)
1422 .map(|(a, b)| (a - b).abs())
1423 .fold(0f32, f32::max);
1424 assert!(
1425 fd_got_fn < 1e-4,
1426 "pipeline output vs final_norm(ir_mid) max |Δ| = {fd_got_fn:.3e}"
1427 );
1428 }
1431
1432 #[test]
1433 fn cpu_layer_norm_row_matches_host_last_cpu() {
1434 let rows = 64usize;
1435 let h = 256usize;
1436 let x: Vec<f32> = (0..rows * h).map(|i| (i as f32) * 1e-3 - 0.5).collect();
1437 let g = vec![1.0; h];
1438 let b = vec![0.0; h];
1439 let mut host = x.clone();
1440 layer_norm_last_cpu(&mut host, rows, h, &g, &b, LN_EPS);
1441 let mut cpu = x.clone();
1442 for r in 0..rows {
1443 rlx_cpu::kernels::layer_norm_row(
1444 &x[r * h..(r + 1) * h],
1445 &g,
1446 &b,
1447 &mut cpu[r * h..(r + 1) * h],
1448 h,
1449 LN_EPS,
1450 );
1451 }
1452 let fd = host
1453 .iter()
1454 .zip(&cpu)
1455 .map(|(a, b)| (a - b).abs())
1456 .fold(0f32, f32::max);
1457 assert!(
1458 fd < 1e-5,
1459 "cpu layer_norm_row vs host_last_cpu max |Δ| = {fd:.3e}"
1460 );
1461 }
1462
1463 #[test]
1464 fn layer_norm_ir_matches_host_synthetic() {
1465 let n_img = 64usize;
1466 let d = 256usize;
1467 let x: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-3 - 0.5).collect();
1468 let norm_g = vec![1.0; d];
1469 let norm_b = vec![0.0; d];
1470 let mut host = x.clone();
1471 layer_norm_last_cpu(&mut host, n_img, d, &norm_g, &norm_b, LN_EPS);
1472 let (fg, fp) = build_final_norm_graph(&norm_g, &norm_b, n_img, d).unwrap();
1473 let mut compiled =
1474 Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1475 for (n, data) in &fp {
1476 compiled.set_param(n, data);
1477 }
1478 let ir = compiled.run(&[("tgt", &x)]).into_iter().next().unwrap();
1479 let fd = host
1480 .iter()
1481 .zip(&ir)
1482 .map(|(a, b)| (a - b).abs())
1483 .fold(0f32, f32::max);
1484 assert!(
1485 fd < 1e-4,
1486 "synthetic final norm IR vs host max |Δ| = {fd:.3e}"
1487 );
1488 }
1489
1490 #[test]
1491 fn stack_final_norm_ir_matches_host_layer_output() {
1492 let d = 256usize;
1493 let n_img = 64usize;
1494 let layer_tgt: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1495 let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1496 let memory: Vec<f32> = (0..n_img * 64).map(|i| (i as f32) * 2e-4).collect();
1497 let memory_pos: Vec<f32> = (0..n_img * 64).map(|i| (i as f32) * 2e-5).collect();
1498 let layer = Sam2MemoryAttentionLayerWeights {
1499 self_attn: synth_rope_attn(d, d, [8, 8]),
1500 cross_attn: synth_rope_attn(d, 64, [8, 8]),
1501 norm1_g: vec![1.0; d],
1502 norm1_b: vec![0.0; d],
1503 norm2_g: vec![1.0; d],
1504 norm2_b: vec![0.0; d],
1505 norm3_g: vec![1.0; d],
1506 norm3_b: vec![0.0; d],
1507 linear1_w: vec![0.01; 2048 * d],
1508 linear1_b: vec![0.0; 2048],
1509 linear2_w: vec![0.02; d * 2048],
1510 linear2_b: vec![0.0; d],
1511 pos_enc_at_attn: false,
1512 pos_enc_at_cross_attn_queries: false,
1513 pos_enc_at_cross_attn_keys: true,
1514 d_model: d,
1515 };
1516 let host_layer = memory_attention_layer_forward(
1517 &layer,
1518 layer_tgt,
1519 &curr_pos,
1520 &memory,
1521 &memory_pos,
1522 n_img,
1523 n_img,
1524 64,
1525 0,
1526 )
1527 .unwrap();
1528 let mut host_final = host_layer.clone();
1529 let norm_g = vec![1.0; d];
1530 let norm_b = vec![0.0; d];
1531 layer_norm_last_cpu(&mut host_final, n_img, d, &norm_g, &norm_b, LN_EPS);
1532
1533 let (fg, fp) = build_final_norm_graph(&norm_g, &norm_b, n_img, d).unwrap();
1534 let mut compiled =
1535 Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1536 for (n, data) in &fp {
1537 compiled.set_param(n, data);
1538 }
1539 let ir_final = compiled
1540 .run(&[("tgt", &host_layer)])
1541 .into_iter()
1542 .next()
1543 .unwrap();
1544 let fd = host_final
1545 .iter()
1546 .zip(&ir_final)
1547 .map(|(a, b)| (a - b).abs())
1548 .fold(0f32, f32::max);
1549 assert!(fd < 1e-4, "stack final norm IR vs host max |Δ| = {fd:.3e}");
1550 }
1551
1552 #[test]
1554 fn memory_attention_layer_in_graph_rope_bisect() {
1555 let d = 256usize;
1556 let kv = 64usize;
1557 let feat = [8usize, 8usize];
1558 let n_img = 64usize;
1559 let n_mem = 64usize;
1560 let layer = Sam2MemoryAttentionLayerWeights {
1561 self_attn: synth_rope_attn(d, d, feat),
1562 cross_attn: synth_rope_attn(d, kv, feat),
1563 norm1_g: vec![1.0; d],
1564 norm1_b: vec![0.0; d],
1565 norm2_g: vec![1.0; d],
1566 norm2_b: vec![0.0; d],
1567 norm3_g: vec![1.0; d],
1568 norm3_b: vec![0.0; d],
1569 linear1_w: vec![0.01; 2048 * d],
1570 linear1_b: vec![0.0; 2048],
1571 linear2_w: vec![0.02; d * 2048],
1572 linear2_b: vec![0.0; d],
1573 pos_enc_at_attn: false,
1574 pos_enc_at_cross_attn_queries: false,
1575 pos_enc_at_cross_attn_keys: true,
1576 d_model: d,
1577 };
1578
1579 let mut tgt: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1580 for i in 0..tgt.len() {
1581 tgt[i] += INPUT_POS_SCALE * (i as f32) * 1e-5;
1582 }
1583 let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1584 let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1585 let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1586
1587 let host = crate::memory_attention::memory_attention_layer_forward(
1588 &layer,
1589 tgt.clone(),
1590 &curr_pos,
1591 &memory,
1592 &memory_pos,
1593 n_img,
1594 n_mem,
1595 kv,
1596 0,
1597 )
1598 .unwrap();
1599
1600 let (g, p) = build_layer_graph(&layer, n_img, n_mem, kv, 0).unwrap();
1601 let nh = layer.cross_attn.num_heads;
1602 let mut mask = vec![0f32; nh * n_img * n_mem];
1603 fill_cross_attn_bias(&mut mask, nh, n_img, n_mem, n_mem);
1604 let mut compiled =
1605 Session::new(Device::Cpu).compile_with(g, &compile_opts_no_fusion(Device::Cpu));
1606 for (n, data) in &p {
1607 compiled.set_param(n, data);
1608 }
1609 let got = compiled
1610 .run(&[
1611 ("tgt", &tgt),
1612 ("curr_pos", &curr_pos),
1613 ("memory", &memory),
1614 ("memory_pos", &memory_pos),
1615 ("mask_ca", &mask),
1616 ])
1617 .into_iter()
1618 .next()
1619 .unwrap();
1620
1621 let fd = host
1622 .iter()
1623 .zip(&got)
1624 .map(|(a, b)| (a - b).abs())
1625 .fold(0f32, f32::max);
1626 assert!(fd < 3e-2, "layer in-graph rope max |Δ| = {fd:.3e}");
1627 }
1628
1629 #[test]
1631 fn memory_attention_in_graph_rope_timing_quick_check() {
1632 use std::time::Instant;
1633
1634 let d = 256usize;
1635 let kv = 64usize;
1636 let feat = [8usize, 8usize];
1637 let n_img = 64usize;
1638 let n_mem = 64usize;
1639 let layer = Sam2MemoryAttentionLayerWeights {
1640 self_attn: synth_rope_attn(d, d, feat),
1641 cross_attn: synth_rope_attn(d, kv, feat),
1642 norm1_g: vec![1.0; d],
1643 norm1_b: vec![0.0; d],
1644 norm2_g: vec![1.0; d],
1645 norm2_b: vec![0.0; d],
1646 norm3_g: vec![1.0; d],
1647 norm3_b: vec![0.0; d],
1648 linear1_w: vec![0.01; 2048 * d],
1649 linear1_b: vec![0.0; 2048],
1650 linear2_w: vec![0.02; d * 2048],
1651 linear2_b: vec![0.0; d],
1652 pos_enc_at_attn: false,
1653 pos_enc_at_cross_attn_queries: false,
1654 pos_enc_at_cross_attn_keys: true,
1655 d_model: d,
1656 };
1657 let w = Sam2MemoryAttentionWeights {
1658 layers: vec![layer],
1659 norm_g: vec![1.0; d],
1660 norm_b: vec![0.0; d],
1661 d_model: d,
1662 pos_enc_at_input: true,
1663 };
1664 let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1665 let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1666 let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1667 let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1668
1669 let t0 = Instant::now();
1670 let mut default =
1671 MemoryAttentionCompiled::compile(&w, n_img, n_mem, 0, Device::Cpu).unwrap();
1672 let compile_default_ms = t0.elapsed().as_secs_f64() * 1000.0;
1673
1674 let t1 = Instant::now();
1675 let mut in_graph =
1676 MemoryAttentionCompiled::compile_in_graph_rope(&w, n_img, n_mem, 0, Device::Cpu)
1677 .unwrap();
1678 let compile_in_graph_ms = t1.elapsed().as_secs_f64() * 1000.0;
1679
1680 const RUNS: usize = 5;
1681 let t2 = Instant::now();
1682 for _ in 0..RUNS {
1683 let _ = default
1684 .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1685 .unwrap();
1686 }
1687 let run_default_ms = t2.elapsed().as_secs_f64() * 1000.0 / RUNS as f64;
1688
1689 let t3 = Instant::now();
1690 for _ in 0..RUNS {
1691 let _ = in_graph
1692 .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1693 .unwrap();
1694 }
1695 let run_in_graph_ms = t3.elapsed().as_secs_f64() * 1000.0 / RUNS as f64;
1696
1697 eprintln!(
1698 "mem_attn compile ms: default={compile_default_ms:.2} in_graph={compile_in_graph_ms:.2}; \
1699 run ms (avg/{RUNS}): default={run_default_ms:.2} in_graph={run_in_graph_ms:.2}"
1700 );
1701 assert!(compile_in_graph_ms > 0.0 && run_in_graph_ms > 0.0);
1702 }
1703}