Skip to main content

rlx_sam2/
memory_attention_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM2 memory attention IR.
17//!
18//! Default: five compiled subgraphs per layer with host axial RoPE between Q/K
19//! projection and attention (best parity today).
20//!
21//! [`MemoryAttentionCompiled::compile_in_graph_rope`] fuses each layer into one
22//! graph using [`Op::AxialRope2d`] (faster compile/run). Both paths share a
23//! compiled stack final-norm subgraph (`skip_fusion` to avoid bad LN fusion).
24
25use super::axial_rope::apply_axial_rope_2d;
26use super::memory_attention::{
27    Sam2MemoryAttentionLayerWeights, Sam2MemoryAttentionWeights, Sam2RoPEAttnWeights,
28};
29use anyhow::Result;
30use rlx_ir::infer::GraphExt;
31use rlx_ir::op::{Activation, BinaryOp, MaskKind};
32use rlx_ir::{DType, Graph, NodeId, Shape};
33use rlx_runtime::{CompileOptions, CompiledGraph, Device, Session};
34use std::collections::HashMap;
35
36/// How each layer applies axial RoPE relative to attention.
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38enum LayerRopeMode {
39    /// Host `apply_axial_rope_2d` between projection and attention subgraphs.
40    HostBetweenGraphs,
41    /// `Op::AxialRope2d` inside a single compiled layer graph.
42    InGraph,
43}
44
45const LN_EPS: f32 = 1e-5;
46const INPUT_POS_SCALE: f32 = 0.1;
47
48/// Max prior-frame spatial memory banks in cross-attention (SAM2 ≈7 maskmem slots).
49pub const MAX_MEMORY_FRAMES_IN_ATTN: usize = 7;
50
51pub fn max_memory_slots(n_img: usize, max_obj_ptr_tokens: usize) -> usize {
52    MAX_MEMORY_FRAMES_IN_ATTN * n_img + max_obj_ptr_tokens
53}
54
55struct MemoryAttentionLayerCompiled {
56    mode: LayerRopeMode,
57    /// Fused layer graph when `mode == InGraph`.
58    fused: Option<CompiledGraph>,
59    self_proj: Option<CompiledGraph>,
60    self_attn: Option<CompiledGraph>,
61    cross_proj: Option<CompiledGraph>,
62    cross_attn: Option<CompiledGraph>,
63    ffn: Option<CompiledGraph>,
64    layer: Sam2MemoryAttentionLayerWeights,
65}
66
67pub struct MemoryAttentionCompiled {
68    layers: Vec<MemoryAttentionLayerCompiled>,
69    /// Stack final norm (compiled; `skip_fusion` at compile).
70    final_norm: CompiledGraph,
71    pub n_img: usize,
72    pub max_n_mem: usize,
73    pub d_model: usize,
74    pub kv_in_dim: usize,
75    pub max_obj_ptr_tokens: usize,
76    pos_enc_at_input: bool,
77}
78
79impl MemoryAttentionCompiled {
80    pub fn compile(
81        w: &Sam2MemoryAttentionWeights,
82        n_img: usize,
83        max_n_mem: usize,
84        max_obj_ptr_tokens: usize,
85        device: Device,
86    ) -> Result<Self> {
87        Self::compile_with_profile(
88            w,
89            n_img,
90            max_n_mem,
91            max_obj_ptr_tokens,
92            device,
93            &rlx_flow::CompileProfile::sam_encoder(),
94        )
95    }
96
97    pub fn compile_with_profile(
98        w: &Sam2MemoryAttentionWeights,
99        n_img: usize,
100        max_n_mem: usize,
101        max_obj_ptr_tokens: usize,
102        device: Device,
103        profile: &rlx_flow::CompileProfile,
104    ) -> Result<Self> {
105        Self::compile_with_mode(
106            w,
107            n_img,
108            max_n_mem,
109            max_obj_ptr_tokens,
110            device,
111            LayerRopeMode::HostBetweenGraphs,
112            profile,
113        )
114    }
115
116    /// One compiled graph per layer (`Op::AxialRope2d` in-graph). Fusion is disabled
117    /// so `FuseAttentionBlock` cannot reorder RoPE relative to attention.
118    pub fn compile_in_graph_rope(
119        w: &Sam2MemoryAttentionWeights,
120        n_img: usize,
121        max_n_mem: usize,
122        max_obj_ptr_tokens: usize,
123        device: Device,
124    ) -> Result<Self> {
125        Self::compile_in_graph_rope_with_profile(
126            w,
127            n_img,
128            max_n_mem,
129            max_obj_ptr_tokens,
130            device,
131            &rlx_flow::CompileProfile::sam_encoder(),
132        )
133    }
134
135    pub fn compile_in_graph_rope_with_profile(
136        w: &Sam2MemoryAttentionWeights,
137        n_img: usize,
138        max_n_mem: usize,
139        max_obj_ptr_tokens: usize,
140        device: Device,
141        profile: &rlx_flow::CompileProfile,
142    ) -> Result<Self> {
143        Self::compile_with_mode(
144            w,
145            n_img,
146            max_n_mem,
147            max_obj_ptr_tokens,
148            device,
149            LayerRopeMode::InGraph,
150            profile,
151        )
152    }
153
154    fn compile_with_mode(
155        w: &Sam2MemoryAttentionWeights,
156        n_img: usize,
157        max_n_mem: usize,
158        max_obj_ptr_tokens: usize,
159        device: Device,
160        mode: LayerRopeMode,
161        profile: &rlx_flow::CompileProfile,
162    ) -> Result<Self> {
163        anyhow::ensure!(
164            w.layers
165                .iter()
166                .all(|l| l.self_attn.num_heads == 1 && l.cross_attn.num_heads == 1),
167            "memory_attention_ir currently requires num_heads=1"
168        );
169        let kv = w.layers[0].cross_attn.kv_in_dim;
170        let mut layers = Vec::with_capacity(w.layers.len());
171        for layer in &w.layers {
172            layers.push(compile_layer(
173                layer,
174                n_img,
175                max_n_mem,
176                kv,
177                max_obj_ptr_tokens,
178                device,
179                mode,
180                profile,
181            )?);
182        }
183        let (fn_g, fn_p) = build_final_norm_graph(&w.norm_g, &w.norm_b, n_img, w.d_model)?;
184        let mut final_norm =
185            Session::new(device).compile_with(fn_g, &compile_opts_no_fusion(device));
186        for (n, d) in &fn_p {
187            final_norm.set_param(n, d);
188        }
189        Ok(Self {
190            layers,
191            final_norm,
192            n_img,
193            max_n_mem,
194            d_model: w.d_model,
195            kv_in_dim: kv,
196            max_obj_ptr_tokens,
197            pos_enc_at_input: w.pos_enc_at_input,
198        })
199    }
200
201    pub fn run(
202        &mut self,
203        curr: &[f32],
204        curr_pos: &[f32],
205        memory: &[f32],
206        memory_pos: &[f32],
207        active_n_mem: usize,
208        num_obj_ptr_tokens: usize,
209    ) -> Result<Vec<f32>> {
210        let d = self.d_model;
211        let kv = self.kv_in_dim;
212        anyhow::ensure!(curr.len() == self.n_img * d);
213        anyhow::ensure!(curr_pos.len() == self.n_img * d);
214        anyhow::ensure!(memory.len() >= active_n_mem * kv);
215        anyhow::ensure!(memory_pos.len() >= active_n_mem * kv);
216        anyhow::ensure!(active_n_mem <= self.max_n_mem);
217        anyhow::ensure!(num_obj_ptr_tokens <= self.max_obj_ptr_tokens);
218
219        let mut tgt = curr.to_vec();
220        if self.pos_enc_at_input {
221            for i in 0..tgt.len() {
222                tgt[i] += INPUT_POS_SCALE * curr_pos[i];
223            }
224        }
225
226        let mut mem_pad = vec![0f32; self.max_n_mem * kv];
227        let mut mem_pos_pad = vec![0f32; self.max_n_mem * kv];
228        mem_pad[..active_n_mem * kv].copy_from_slice(&memory[..active_n_mem * kv]);
229        mem_pos_pad[..active_n_mem * kv].copy_from_slice(&memory_pos[..active_n_mem * kv]);
230
231        let nh = 1usize;
232        let mut mask = vec![0f32; nh * self.n_img * self.max_n_mem];
233        fill_cross_attn_bias(&mut mask, nh, self.n_img, self.max_n_mem, active_n_mem);
234
235        for layer in &mut self.layers {
236            tgt = match layer.mode {
237                LayerRopeMode::InGraph => layer
238                    .fused
239                    .as_mut()
240                    .expect("fused layer")
241                    .run(&[
242                        ("tgt", &tgt),
243                        ("curr_pos", curr_pos),
244                        ("memory", &mem_pad),
245                        ("memory_pos", &mem_pos_pad),
246                        ("mask_ca", &mask),
247                    ])
248                    .into_iter()
249                    .next()
250                    .expect("fused layer output"),
251                LayerRopeMode::HostBetweenGraphs => layer.run_host_between(
252                    &tgt,
253                    curr_pos,
254                    &mem_pad,
255                    &mem_pos_pad,
256                    active_n_mem,
257                    num_obj_ptr_tokens,
258                )?,
259            };
260        }
261
262        let outs = self.final_norm.run(&[("tgt", &tgt)]);
263        Ok(outs.into_iter().next().expect("memory_attention output"))
264    }
265}
266
267impl MemoryAttentionLayerCompiled {
268    fn run_host_between(
269        &mut self,
270        tgt: &[f32],
271        curr_pos: &[f32],
272        memory: &[f32],
273        memory_pos: &[f32],
274        active_n_mem: usize,
275        num_obj_ptr_tokens: usize,
276    ) -> Result<Vec<f32>> {
277        let d = self.layer.d_model;
278        let kv = self.layer.cross_attn.kv_in_dim;
279        let n_img = tgt.len() / d;
280        let max_n_mem = memory.len() / kv;
281        let _id = self.layer.self_attn.internal_dim;
282
283        let p = self
284            .self_proj
285            .as_mut()
286            .expect("self_proj")
287            .run(&[("tgt", tgt), ("curr_pos", curr_pos)]);
288        let mut it = p.into_iter();
289        let mut sa_q = it.next().expect("sa_q");
290        let mut sa_k = it.next().expect("sa_k");
291        let sa_v = it.next().expect("sa_v");
292        host_rotate_qk(&mut sa_q, n_img, &self.layer.self_attn);
293        host_rotate_qk(&mut sa_k, n_img, &self.layer.self_attn);
294
295        let mut tgt = self
296            .self_attn
297            .as_mut()
298            .expect("self_attn")
299            .run(&[
300                ("tgt", tgt),
301                ("sa_q", &sa_q),
302                ("sa_k", &sa_k),
303                ("sa_v", &sa_v),
304            ])
305            .into_iter()
306            .next()
307            .expect("tgt after self");
308
309        let c = self.cross_proj.as_mut().expect("cross_proj").run(&[
310            ("tgt", &tgt),
311            ("curr_pos", curr_pos),
312            ("memory", memory),
313            ("memory_pos", memory_pos),
314        ]);
315        let mut it = c.into_iter();
316        let mut ca_q = it.next().expect("ca_q");
317        let mut ca_k = it.next().expect("ca_k");
318        host_rotate_qk(&mut ca_q, n_img, &self.layer.cross_attn);
319        host_rotate_k_partial(
320            &mut ca_k,
321            max_n_mem,
322            active_n_mem,
323            num_obj_ptr_tokens,
324            &self.layer.cross_attn,
325        );
326
327        let nh = self.layer.cross_attn.num_heads;
328        let mut mask = vec![0f32; nh * n_img * max_n_mem];
329        fill_cross_attn_bias(&mut mask, nh, n_img, max_n_mem, active_n_mem);
330
331        tgt = self
332            .cross_attn
333            .as_mut()
334            .expect("cross_attn")
335            .run(&[
336                ("tgt", &tgt),
337                ("ca_q", &ca_q),
338                ("ca_k", &ca_k),
339                ("memory", memory),
340                ("mask_ca", &mask),
341            ])
342            .into_iter()
343            .next()
344            .expect("tgt after cross");
345
346        self.ffn
347            .as_mut()
348            .expect("ffn")
349            .run(&[("tgt", &tgt)])
350            .into_iter()
351            .next()
352            .ok_or_else(|| anyhow::anyhow!("ffn output missing"))
353    }
354}
355
356fn compile_opts_no_fusion(device: Device) -> CompileOptions {
357    rlx_core::flow_bridge::compile_options_sam2_memory_attention(device)
358}
359
360fn compile_layer(
361    layer: &Sam2MemoryAttentionLayerWeights,
362    n_img: usize,
363    n_mem: usize,
364    kv: usize,
365    max_obj_ptr_tokens: usize,
366    device: Device,
367    mode: LayerRopeMode,
368    profile: &rlx_flow::CompileProfile,
369) -> Result<MemoryAttentionLayerCompiled> {
370    let compile =
371        |g: Graph, p: HashMap<String, Vec<f32>>, opts: &CompileOptions| -> Result<CompiledGraph> {
372            let mut c = Session::new(device).compile_with(g, opts);
373            for (n, d) in &p {
374                c.set_param(n, d);
375            }
376            Ok(c)
377        };
378    match mode {
379        LayerRopeMode::InGraph => {
380            let opts = compile_opts_no_fusion(device);
381            let (g, p) = build_layer_graph(layer, n_img, n_mem, kv, max_obj_ptr_tokens)?;
382            Ok(MemoryAttentionLayerCompiled {
383                mode,
384                fused: Some(compile(g, p, &opts)?),
385                self_proj: None,
386                self_attn: None,
387                cross_proj: None,
388                cross_attn: None,
389                ffn: None,
390                layer: clone_layer(layer),
391            })
392        }
393        LayerRopeMode::HostBetweenGraphs => {
394            let opts = rlx_core::flow_bridge::compile_options_for_profile(profile, device);
395            let (g1, p1) = build_self_proj_graph(layer, n_img)?;
396            let (g2, p2) = build_self_attn_graph(layer, n_img)?;
397            let (g3, p3) = build_cross_proj_graph(layer, n_img, n_mem, kv)?;
398            let (g4, p4) = build_cross_attn_graph(layer, n_img, n_mem, kv)?;
399            let (g5, p5) = build_ffn_graph(layer, n_img)?;
400            Ok(MemoryAttentionLayerCompiled {
401                mode,
402                fused: None,
403                self_proj: Some(compile(g1, p1, &opts)?),
404                self_attn: Some(compile(g2, p2, &opts)?),
405                cross_proj: Some(compile(g3, p3, &opts)?),
406                cross_attn: Some(compile(g4, p4, &opts)?),
407                ffn: Some(compile(g5, p5, &opts)?),
408                layer: clone_layer(layer),
409            })
410        }
411    }
412}
413
414fn fill_cross_attn_bias(
415    out: &mut [f32],
416    nh: usize,
417    n_img: usize,
418    max_n_mem: usize,
419    active_n_mem: usize,
420) {
421    out.fill(0.0);
422    for h in 0..nh {
423        for qi in 0..n_img {
424            for ki in active_n_mem..max_n_mem {
425                out[(h * n_img + qi) * max_n_mem + ki] = -1e4;
426            }
427        }
428    }
429}
430
431fn host_rotate_qk(seq: &mut [f32], n_tokens: usize, w: &Sam2RoPEAttnWeights) {
432    let nh = w.num_heads;
433    let dh = w.internal_dim / nh;
434    let [ex, ey] = w.rope_feat_size;
435    let out = apply_axial_rope_2d(seq, nh, n_tokens, dh, ex, ey, w.rope_theta, 1);
436    seq.copy_from_slice(&out);
437}
438
439fn host_rotate_k_partial(
440    k: &mut [f32],
441    buf_tokens: usize,
442    active_tokens: usize,
443    num_k_exclude_rope: usize,
444    w: &Sam2RoPEAttnWeights,
445) {
446    let nh = w.num_heads;
447    let dh = w.internal_dim / nh;
448    let [ex, ey] = w.rope_feat_size;
449    let spatial = ex * ey;
450    let num_k_rope = active_tokens.saturating_sub(num_k_exclude_rope);
451    if num_k_rope == 0 {
452        return;
453    }
454    let _ = buf_tokens;
455    let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
456        num_k_rope / spatial
457    } else {
458        1
459    };
460    let prefix_len = nh * num_k_rope * dh;
461    let out = apply_axial_rope_2d(
462        &k[..prefix_len],
463        nh,
464        num_k_rope,
465        dh,
466        ex,
467        ey,
468        w.rope_theta,
469        r,
470    );
471    k[..prefix_len].copy_from_slice(&out);
472}
473
474fn clone_layer(l: &Sam2MemoryAttentionLayerWeights) -> Sam2MemoryAttentionLayerWeights {
475    Sam2MemoryAttentionLayerWeights {
476        self_attn: clone_rope(&l.self_attn),
477        cross_attn: clone_rope(&l.cross_attn),
478        norm1_g: l.norm1_g.clone(),
479        norm1_b: l.norm1_b.clone(),
480        norm2_g: l.norm2_g.clone(),
481        norm2_b: l.norm2_b.clone(),
482        norm3_g: l.norm3_g.clone(),
483        norm3_b: l.norm3_b.clone(),
484        linear1_w: l.linear1_w.clone(),
485        linear1_b: l.linear1_b.clone(),
486        linear2_w: l.linear2_w.clone(),
487        linear2_b: l.linear2_b.clone(),
488        pos_enc_at_attn: l.pos_enc_at_attn,
489        pos_enc_at_cross_attn_queries: l.pos_enc_at_cross_attn_queries,
490        pos_enc_at_cross_attn_keys: l.pos_enc_at_cross_attn_keys,
491        d_model: l.d_model,
492    }
493}
494
495fn clone_rope(w: &Sam2RoPEAttnWeights) -> Sam2RoPEAttnWeights {
496    Sam2RoPEAttnWeights {
497        q_w: w.q_w.clone(),
498        q_b: w.q_b.clone(),
499        k_w: w.k_w.clone(),
500        k_b: w.k_b.clone(),
501        v_w: w.v_w.clone(),
502        v_b: w.v_b.clone(),
503        out_w: w.out_w.clone(),
504        out_b: w.out_b.clone(),
505        embedding_dim: w.embedding_dim,
506        kv_in_dim: w.kv_in_dim,
507        internal_dim: w.internal_dim,
508        num_heads: w.num_heads,
509        rope_theta: w.rope_theta,
510        rope_feat_size: w.rope_feat_size,
511        rope_k_repeat: w.rope_k_repeat,
512    }
513}
514
515fn matmul_weight(w_out_in: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
516    let mut t = vec![0f32; in_d * out_d];
517    for o in 0..out_d {
518        for k in 0..in_d {
519            t[k * out_d + o] = w_out_in[o * in_d + k];
520        }
521    }
522    t
523}
524
525fn bind_linear(
526    g: &mut Graph,
527    params: &mut HashMap<String, Vec<f32>>,
528    prefix: &str,
529    w: &[f32],
530    b: &[f32],
531    in_d: usize,
532    out_d: usize,
533) -> (NodeId, NodeId) {
534    let f = DType::F32;
535    let w_id = g.param(format!("{prefix}.w"), Shape::new(&[in_d, out_d], f));
536    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[out_d], f));
537    params.insert(format!("{prefix}.w"), matmul_weight(w, in_d, out_d));
538    params.insert(format!("{prefix}.b"), b.to_vec());
539    (w_id, b_id)
540}
541
542fn linear(
543    g: &mut Graph,
544    params: &mut HashMap<String, Vec<f32>>,
545    prefix: &str,
546    x: NodeId,
547    w: &[f32],
548    b: &[f32],
549    in_d: usize,
550    out_d: usize,
551    seq: usize,
552) -> NodeId {
553    let f = DType::F32;
554    let (w_id, b_id) = bind_linear(g, params, prefix, w, b, in_d, out_d);
555    g.fused_matmul_bias_act(x, w_id, b_id, None, Shape::new(&[1, seq, out_d], f))
556}
557
558fn bind_ln(
559    g: &mut Graph,
560    params: &mut HashMap<String, Vec<f32>>,
561    prefix: &str,
562    gamm: &[f32],
563    bet: &[f32],
564    e: usize,
565) -> (NodeId, NodeId) {
566    let f = DType::F32;
567    let g_id = g.param(format!("{prefix}.g"), Shape::new(&[e], f));
568    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[e], f));
569    params.insert(format!("{prefix}.g"), gamm.to_vec());
570    params.insert(format!("{prefix}.b"), bet.to_vec());
571    (g_id, b_id)
572}
573
574fn layer_norm(
575    g: &mut Graph,
576    params: &mut HashMap<String, Vec<f32>>,
577    prefix: &str,
578    x: NodeId,
579    gamm: &[f32],
580    bet: &[f32],
581    seq: usize,
582    e: usize,
583) -> NodeId {
584    let f = DType::F32;
585    let shape = Shape::new(&[1, seq, e], f);
586    let (g_id, b_id) = bind_ln(g, params, prefix, gamm, bet, e);
587    g.layer_norm(x, g_id, b_id, -1, LN_EPS, shape)
588}
589
590fn maybe_add_pos(
591    g: &mut Graph,
592    x: NodeId,
593    pos: NodeId,
594    seq: usize,
595    e: usize,
596    enabled: bool,
597) -> NodeId {
598    if enabled {
599        let f = DType::F32;
600        g.binary(BinaryOp::Add, x, pos, Shape::new(&[1, seq, e], f))
601    } else {
602        x
603    }
604}
605
606fn apply_axial_rope_graph(
607    g: &mut Graph,
608    x: NodeId,
609    w: &Sam2RoPEAttnWeights,
610    _seq: usize,
611    repeat_factor: usize,
612) -> NodeId {
613    let nh = w.num_heads;
614    let dh = w.internal_dim / nh;
615    let [ex, ey] = w.rope_feat_size;
616    g.axial_rope2d(x, ex, ey, dh, nh, w.rope_theta, repeat_factor)
617}
618
619fn build_rope_attn(
620    g: &mut Graph,
621    params: &mut HashMap<String, Vec<f32>>,
622    prefix: &str,
623    w: &Sam2RoPEAttnWeights,
624    q_in: NodeId,
625    k_in: NodeId,
626    v_in: NodeId,
627    q_len: usize,
628    k_len: usize,
629    q_in_dim: usize,
630    kv_in_dim: usize,
631    num_k_exclude_rope: usize,
632    bias: Option<NodeId>,
633) -> NodeId {
634    let d = w.embedding_dim;
635    let id = w.internal_dim;
636    let nh = w.num_heads;
637    let dh = id / nh;
638    let f = DType::F32;
639    let [end_x, end_y] = w.rope_feat_size;
640    let spatial = end_x * end_y;
641
642    let q_proj = linear(
643        g,
644        params,
645        &format!("{prefix}.q"),
646        q_in,
647        &w.q_w,
648        &w.q_b,
649        q_in_dim,
650        id,
651        q_len,
652    );
653    let k_proj = linear(
654        g,
655        params,
656        &format!("{prefix}.k"),
657        k_in,
658        &w.k_w,
659        &w.k_b,
660        kv_in_dim,
661        id,
662        k_len,
663    );
664    let v_proj = linear(
665        g,
666        params,
667        &format!("{prefix}.v"),
668        v_in,
669        &w.v_w,
670        &w.v_b,
671        kv_in_dim,
672        id,
673        k_len,
674    );
675
676    let q_rot = apply_axial_rope_graph(g, q_proj, w, q_len, 1);
677
678    let num_k_rope = k_len.saturating_sub(num_k_exclude_rope);
679    let k_rot = if num_k_rope == 0 {
680        k_proj
681    } else if num_k_rope == k_len {
682        let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
683            num_k_rope / spatial
684        } else {
685            1
686        };
687        apply_axial_rope_graph(g, k_proj, w, k_len, r)
688    } else {
689        let k_prefix = g.narrow_(k_proj, 1, 0, num_k_rope);
690        let k_suffix = g.narrow_(k_proj, 1, num_k_rope, k_len - num_k_rope);
691        let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
692            num_k_rope / spatial
693        } else {
694            1
695        };
696        let k_pre_rot = apply_axial_rope_graph(g, k_prefix, w, num_k_rope, r);
697        g.concat_(vec![k_pre_rot, k_suffix], 1)
698    };
699
700    let out_shape = Shape::new(&[1, q_len, id], f);
701    let attn = if let Some(b) = bias {
702        g.attention_bias(q_rot, k_rot, v_proj, b, nh, dh, out_shape.clone())
703    } else {
704        g.attention_kind(
705            q_rot,
706            k_rot,
707            v_proj,
708            nh,
709            dh,
710            MaskKind::None,
711            out_shape.clone(),
712        )
713    };
714    linear(
715        g,
716        params,
717        &format!("{prefix}.o"),
718        attn,
719        &w.out_w,
720        &w.out_b,
721        id,
722        d,
723        q_len,
724    )
725}
726
727fn build_layer_graph(
728    layer: &Sam2MemoryAttentionLayerWeights,
729    n_img: usize,
730    n_mem: usize,
731    kv_in_dim: usize,
732    num_obj_ptr_tokens: usize,
733) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
734    let d = layer.d_model;
735    let f = DType::F32;
736    let mut g = Graph::new("sam2_mem_attn_layer");
737    let mut params = HashMap::new();
738
739    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
740    let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
741    let memory = g.input("memory", Shape::new(&[1, n_mem, kv_in_dim], f));
742    let memory_pos = g.input("memory_pos", Shape::new(&[1, n_mem, kv_in_dim], f));
743    let mask_ca = g.input(
744        "mask_ca",
745        Shape::new(&[1, layer.cross_attn.num_heads, n_img, n_mem], f),
746    );
747
748    let seq_shape = Shape::new(&[1, n_img, d], f);
749    let mut tgt2 = layer_norm(
750        &mut g,
751        &mut params,
752        "n1",
753        tgt,
754        &layer.norm1_g,
755        &layer.norm1_b,
756        n_img,
757        d,
758    );
759    let q_sa = maybe_add_pos(&mut g, tgt2, curr_pos, n_img, d, layer.pos_enc_at_attn);
760    let sa = build_rope_attn(
761        &mut g,
762        &mut params,
763        "sa",
764        &layer.self_attn,
765        q_sa,
766        tgt2,
767        tgt2,
768        n_img,
769        n_img,
770        d,
771        d,
772        0,
773        None,
774    );
775    let mut out = g.binary(BinaryOp::Add, tgt, sa, seq_shape.clone());
776
777    tgt2 = layer_norm(
778        &mut g,
779        &mut params,
780        "n2",
781        out,
782        &layer.norm2_g,
783        &layer.norm2_b,
784        n_img,
785        d,
786    );
787    let q_ca = maybe_add_pos(
788        &mut g,
789        tgt2,
790        curr_pos,
791        n_img,
792        d,
793        layer.pos_enc_at_cross_attn_queries,
794    );
795    let k_ca = maybe_add_pos(
796        &mut g,
797        memory,
798        memory_pos,
799        n_mem,
800        kv_in_dim,
801        layer.pos_enc_at_cross_attn_keys,
802    );
803    let ca = build_rope_attn(
804        &mut g,
805        &mut params,
806        "ca",
807        &layer.cross_attn,
808        q_ca,
809        k_ca,
810        memory,
811        n_img,
812        n_mem,
813        d,
814        kv_in_dim,
815        num_obj_ptr_tokens,
816        Some(mask_ca),
817    );
818    out = g.binary(BinaryOp::Add, out, ca, seq_shape.clone());
819
820    tgt2 = layer_norm(
821        &mut g,
822        &mut params,
823        "n3",
824        out,
825        &layer.norm3_g,
826        &layer.norm3_b,
827        n_img,
828        d,
829    );
830    let dim_ff = layer.linear1_b.len();
831    let mid = linear(
832        &mut g,
833        &mut params,
834        "ff1",
835        tgt2,
836        &layer.linear1_w,
837        &layer.linear1_b,
838        d,
839        dim_ff,
840        n_img,
841    );
842    let mid = g.activation(Activation::Relu, mid, Shape::new(&[1, n_img, dim_ff], f));
843    let down = linear(
844        &mut g,
845        &mut params,
846        "ff2",
847        mid,
848        &layer.linear2_w,
849        &layer.linear2_b,
850        dim_ff,
851        d,
852        n_img,
853    );
854    out = g.binary(BinaryOp::Add, out, down, seq_shape);
855
856    g.set_outputs(vec![out]);
857    Ok((g, params))
858}
859
860fn build_qkv_proj(
861    g: &mut Graph,
862    params: &mut HashMap<String, Vec<f32>>,
863    prefix: &str,
864    w: &Sam2RoPEAttnWeights,
865    q_in: NodeId,
866    k_in: NodeId,
867    v_in: NodeId,
868    q_len: usize,
869    k_len: usize,
870    q_in_dim: usize,
871    kv_in_dim: usize,
872) -> (NodeId, NodeId, NodeId) {
873    let id = w.internal_dim;
874    let q = linear(
875        g,
876        params,
877        &format!("{prefix}.q"),
878        q_in,
879        &w.q_w,
880        &w.q_b,
881        q_in_dim,
882        id,
883        q_len,
884    );
885    let k = linear(
886        g,
887        params,
888        &format!("{prefix}.k"),
889        k_in,
890        &w.k_w,
891        &w.k_b,
892        kv_in_dim,
893        id,
894        k_len,
895    );
896    let v = linear(
897        g,
898        params,
899        &format!("{prefix}.v"),
900        v_in,
901        &w.v_w,
902        &w.v_b,
903        kv_in_dim,
904        id,
905        k_len,
906    );
907    (q, k, v)
908}
909
910fn build_attention_out(
911    g: &mut Graph,
912    params: &mut HashMap<String, Vec<f32>>,
913    prefix: &str,
914    w: &Sam2RoPEAttnWeights,
915    q: NodeId,
916    k: NodeId,
917    v: NodeId,
918    q_len: usize,
919    _k_len: usize,
920    mask: Option<NodeId>,
921) -> NodeId {
922    let d = w.embedding_dim;
923    let id = w.internal_dim;
924    let nh = w.num_heads;
925    let dh = id / nh;
926    let f = DType::F32;
927    let out_shape = Shape::new(&[1, q_len, id], f);
928    let attn = if let Some(m) = mask {
929        g.attention_bias(q, k, v, m, nh, dh, out_shape.clone())
930    } else {
931        g.attention_kind(q, k, v, nh, dh, MaskKind::None, out_shape.clone())
932    };
933    linear(
934        g,
935        params,
936        &format!("{prefix}.o"),
937        attn,
938        &w.out_w,
939        &w.out_b,
940        id,
941        d,
942        q_len,
943    )
944}
945
946fn build_self_proj_graph(
947    layer: &Sam2MemoryAttentionLayerWeights,
948    n_img: usize,
949) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
950    let d = layer.d_model;
951    let f = DType::F32;
952    let mut g = Graph::new("sam2_mem_self_proj");
953    let mut params = HashMap::new();
954    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
955    let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
956    let tgt2 = layer_norm(
957        &mut g,
958        &mut params,
959        "n1",
960        tgt,
961        &layer.norm1_g,
962        &layer.norm1_b,
963        n_img,
964        d,
965    );
966    let q_in = maybe_add_pos(&mut g, tgt2, curr_pos, n_img, d, layer.pos_enc_at_attn);
967    let (sa_q, sa_k, sa_v) = build_qkv_proj(
968        &mut g,
969        &mut params,
970        "sa",
971        &layer.self_attn,
972        q_in,
973        tgt2,
974        tgt2,
975        n_img,
976        n_img,
977        d,
978        d,
979    );
980    g.set_outputs(vec![sa_q, sa_k, sa_v]);
981    Ok((g, params))
982}
983
984fn build_self_attn_graph(
985    layer: &Sam2MemoryAttentionLayerWeights,
986    n_img: usize,
987) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
988    let d = layer.d_model;
989    let f = DType::F32;
990    let mut g = Graph::new("sam2_mem_self_attn");
991    let mut params = HashMap::new();
992    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
993    let sa_q = g.input(
994        "sa_q",
995        Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
996    );
997    let sa_k = g.input(
998        "sa_k",
999        Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
1000    );
1001    let sa_v = g.input(
1002        "sa_v",
1003        Shape::new(&[1, n_img, layer.self_attn.internal_dim], f),
1004    );
1005    let sa = build_attention_out(
1006        &mut g,
1007        &mut params,
1008        "sa",
1009        &layer.self_attn,
1010        sa_q,
1011        sa_k,
1012        sa_v,
1013        n_img,
1014        n_img,
1015        None,
1016    );
1017    let out = g.binary(BinaryOp::Add, tgt, sa, Shape::new(&[1, n_img, d], f));
1018    g.set_outputs(vec![out]);
1019    Ok((g, params))
1020}
1021
1022fn build_cross_proj_graph(
1023    layer: &Sam2MemoryAttentionLayerWeights,
1024    n_img: usize,
1025    n_mem: usize,
1026    kv: usize,
1027) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1028    let d = layer.d_model;
1029    let f = DType::F32;
1030    let mut g = Graph::new("sam2_mem_cross_proj");
1031    let mut params = HashMap::new();
1032    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1033    let curr_pos = g.input("curr_pos", Shape::new(&[1, n_img, d], f));
1034    let memory = g.input("memory", Shape::new(&[1, n_mem, kv], f));
1035    let memory_pos = g.input("memory_pos", Shape::new(&[1, n_mem, kv], f));
1036    let tgt2 = layer_norm(
1037        &mut g,
1038        &mut params,
1039        "n2",
1040        tgt,
1041        &layer.norm2_g,
1042        &layer.norm2_b,
1043        n_img,
1044        d,
1045    );
1046    let q_in = maybe_add_pos(
1047        &mut g,
1048        tgt2,
1049        curr_pos,
1050        n_img,
1051        d,
1052        layer.pos_enc_at_cross_attn_queries,
1053    );
1054    let k_in = maybe_add_pos(
1055        &mut g,
1056        memory,
1057        memory_pos,
1058        n_mem,
1059        kv,
1060        layer.pos_enc_at_cross_attn_keys,
1061    );
1062    let (ca_q, ca_k, _) = build_qkv_proj(
1063        &mut g,
1064        &mut params,
1065        "ca",
1066        &layer.cross_attn,
1067        q_in,
1068        k_in,
1069        memory,
1070        n_img,
1071        n_mem,
1072        d,
1073        kv,
1074    );
1075    g.set_outputs(vec![ca_q, ca_k]);
1076    Ok((g, params))
1077}
1078
1079fn build_cross_attn_graph(
1080    layer: &Sam2MemoryAttentionLayerWeights,
1081    n_img: usize,
1082    n_mem: usize,
1083    kv: usize,
1084) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1085    let d = layer.d_model;
1086    let f = DType::F32;
1087    let mut g = Graph::new("sam2_mem_cross_attn");
1088    let mut params = HashMap::new();
1089    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1090    let ca_q = g.input(
1091        "ca_q",
1092        Shape::new(&[1, n_img, layer.cross_attn.internal_dim], f),
1093    );
1094    let ca_k = g.input(
1095        "ca_k",
1096        Shape::new(&[1, n_mem, layer.cross_attn.internal_dim], f),
1097    );
1098    let memory = g.input("memory", Shape::new(&[1, n_mem, kv], f));
1099    let mask_ca = g.input(
1100        "mask_ca",
1101        Shape::new(&[1, layer.cross_attn.num_heads, n_img, n_mem], f),
1102    );
1103    let ca = build_attention_out(
1104        &mut g,
1105        &mut params,
1106        "ca",
1107        &layer.cross_attn,
1108        ca_q,
1109        ca_k,
1110        memory,
1111        n_img,
1112        n_mem,
1113        Some(mask_ca),
1114    );
1115    let out = g.binary(BinaryOp::Add, tgt, ca, Shape::new(&[1, n_img, d], f));
1116    g.set_outputs(vec![out]);
1117    Ok((g, params))
1118}
1119
1120fn build_ffn_graph(
1121    layer: &Sam2MemoryAttentionLayerWeights,
1122    n_img: usize,
1123) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1124    let d = layer.d_model;
1125    let f = DType::F32;
1126    let mut g = Graph::new("sam2_mem_ffn");
1127    let mut params = HashMap::new();
1128    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1129    let seq_shape = Shape::new(&[1, n_img, d], f);
1130    let normed = layer_norm(
1131        &mut g,
1132        &mut params,
1133        "n3",
1134        tgt,
1135        &layer.norm3_g,
1136        &layer.norm3_b,
1137        n_img,
1138        d,
1139    );
1140    let dim_ff = layer.linear1_b.len();
1141    let mid = linear(
1142        &mut g,
1143        &mut params,
1144        "ff1",
1145        normed,
1146        &layer.linear1_w,
1147        &layer.linear1_b,
1148        d,
1149        dim_ff,
1150        n_img,
1151    );
1152    let mid = g.activation(Activation::Relu, mid, Shape::new(&[1, n_img, dim_ff], f));
1153    let down = linear(
1154        &mut g,
1155        &mut params,
1156        "ff2",
1157        mid,
1158        &layer.linear2_w,
1159        &layer.linear2_b,
1160        dim_ff,
1161        d,
1162        n_img,
1163    );
1164    let out = g.binary(BinaryOp::Add, tgt, down, seq_shape);
1165    g.set_outputs(vec![out]);
1166    Ok((g, params))
1167}
1168
1169fn build_final_norm_graph(
1170    norm_g: &[f32],
1171    norm_b: &[f32],
1172    n_img: usize,
1173    d: usize,
1174) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1175    let f = DType::F32;
1176    let mut g = Graph::new("sam2_mem_attn_final");
1177    let mut params = HashMap::new();
1178    let tgt = g.input("tgt", Shape::new(&[1, n_img, d], f));
1179    let out = layer_norm(
1180        &mut g,
1181        &mut params,
1182        "out_norm",
1183        tgt,
1184        norm_g,
1185        norm_b,
1186        n_img,
1187        d,
1188    );
1189    g.set_outputs(vec![out]);
1190    Ok((g, params))
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195    use super::*;
1196    use crate::axial_rope::apply_axial_rope_2d;
1197    use crate::memory_attention::{
1198        Sam2MemoryAttentionLayerWeights, Sam2MemoryAttentionWeights, Sam2RoPEAttnWeights,
1199        memory_attention_forward, memory_attention_layer_forward,
1200    };
1201    use crate::transformer::layer_norm_last_cpu;
1202    use rlx_ir::Graph;
1203
1204    #[test]
1205    fn axial_rope2d_op_matches_host_merged_layout() {
1206        let nh = 1usize;
1207        let n = 64usize;
1208        let dh = 256usize;
1209        let feat = [8usize, 8usize];
1210        let x: Vec<f32> = (0..n * nh * dh).map(|i| i as f32 * 0.001).collect();
1211        let host = apply_axial_rope_2d(&x, nh, n, dh, feat[0], feat[1], 10000.0, 1);
1212
1213        let mut g = Graph::new("axial_rope_check");
1214        let f = rlx_ir::DType::F32;
1215        let inp = g.input("x", Shape::new(&[1, n, nh * dh], f));
1216        let out = g.axial_rope2d(inp, feat[0], feat[1], dh, nh, 10000.0, 1);
1217        g.set_outputs(vec![out]);
1218        let mut compiled =
1219            rlx_core::flow_bridge::compile_graph_sam(Device::Cpu, g).expect("compile");
1220        let ir = compiled.run(&[("x", &x)]).into_iter().next().unwrap();
1221
1222        let fd = host
1223            .iter()
1224            .zip(&ir)
1225            .map(|(a, b)| (a - b).abs())
1226            .fold(0f32, f32::max);
1227        assert!(fd < 1e-5, "axial_rope2d op vs host max |Δ| = {fd:.3e}");
1228    }
1229
1230    fn synth_rope_attn(d: usize, kv: usize, feat: [usize; 2]) -> Sam2RoPEAttnWeights {
1231        let id = d;
1232        Sam2RoPEAttnWeights {
1233            q_w: vec![0.01; id * d],
1234            q_b: vec![0.0; id],
1235            k_w: vec![0.02; id * kv],
1236            k_b: vec![0.0; id],
1237            v_w: vec![0.03; id * kv],
1238            v_b: vec![0.0; id],
1239            out_w: vec![0.04; d * id],
1240            out_b: vec![0.0; d],
1241            embedding_dim: d,
1242            kv_in_dim: kv,
1243            internal_dim: id,
1244            num_heads: 1,
1245            rope_theta: 10000.0,
1246            rope_feat_size: feat,
1247            rope_k_repeat: true,
1248        }
1249    }
1250
1251    #[test]
1252    fn memory_attention_ir_matches_host_small_grid() {
1253        let d = 256usize;
1254        let kv = 64usize;
1255        let feat = [8usize, 8usize];
1256        let n_img = 64usize;
1257        let n_mem = 64usize;
1258        let layer = Sam2MemoryAttentionLayerWeights {
1259            self_attn: synth_rope_attn(d, d, feat),
1260            cross_attn: synth_rope_attn(d, kv, feat),
1261            norm1_g: vec![1.0; d],
1262            norm1_b: vec![0.0; d],
1263            norm2_g: vec![1.0; d],
1264            norm2_b: vec![0.0; d],
1265            norm3_g: vec![1.0; d],
1266            norm3_b: vec![0.0; d],
1267            linear1_w: vec![0.01; 2048 * d],
1268            linear1_b: vec![0.0; 2048],
1269            linear2_w: vec![0.02; d * 2048],
1270            linear2_b: vec![0.0; d],
1271            pos_enc_at_attn: false,
1272            pos_enc_at_cross_attn_queries: false,
1273            pos_enc_at_cross_attn_keys: true,
1274            d_model: d,
1275        };
1276        let w = Sam2MemoryAttentionWeights {
1277            layers: vec![layer],
1278            norm_g: vec![1.0; d],
1279            norm_b: vec![0.0; d],
1280            d_model: d,
1281            pos_enc_at_input: true,
1282        };
1283
1284        let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1285        let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1286        let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1287        let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1288
1289        let host = memory_attention_forward(
1290            &w,
1291            &curr,
1292            &curr_pos,
1293            &memory,
1294            &memory_pos,
1295            n_img,
1296            n_mem,
1297            kv,
1298            0,
1299        )
1300        .unwrap();
1301
1302        let mut ir = MemoryAttentionCompiled::compile(&w, n_img, n_mem, 0, Device::Cpu).unwrap();
1303        let got = ir
1304            .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1305            .unwrap();
1306
1307        let fd = host
1308            .iter()
1309            .zip(&got)
1310            .map(|(a, b)| (a - b).abs())
1311            .fold(0f32, f32::max);
1312        assert!(fd < 3e-2, "memory attention max |Δ| = {fd:.3e}");
1313    }
1314
1315    #[test]
1316    fn memory_attention_in_graph_rope_matches_host_small_grid() {
1317        let d = 256usize;
1318        let kv = 64usize;
1319        let feat = [8usize, 8usize];
1320        let n_img = 64usize;
1321        let n_mem = 64usize;
1322        let layer = Sam2MemoryAttentionLayerWeights {
1323            self_attn: synth_rope_attn(d, d, feat),
1324            cross_attn: synth_rope_attn(d, kv, feat),
1325            norm1_g: vec![1.0; d],
1326            norm1_b: vec![0.0; d],
1327            norm2_g: vec![1.0; d],
1328            norm2_b: vec![0.0; d],
1329            norm3_g: vec![1.0; d],
1330            norm3_b: vec![0.0; d],
1331            linear1_w: vec![0.01; 2048 * d],
1332            linear1_b: vec![0.0; 2048],
1333            linear2_w: vec![0.02; d * 2048],
1334            linear2_b: vec![0.0; d],
1335            pos_enc_at_attn: false,
1336            pos_enc_at_cross_attn_queries: false,
1337            pos_enc_at_cross_attn_keys: true,
1338            d_model: d,
1339        };
1340        let w = Sam2MemoryAttentionWeights {
1341            layers: vec![layer],
1342            norm_g: vec![1.0; d],
1343            norm_b: vec![0.0; d],
1344            d_model: d,
1345            pos_enc_at_input: true,
1346        };
1347
1348        let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1349        let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1350        let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1351        let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1352
1353        let host_mid = crate::memory_attention::memory_attention_forward_layers_only(
1354            &w,
1355            &curr,
1356            &curr_pos,
1357            &memory,
1358            &memory_pos,
1359            n_img,
1360            n_mem,
1361            kv,
1362            0,
1363        )
1364        .unwrap();
1365
1366        let mut ir =
1367            MemoryAttentionCompiled::compile_in_graph_rope(&w, n_img, n_mem, 0, Device::Cpu)
1368                .unwrap();
1369        let got = ir
1370            .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1371            .unwrap();
1372
1373        let mut mem_pad = vec![0f32; n_mem * kv];
1374        let mut mem_pos_pad = vec![0f32; n_mem * kv];
1375        mem_pad.copy_from_slice(&memory);
1376        mem_pos_pad.copy_from_slice(&memory_pos);
1377        let mut tgt = curr.clone();
1378        if w.pos_enc_at_input {
1379            for i in 0..tgt.len() {
1380                tgt[i] += INPUT_POS_SCALE * curr_pos[i];
1381            }
1382        }
1383        let nh = w.layers[0].cross_attn.num_heads;
1384        let mut mask = vec![0f32; nh * n_img * n_mem];
1385        fill_cross_attn_bias(&mut mask, nh, n_img, n_mem, n_mem);
1386        let layer_inputs = [
1387            ("tgt", tgt.as_slice()),
1388            ("curr_pos", curr_pos.as_slice()),
1389            ("memory", mem_pad.as_slice()),
1390            ("memory_pos", mem_pos_pad.as_slice()),
1391            ("mask_ca", mask.as_slice()),
1392        ];
1393        let ir_mid = ir.layers[0]
1394            .fused
1395            .as_mut()
1396            .expect("fused")
1397            .run(&layer_inputs)
1398            .into_iter()
1399            .next()
1400            .unwrap();
1401        let fd_layer = host_mid
1402            .iter()
1403            .zip(&ir_mid)
1404            .map(|(a, b)| (a - b).abs())
1405            .fold(0f32, f32::max);
1406        assert!(fd_layer < 5e-2, "in-graph layer max |Δ| = {fd_layer:.3e}");
1407
1408        let (fg, fp) = build_final_norm_graph(&w.norm_g, &w.norm_b, n_img, d).unwrap();
1409        let mut fn_alone =
1410            Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1411        for (n, data) in &fp {
1412            fn_alone.set_param(n, data);
1413        }
1414        let ir_via_fn = fn_alone
1415            .run(&[("tgt", &ir_mid)])
1416            .into_iter()
1417            .next()
1418            .unwrap();
1419        let fd_got_fn = got
1420            .iter()
1421            .zip(&ir_via_fn)
1422            .map(|(a, b)| (a - b).abs())
1423            .fold(0f32, f32::max);
1424        assert!(
1425            fd_got_fn < 1e-4,
1426            "pipeline output vs final_norm(ir_mid) max |Δ| = {fd_got_fn:.3e}"
1427        );
1428        // End-to-end vs host layers is dominated by stack LN sensitivity on ~2e-2
1429        // layer deltas; see `stack_final_norm_ir_matches_host_layer_output`.
1430    }
1431
1432    #[test]
1433    fn cpu_layer_norm_row_matches_host_last_cpu() {
1434        let rows = 64usize;
1435        let h = 256usize;
1436        let x: Vec<f32> = (0..rows * h).map(|i| (i as f32) * 1e-3 - 0.5).collect();
1437        let g = vec![1.0; h];
1438        let b = vec![0.0; h];
1439        let mut host = x.clone();
1440        layer_norm_last_cpu(&mut host, rows, h, &g, &b, LN_EPS);
1441        let mut cpu = x.clone();
1442        for r in 0..rows {
1443            rlx_cpu::kernels::layer_norm_row(
1444                &x[r * h..(r + 1) * h],
1445                &g,
1446                &b,
1447                &mut cpu[r * h..(r + 1) * h],
1448                h,
1449                LN_EPS,
1450            );
1451        }
1452        let fd = host
1453            .iter()
1454            .zip(&cpu)
1455            .map(|(a, b)| (a - b).abs())
1456            .fold(0f32, f32::max);
1457        assert!(
1458            fd < 1e-5,
1459            "cpu layer_norm_row vs host_last_cpu max |Δ| = {fd:.3e}"
1460        );
1461    }
1462
1463    #[test]
1464    fn layer_norm_ir_matches_host_synthetic() {
1465        let n_img = 64usize;
1466        let d = 256usize;
1467        let x: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-3 - 0.5).collect();
1468        let norm_g = vec![1.0; d];
1469        let norm_b = vec![0.0; d];
1470        let mut host = x.clone();
1471        layer_norm_last_cpu(&mut host, n_img, d, &norm_g, &norm_b, LN_EPS);
1472        let (fg, fp) = build_final_norm_graph(&norm_g, &norm_b, n_img, d).unwrap();
1473        let mut compiled =
1474            Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1475        for (n, data) in &fp {
1476            compiled.set_param(n, data);
1477        }
1478        let ir = compiled.run(&[("tgt", &x)]).into_iter().next().unwrap();
1479        let fd = host
1480            .iter()
1481            .zip(&ir)
1482            .map(|(a, b)| (a - b).abs())
1483            .fold(0f32, f32::max);
1484        assert!(
1485            fd < 1e-4,
1486            "synthetic final norm IR vs host max |Δ| = {fd:.3e}"
1487        );
1488    }
1489
1490    #[test]
1491    fn stack_final_norm_ir_matches_host_layer_output() {
1492        let d = 256usize;
1493        let n_img = 64usize;
1494        let layer_tgt: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1495        let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1496        let memory: Vec<f32> = (0..n_img * 64).map(|i| (i as f32) * 2e-4).collect();
1497        let memory_pos: Vec<f32> = (0..n_img * 64).map(|i| (i as f32) * 2e-5).collect();
1498        let layer = Sam2MemoryAttentionLayerWeights {
1499            self_attn: synth_rope_attn(d, d, [8, 8]),
1500            cross_attn: synth_rope_attn(d, 64, [8, 8]),
1501            norm1_g: vec![1.0; d],
1502            norm1_b: vec![0.0; d],
1503            norm2_g: vec![1.0; d],
1504            norm2_b: vec![0.0; d],
1505            norm3_g: vec![1.0; d],
1506            norm3_b: vec![0.0; d],
1507            linear1_w: vec![0.01; 2048 * d],
1508            linear1_b: vec![0.0; 2048],
1509            linear2_w: vec![0.02; d * 2048],
1510            linear2_b: vec![0.0; d],
1511            pos_enc_at_attn: false,
1512            pos_enc_at_cross_attn_queries: false,
1513            pos_enc_at_cross_attn_keys: true,
1514            d_model: d,
1515        };
1516        let host_layer = memory_attention_layer_forward(
1517            &layer,
1518            layer_tgt,
1519            &curr_pos,
1520            &memory,
1521            &memory_pos,
1522            n_img,
1523            n_img,
1524            64,
1525            0,
1526        )
1527        .unwrap();
1528        let mut host_final = host_layer.clone();
1529        let norm_g = vec![1.0; d];
1530        let norm_b = vec![0.0; d];
1531        layer_norm_last_cpu(&mut host_final, n_img, d, &norm_g, &norm_b, LN_EPS);
1532
1533        let (fg, fp) = build_final_norm_graph(&norm_g, &norm_b, n_img, d).unwrap();
1534        let mut compiled =
1535            Session::new(Device::Cpu).compile_with(fg, &compile_opts_no_fusion(Device::Cpu));
1536        for (n, data) in &fp {
1537            compiled.set_param(n, data);
1538        }
1539        let ir_final = compiled
1540            .run(&[("tgt", &host_layer)])
1541            .into_iter()
1542            .next()
1543            .unwrap();
1544        let fd = host_final
1545            .iter()
1546            .zip(&ir_final)
1547            .map(|(a, b)| (a - b).abs())
1548            .fold(0f32, f32::max);
1549        assert!(fd < 1e-4, "stack final norm IR vs host max |Δ| = {fd:.3e}");
1550    }
1551
1552    /// Bisect monolithic layer: compare fused layer graph vs host `memory_attention_layer_forward`.
1553    #[test]
1554    fn memory_attention_layer_in_graph_rope_bisect() {
1555        let d = 256usize;
1556        let kv = 64usize;
1557        let feat = [8usize, 8usize];
1558        let n_img = 64usize;
1559        let n_mem = 64usize;
1560        let layer = Sam2MemoryAttentionLayerWeights {
1561            self_attn: synth_rope_attn(d, d, feat),
1562            cross_attn: synth_rope_attn(d, kv, feat),
1563            norm1_g: vec![1.0; d],
1564            norm1_b: vec![0.0; d],
1565            norm2_g: vec![1.0; d],
1566            norm2_b: vec![0.0; d],
1567            norm3_g: vec![1.0; d],
1568            norm3_b: vec![0.0; d],
1569            linear1_w: vec![0.01; 2048 * d],
1570            linear1_b: vec![0.0; 2048],
1571            linear2_w: vec![0.02; d * 2048],
1572            linear2_b: vec![0.0; d],
1573            pos_enc_at_attn: false,
1574            pos_enc_at_cross_attn_queries: false,
1575            pos_enc_at_cross_attn_keys: true,
1576            d_model: d,
1577        };
1578
1579        let mut tgt: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1580        for i in 0..tgt.len() {
1581            tgt[i] += INPUT_POS_SCALE * (i as f32) * 1e-5;
1582        }
1583        let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1584        let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1585        let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1586
1587        let host = crate::memory_attention::memory_attention_layer_forward(
1588            &layer,
1589            tgt.clone(),
1590            &curr_pos,
1591            &memory,
1592            &memory_pos,
1593            n_img,
1594            n_mem,
1595            kv,
1596            0,
1597        )
1598        .unwrap();
1599
1600        let (g, p) = build_layer_graph(&layer, n_img, n_mem, kv, 0).unwrap();
1601        let nh = layer.cross_attn.num_heads;
1602        let mut mask = vec![0f32; nh * n_img * n_mem];
1603        fill_cross_attn_bias(&mut mask, nh, n_img, n_mem, n_mem);
1604        let mut compiled =
1605            Session::new(Device::Cpu).compile_with(g, &compile_opts_no_fusion(Device::Cpu));
1606        for (n, data) in &p {
1607            compiled.set_param(n, data);
1608        }
1609        let got = compiled
1610            .run(&[
1611                ("tgt", &tgt),
1612                ("curr_pos", &curr_pos),
1613                ("memory", &memory),
1614                ("memory_pos", &memory_pos),
1615                ("mask_ca", &mask),
1616            ])
1617            .into_iter()
1618            .next()
1619            .unwrap();
1620
1621        let fd = host
1622            .iter()
1623            .zip(&got)
1624            .map(|(a, b)| (a - b).abs())
1625            .fold(0f32, f32::max);
1626        assert!(fd < 3e-2, "layer in-graph rope max |Δ| = {fd:.3e}");
1627    }
1628
1629    /// Quick check: in-graph RoPE compiles and runs; log timing vs default (no hard SLA).
1630    #[test]
1631    fn memory_attention_in_graph_rope_timing_quick_check() {
1632        use std::time::Instant;
1633
1634        let d = 256usize;
1635        let kv = 64usize;
1636        let feat = [8usize, 8usize];
1637        let n_img = 64usize;
1638        let n_mem = 64usize;
1639        let layer = Sam2MemoryAttentionLayerWeights {
1640            self_attn: synth_rope_attn(d, d, feat),
1641            cross_attn: synth_rope_attn(d, kv, feat),
1642            norm1_g: vec![1.0; d],
1643            norm1_b: vec![0.0; d],
1644            norm2_g: vec![1.0; d],
1645            norm2_b: vec![0.0; d],
1646            norm3_g: vec![1.0; d],
1647            norm3_b: vec![0.0; d],
1648            linear1_w: vec![0.01; 2048 * d],
1649            linear1_b: vec![0.0; 2048],
1650            linear2_w: vec![0.02; d * 2048],
1651            linear2_b: vec![0.0; d],
1652            pos_enc_at_attn: false,
1653            pos_enc_at_cross_attn_queries: false,
1654            pos_enc_at_cross_attn_keys: true,
1655            d_model: d,
1656        };
1657        let w = Sam2MemoryAttentionWeights {
1658            layers: vec![layer],
1659            norm_g: vec![1.0; d],
1660            norm_b: vec![0.0; d],
1661            d_model: d,
1662            pos_enc_at_input: true,
1663        };
1664        let curr: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-4).collect();
1665        let curr_pos: Vec<f32> = (0..n_img * d).map(|i| (i as f32) * 1e-5).collect();
1666        let memory: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-4).collect();
1667        let memory_pos: Vec<f32> = (0..n_mem * kv).map(|i| (i as f32) * 2e-5).collect();
1668
1669        let t0 = Instant::now();
1670        let mut default =
1671            MemoryAttentionCompiled::compile(&w, n_img, n_mem, 0, Device::Cpu).unwrap();
1672        let compile_default_ms = t0.elapsed().as_secs_f64() * 1000.0;
1673
1674        let t1 = Instant::now();
1675        let mut in_graph =
1676            MemoryAttentionCompiled::compile_in_graph_rope(&w, n_img, n_mem, 0, Device::Cpu)
1677                .unwrap();
1678        let compile_in_graph_ms = t1.elapsed().as_secs_f64() * 1000.0;
1679
1680        const RUNS: usize = 5;
1681        let t2 = Instant::now();
1682        for _ in 0..RUNS {
1683            let _ = default
1684                .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1685                .unwrap();
1686        }
1687        let run_default_ms = t2.elapsed().as_secs_f64() * 1000.0 / RUNS as f64;
1688
1689        let t3 = Instant::now();
1690        for _ in 0..RUNS {
1691            let _ = in_graph
1692                .run(&curr, &curr_pos, &memory, &memory_pos, n_mem, 0)
1693                .unwrap();
1694        }
1695        let run_in_graph_ms = t3.elapsed().as_secs_f64() * 1000.0 / RUNS as f64;
1696
1697        eprintln!(
1698            "mem_attn compile ms: default={compile_default_ms:.2} in_graph={compile_in_graph_ms:.2}; \
1699             run ms (avg/{RUNS}): default={run_default_ms:.2} in_graph={run_in_graph_ms:.2}"
1700        );
1701        assert!(compile_in_graph_ms > 0.0 && run_in_graph_ms > 0.0);
1702    }
1703}