Skip to main content

rlx_sam2/
memory_attention.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 memory attention — host-side.
17//!
18//! Mirrors `sam2/modeling/memory_attention.py` and the RoPE-attention
19//! helper from `sam2/modeling/sam/transformer.py`. The memory attention
20//! is the video-tracking core: each layer self-attends the current
21//! frame's image tokens, then cross-attends them to the memory bank
22//! (spatial memory tokens from prior frames + object-pointer tokens
23//! from prior frame decoder outputs).
24//!
25//! ## Per-layer structure
26//!
27//! ```text
28//!   tgt = curr_image_tokens   # [B, N_img, d_model]
29//!   memory = [spatial_memory_tokens; object_pointer_tokens]  # [B, N_mem, kv_in_dim]
30//!
31//!   # ── Self-attention (RoPE on Q and K) ──
32//!   tgt = tgt + SelfAttn(LN(tgt), LN(tgt), LN(tgt))
33//!
34//!   # ── Cross-attention to memory (RoPE on Q + spatial part of K) ──
35//!   #   `num_k_exclude_rope = number of obj_ptr tokens in K` — those
36//!   #   get no rotary encoding (they're positionless).
37//!   tgt = tgt + CrossAttn(LN(tgt), memory + memory_pos, memory,
38//!                         num_k_exclude_rope=N_obj_ptr)
39//!
40//!   # ── FFN ──
41//!   tgt = tgt + Linear(ReLU(Linear(LN(tgt))))
42//! ```
43//!
44//! ## Axial 2-D RoPE
45//!
46//! Reference's `compute_axial_cis(dim, end_x, end_y, theta)` builds a
47//! per-position complex rotation table where the first `dim/2` head
48//! channels rotate by x-coordinate frequencies and the second `dim/2`
49//! by y-coordinate frequencies. `apply_rotary_enc` then multiplies
50//! each contiguous (real, imag) channel pair by the complex factor.
51//! When the memory bank holds `r` frames, `repeat_freqs_k=True`
52//! interleave-repeats the per-position freqs `r` times along the
53//! sequence axis so K's rotation matches the temporal stacking.
54
55use super::config::Sam2MemoryConfig;
56use super::transformer::{layer_norm_last, layer_norm_last_cpu, linear};
57use anyhow::{Result, ensure};
58use rlx_core::weight_map::WeightMap;
59
60// ─── Weight structs ─────────────────────────────────────────────────
61
62pub struct Sam2RoPEAttnWeights {
63    pub q_w: Vec<f32>, // [internal_dim, embedding_dim]
64    pub q_b: Vec<f32>,
65    pub k_w: Vec<f32>, // [internal_dim, kv_in_dim]
66    pub k_b: Vec<f32>,
67    pub v_w: Vec<f32>, // [internal_dim, kv_in_dim]
68    pub v_b: Vec<f32>,
69    pub out_w: Vec<f32>, // [embedding_dim, internal_dim]
70    pub out_b: Vec<f32>,
71    pub embedding_dim: usize,
72    pub kv_in_dim: usize,
73    pub internal_dim: usize,
74    pub num_heads: usize,
75    pub rope_theta: f32,
76    pub rope_feat_size: [usize; 2],
77    pub rope_k_repeat: bool,
78}
79
80pub struct Sam2MemoryAttentionLayerWeights {
81    pub self_attn: Sam2RoPEAttnWeights,
82    pub cross_attn: Sam2RoPEAttnWeights,
83    pub norm1_g: Vec<f32>,
84    pub norm1_b: Vec<f32>,
85    pub norm2_g: Vec<f32>,
86    pub norm2_b: Vec<f32>,
87    pub norm3_g: Vec<f32>,
88    pub norm3_b: Vec<f32>,
89    pub linear1_w: Vec<f32>, // [dim_ff, d_model]
90    pub linear1_b: Vec<f32>,
91    pub linear2_w: Vec<f32>, // [d_model, dim_ff]
92    pub linear2_b: Vec<f32>,
93    pub pos_enc_at_attn: bool,
94    pub pos_enc_at_cross_attn_queries: bool,
95    pub pos_enc_at_cross_attn_keys: bool,
96    pub d_model: usize,
97}
98
99pub struct Sam2MemoryAttentionWeights {
100    pub layers: Vec<Sam2MemoryAttentionLayerWeights>,
101    pub norm_g: Vec<f32>,
102    pub norm_b: Vec<f32>,
103    pub d_model: usize,
104    pub pos_enc_at_input: bool,
105}
106
107// ─── Weight extraction ─────────────────────────────────────────────
108
109fn load_rope_attn(
110    weights: &mut WeightMap,
111    prefix: &str,
112    cfg: &Sam2MemoryConfig,
113    is_self: bool,
114) -> Result<Sam2RoPEAttnWeights> {
115    let d = cfg.d_model;
116    let internal_dim = d; // downsample_rate=1 in published configs
117    let kv_in_dim = if is_self { d } else { cfg.kv_in_dim };
118    let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
119    ensure!(
120        sh == vec![internal_dim, d],
121        "{prefix}.q_proj.weight shape {sh:?} not [{internal_dim}, {d}]"
122    );
123    let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
124    let (k_w, sh) = weights.take(&format!("{prefix}.k_proj.weight"))?;
125    ensure!(
126        sh == vec![internal_dim, kv_in_dim],
127        "{prefix}.k_proj.weight shape {sh:?} not [{internal_dim}, {kv_in_dim}]"
128    );
129    let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
130    let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
131    let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
132    let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
133    ensure!(
134        sh == vec![d, internal_dim],
135        "{prefix}.out_proj.weight shape {sh:?} not [{d}, {internal_dim}]"
136    );
137    let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
138    Ok(Sam2RoPEAttnWeights {
139        q_w,
140        q_b,
141        k_w,
142        k_b,
143        v_w,
144        v_b,
145        out_w,
146        out_b,
147        embedding_dim: d,
148        kv_in_dim,
149        internal_dim,
150        num_heads: cfg.num_heads,
151        rope_theta: cfg.rope_theta,
152        rope_feat_size: cfg.rope_feat_size,
153        rope_k_repeat: cfg.rope_k_repeat,
154    })
155}
156
157pub fn extract_memory_attention_weights(
158    weights: &mut WeightMap,
159    cfg: &Sam2MemoryConfig,
160) -> Result<Sam2MemoryAttentionWeights> {
161    let mut layers = Vec::with_capacity(cfg.num_layers);
162    for i in 0..cfg.num_layers {
163        let p = format!("memory_attention.layers.{i}");
164        let self_attn = load_rope_attn(
165            weights,
166            &format!("{p}.self_attn"),
167            cfg,
168            /*is_self=*/ true,
169        )?;
170        let cross_attn = load_rope_attn(
171            weights,
172            &format!("{p}.cross_attn_image"),
173            cfg,
174            /*is_self=*/ false,
175        )?;
176        let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
177        let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
178        let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
179        let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
180        let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
181        let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
182        let (linear1_w, sh) = weights.take(&format!("{p}.linear1.weight"))?;
183        ensure!(
184            sh == vec![cfg.dim_feedforward, cfg.d_model],
185            "{p}.linear1.weight shape {sh:?} not [{}, {}]",
186            cfg.dim_feedforward,
187            cfg.d_model
188        );
189        let (linear1_b, _) = weights.take(&format!("{p}.linear1.bias"))?;
190        let (linear2_w, _) = weights.take(&format!("{p}.linear2.weight"))?;
191        let (linear2_b, _) = weights.take(&format!("{p}.linear2.bias"))?;
192        layers.push(Sam2MemoryAttentionLayerWeights {
193            self_attn,
194            cross_attn,
195            norm1_g,
196            norm1_b,
197            norm2_g,
198            norm2_b,
199            norm3_g,
200            norm3_b,
201            linear1_w,
202            linear1_b,
203            linear2_w,
204            linear2_b,
205            pos_enc_at_attn: cfg.pos_enc_at_attn,
206            pos_enc_at_cross_attn_queries: cfg.pos_enc_at_cross_attn_queries,
207            pos_enc_at_cross_attn_keys: cfg.pos_enc_at_cross_attn_keys,
208            d_model: cfg.d_model,
209        });
210    }
211    let (norm_g, _) = weights.take("memory_attention.norm.weight")?;
212    let (norm_b, _) = weights.take("memory_attention.norm.bias")?;
213    Ok(Sam2MemoryAttentionWeights {
214        layers,
215        norm_g,
216        norm_b,
217        d_model: cfg.d_model,
218        pos_enc_at_input: cfg.pos_enc_at_input,
219    })
220}
221
222// ─── Forward ────────────────────────────────────────────────────────
223
224/// Memory attention forward.
225///
226/// `curr`: current frame's image tokens `[N_img, d_model]` (B=1).
227/// `curr_pos`: same-shape positional encoding (sinusoidal 2-D, from
228///     the FpnNeck stride-32 level).
229/// `memory`: memory bank `[N_mem, kv_in_dim]` — concatenation of
230///     `[spatial_tokens; object_pointer_tokens]` in that order.
231/// `memory_pos`: same-shape positional encoding for memory. Object-
232///     pointer tokens may use zeros for their pos slots — they're
233///     excluded from RoPE via `num_obj_ptr_tokens`.
234/// `num_obj_ptr_tokens`: count of obj-ptr tokens at the *end* of
235///     memory (i.e. `N_mem - num_obj_ptr_tokens` spatial tokens).
236pub fn memory_attention_forward(
237    w: &Sam2MemoryAttentionWeights,
238    curr: &[f32],
239    curr_pos: &[f32],
240    memory: &[f32],
241    memory_pos: &[f32],
242    n_img: usize,
243    n_mem: usize,
244    kv_in_dim: usize,
245    num_obj_ptr_tokens: usize,
246) -> Result<Vec<f32>> {
247    let d = w.d_model;
248    ensure!(curr.len() == n_img * d, "curr len mismatch");
249    ensure!(curr_pos.len() == n_img * d, "curr_pos len mismatch");
250    ensure!(memory.len() == n_mem * kv_in_dim, "memory len mismatch");
251    ensure!(
252        memory_pos.len() == n_mem * kv_in_dim,
253        "memory_pos len mismatch"
254    );
255
256    // Apply 0.1·curr_pos at input (reference uses `output = output + 0.1 * curr_pos`).
257    let mut output = curr.to_vec();
258    if w.pos_enc_at_input {
259        for i in 0..output.len() {
260            output[i] += 0.1 * curr_pos[i];
261        }
262    }
263
264    for layer in &w.layers {
265        output = memory_attention_layer_forward(
266            layer,
267            output,
268            curr_pos,
269            memory,
270            memory_pos,
271            n_img,
272            n_mem,
273            kv_in_dim,
274            num_obj_ptr_tokens,
275        )?;
276    }
277
278    layer_norm_last(&mut output, n_img, d, &w.norm_g, &w.norm_b, 1e-5);
279    Ok(output)
280}
281
282/// Layers + input pos only (no stack final norm).
283pub fn memory_attention_forward_layers_only(
284    w: &Sam2MemoryAttentionWeights,
285    curr: &[f32],
286    curr_pos: &[f32],
287    memory: &[f32],
288    memory_pos: &[f32],
289    n_img: usize,
290    n_mem: usize,
291    kv_in_dim: usize,
292    num_obj_ptr_tokens: usize,
293) -> Result<Vec<f32>> {
294    let _d = w.d_model;
295    let mut output = curr.to_vec();
296    if w.pos_enc_at_input {
297        for i in 0..output.len() {
298            output[i] += 0.1 * curr_pos[i];
299        }
300    }
301    for layer in &w.layers {
302        output = memory_attention_layer_forward(
303            layer,
304            output,
305            curr_pos,
306            memory,
307            memory_pos,
308            n_img,
309            n_mem,
310            kv_in_dim,
311            num_obj_ptr_tokens,
312        )?;
313    }
314    Ok(output)
315}
316
317/// Same as [`memory_attention_forward`] but stack final norm uses the CPU/IR
318/// kernel (`layer_norm_last_cpu`). Use when comparing to compiled IR that ends
319/// with `Op::LayerNorm`.
320pub fn memory_attention_forward_ir_stack(
321    w: &Sam2MemoryAttentionWeights,
322    curr: &[f32],
323    curr_pos: &[f32],
324    memory: &[f32],
325    memory_pos: &[f32],
326    n_img: usize,
327    n_mem: usize,
328    kv_in_dim: usize,
329    num_obj_ptr_tokens: usize,
330) -> Result<Vec<f32>> {
331    let d = w.d_model;
332    let mut output = memory_attention_forward_layers_only(
333        w,
334        curr,
335        curr_pos,
336        memory,
337        memory_pos,
338        n_img,
339        n_mem,
340        kv_in_dim,
341        num_obj_ptr_tokens,
342    )?;
343    layer_norm_last_cpu(&mut output, n_img, d, &w.norm_g, &w.norm_b, 1e-5);
344    Ok(output)
345}
346
347#[allow(clippy::too_many_arguments)]
348pub(crate) fn memory_attention_layer_forward(
349    w: &Sam2MemoryAttentionLayerWeights,
350    mut tgt: Vec<f32>,
351    query_pos: &[f32],
352    memory: &[f32],
353    memory_pos: &[f32],
354    n_img: usize,
355    n_mem: usize,
356    kv_in_dim: usize,
357    num_obj_ptr_tokens: usize,
358) -> Result<Vec<f32>> {
359    let d = w.d_model;
360
361    // ── Self-attention ──
362    let mut tgt2 = tgt.clone();
363    layer_norm_last(&mut tgt2, n_img, d, &w.norm1_g, &w.norm1_b, 1e-5);
364    let q_in = if w.pos_enc_at_attn {
365        let mut x = tgt2.clone();
366        for i in 0..x.len() {
367            x[i] += query_pos[i];
368        }
369        x
370    } else {
371        tgt2.clone()
372    };
373    let k_in = q_in.clone();
374    let v_in = tgt2.clone();
375    let sa_out = rope_attn_forward(
376        &w.self_attn,
377        &q_in,
378        n_img,
379        &k_in,
380        n_img,
381        &v_in,
382        n_img,
383        d,
384        d,
385        /*num_k_exclude_rope=*/ 0,
386    );
387    for i in 0..tgt.len() {
388        tgt[i] += sa_out[i];
389    }
390
391    // ── Cross-attention to memory ──
392    let mut tgt2 = tgt.clone();
393    layer_norm_last(&mut tgt2, n_img, d, &w.norm2_g, &w.norm2_b, 1e-5);
394    let q_in = if w.pos_enc_at_cross_attn_queries {
395        let mut x = tgt2.clone();
396        for i in 0..x.len() {
397            x[i] += query_pos[i];
398        }
399        x
400    } else {
401        tgt2
402    };
403    let k_in = if w.pos_enc_at_cross_attn_keys {
404        let mut x = memory.to_vec();
405        for i in 0..x.len() {
406            x[i] += memory_pos[i];
407        }
408        x
409    } else {
410        memory.to_vec()
411    };
412    let ca_out = rope_attn_forward(
413        &w.cross_attn,
414        &q_in,
415        n_img,
416        &k_in,
417        n_mem,
418        memory,
419        n_mem,
420        d,
421        kv_in_dim,
422        num_obj_ptr_tokens,
423    );
424    for i in 0..tgt.len() {
425        tgt[i] += ca_out[i];
426    }
427
428    // ── FFN ──
429    let mut tgt2 = tgt.clone();
430    layer_norm_last(&mut tgt2, n_img, d, &w.norm3_g, &w.norm3_b, 1e-5);
431    let dim_ff = w.linear1_b.len();
432    let mut mid = linear(&tgt2, &w.linear1_w, &w.linear1_b, n_img, d, dim_ff);
433    // Reference uses ReLU activation in `memory_attention` (`activation:
434    // relu` in the YAML).
435    for v in mid.iter_mut() {
436        if *v < 0.0 {
437            *v = 0.0;
438        }
439    }
440    let down = linear(&mid, &w.linear2_w, &w.linear2_b, n_img, dim_ff, d);
441    for i in 0..tgt.len() {
442        tgt[i] += down[i];
443    }
444
445    Ok(tgt)
446}
447
448#[allow(clippy::too_many_arguments)]
449fn rope_attn_forward(
450    w: &Sam2RoPEAttnWeights,
451    q: &[f32],
452    q_n: usize,
453    k: &[f32],
454    k_n: usize,
455    v: &[f32],
456    v_n: usize,
457    q_in_dim: usize,
458    kv_in_dim: usize,
459    num_k_exclude_rope: usize,
460) -> Vec<f32> {
461    let d = w.embedding_dim;
462    let id = w.internal_dim;
463    let nh = w.num_heads;
464    let dh = id / nh;
465    let scale = 1.0 / (dh as f32).sqrt();
466    let _ = q_in_dim;
467
468    // 1) Projections.
469    let q_p = linear(q, &w.q_w, &w.q_b, q_n, d, id);
470    let k_p = linear(k, &w.k_w, &w.k_b, k_n, kv_in_dim, id);
471    let v_p = linear(v, &w.v_w, &w.v_b, v_n, kv_in_dim, id);
472
473    // 2) Separate heads: [N, id] → [nh, N, dh] (B=1 implicit).
474    let q_h = separate_heads_b1(&q_p, q_n, nh, dh);
475    let mut k_h = separate_heads_b1(&k_p, k_n, nh, dh);
476    let v_h = separate_heads_b1(&v_p, v_n, nh, dh);
477
478    // 3) Apply axial 2-D RoPE to Q and the first `k_n - num_k_exclude_rope`
479    //    K positions. Memory bank may be `r` spatial frames stacked → use
480    //    `rope_k_repeat=true` to repeat-interleave the freqs `r` times.
481    let num_k_rope = k_n.saturating_sub(num_k_exclude_rope);
482    let [end_x, end_y] = w.rope_feat_size;
483    let spatial = end_x * end_y;
484    let q_h = super::axial_rope::apply_axial_rope_2d(
485        &q_h,
486        nh,
487        q_n,
488        dh,
489        end_x,
490        end_y,
491        w.rope_theta,
492        /*repeat_factor=*/ 1,
493    );
494    if num_k_rope > 0 {
495        let r = if w.rope_k_repeat && num_k_rope >= spatial && num_k_rope.is_multiple_of(spatial) {
496            num_k_rope / spatial
497        } else {
498            1
499        };
500        let mut k_prefix = vec![0f32; nh * num_k_rope * dh];
501        for h in 0..nh {
502            let src = &k_h[h * k_n * dh..(h * k_n + num_k_rope) * dh];
503            let dst = &mut k_prefix[h * num_k_rope * dh..(h + 1) * num_k_rope * dh];
504            dst.copy_from_slice(src);
505        }
506        let rotated = super::axial_rope::apply_axial_rope_2d(
507            &k_prefix,
508            nh,
509            num_k_rope,
510            dh,
511            end_x,
512            end_y,
513            w.rope_theta,
514            r,
515        );
516        for h in 0..nh {
517            let src = &rotated[h * num_k_rope * dh..(h + 1) * num_k_rope * dh];
518            let dst = &mut k_h[h * k_n * dh..(h * k_n + num_k_rope) * dh];
519            dst.copy_from_slice(src);
520        }
521    }
522
523    // 4) Scaled dot-product attention (no mask).
524    let mut out_h = vec![0f32; nh * q_n * dh];
525    let mut scores = vec![0f32; q_n * k_n];
526    for h in 0..nh {
527        for i in 0..q_n {
528            for j in 0..k_n {
529                let mut acc = 0f32;
530                for dd in 0..dh {
531                    acc += q_h[(h * q_n + i) * dh + dd] * k_h[(h * k_n + j) * dh + dd];
532                }
533                scores[i * k_n + j] = acc * scale;
534            }
535        }
536        for i in 0..q_n {
537            let row = &mut scores[i * k_n..(i + 1) * k_n];
538            let m = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
539            let mut s = 0f32;
540            for vv in row.iter_mut() {
541                *vv = (*vv - m).exp();
542                s += *vv;
543            }
544            for vv in row.iter_mut() {
545                *vv /= s;
546            }
547        }
548        for i in 0..q_n {
549            for dd in 0..dh {
550                let mut acc = 0f32;
551                for j in 0..k_n {
552                    acc += scores[i * k_n + j] * v_h[(h * v_n + j) * dh + dd];
553                }
554                out_h[(h * q_n + i) * dh + dd] = acc;
555            }
556        }
557    }
558
559    // 5) Recombine heads → [q_n, id]
560    let merged = recombine_heads_b1(&out_h, q_n, nh, dh);
561
562    // 6) Output projection.
563    linear(&merged, &w.out_w, &w.out_b, q_n, id, d)
564}
565
566fn separate_heads_b1(x: &[f32], n: usize, nh: usize, dh: usize) -> Vec<f32> {
567    let mut out = vec![0f32; nh * n * dh];
568    for i in 0..n {
569        for h in 0..nh {
570            for d in 0..dh {
571                out[(h * n + i) * dh + d] = x[i * (nh * dh) + h * dh + d];
572            }
573        }
574    }
575    out
576}
577
578fn recombine_heads_b1(x: &[f32], n: usize, nh: usize, dh: usize) -> Vec<f32> {
579    let mut out = vec![0f32; n * nh * dh];
580    for h in 0..nh {
581        for i in 0..n {
582            for d in 0..dh {
583                out[i * (nh * dh) + h * dh + d] = x[(h * n + i) * dh + d];
584            }
585        }
586    }
587    out
588}