Skip to main content

rlx_sam2/
transformer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 two-way transformer — host-side.
17//!
18//! Mirrors `sam2/modeling/sam/transformer.py::TwoWayTransformer` and
19//! `TwoWayAttentionBlock`. Structurally identical to SAM v1 (2 layers
20//! of self-attn → cross t→i → ReLU MLP → cross i→t, then final
21//! token→image attn + LayerNorm) — the only differences are:
22//!
23//!   - Weight key prefix: `sam_mask_decoder.transformer.*` instead of
24//!     `mask_decoder.transformer.*`.
25//!   - The cross-attention `downsample_rate` is configurable in the
26//!     reference (defaults to 2 for the decoder transformer, matching
27//!     v1). We keep the rate as a parameter on
28//!     [`extract_two_way_transformer_weights`].
29//!
30//! Decoder transformer compute is small (q_n ≤ ~10 tokens, k_n = 64²),
31//! so staying on the CPU is the right tradeoff vs. growing the IR
32//! surface with multi-shape cross-attention.
33
34use anyhow::{Result, ensure};
35use rlx_core::weight_map::WeightMap;
36
37/// Weights for one `Attention` layer (`embed_dim → internal_dim → embed_dim`).
38pub struct Sam2AttentionWeights {
39    pub q_w: Vec<f32>, // [internal_dim, embed_dim]
40    pub q_b: Vec<f32>,
41    pub k_w: Vec<f32>,
42    pub k_b: Vec<f32>,
43    pub v_w: Vec<f32>,
44    pub v_b: Vec<f32>,
45    pub out_w: Vec<f32>, // [embed_dim, internal_dim]
46    pub out_b: Vec<f32>,
47    pub num_heads: usize,
48    pub embed_dim: usize,
49    pub internal_dim: usize,
50}
51
52pub struct Sam2TwoWayAttentionBlockWeights {
53    pub self_attn: Sam2AttentionWeights,
54    pub norm1_g: Vec<f32>,
55    pub norm1_b: Vec<f32>,
56    pub cross_token_to_image: Sam2AttentionWeights,
57    pub norm2_g: Vec<f32>,
58    pub norm2_b: Vec<f32>,
59    pub mlp_lin1_w: Vec<f32>,
60    pub mlp_lin1_b: Vec<f32>,
61    pub mlp_lin2_w: Vec<f32>,
62    pub mlp_lin2_b: Vec<f32>,
63    pub norm3_g: Vec<f32>,
64    pub norm3_b: Vec<f32>,
65    pub cross_image_to_token: Sam2AttentionWeights,
66    pub norm4_g: Vec<f32>,
67    pub norm4_b: Vec<f32>,
68    pub skip_first_layer_pe: bool,
69}
70
71pub struct Sam2TwoWayTransformerWeights {
72    pub layers: Vec<Sam2TwoWayAttentionBlockWeights>,
73    pub final_attn_token_to_image: Sam2AttentionWeights,
74    pub norm_final_g: Vec<f32>,
75    pub norm_final_b: Vec<f32>,
76    pub embed_dim: usize,
77}
78
79fn load_attention(
80    weights: &mut WeightMap,
81    prefix: &str,
82    embed_dim: usize,
83    num_heads: usize,
84    downsample_rate: usize,
85) -> Result<Sam2AttentionWeights> {
86    let internal_dim = embed_dim / downsample_rate;
87    let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
88    ensure!(
89        sh == vec![internal_dim, embed_dim],
90        "{prefix}.q_proj.weight shape {sh:?} not [{internal_dim}, {embed_dim}]"
91    );
92    let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
93    let (k_w, _) = weights.take(&format!("{prefix}.k_proj.weight"))?;
94    let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
95    let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
96    let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
97    let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
98    ensure!(
99        sh == vec![embed_dim, internal_dim],
100        "{prefix}.out_proj.weight shape {sh:?} not [{embed_dim}, {internal_dim}]"
101    );
102    let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
103    Ok(Sam2AttentionWeights {
104        q_w,
105        q_b,
106        k_w,
107        k_b,
108        v_w,
109        v_b,
110        out_w,
111        out_b,
112        num_heads,
113        embed_dim,
114        internal_dim,
115    })
116}
117
118pub(super) fn extract_two_way_transformer_weights(
119    weights: &mut WeightMap,
120    embed_dim: usize,
121    depth: usize,
122    num_heads: usize,
123    mlp_dim: usize,
124) -> Result<Sam2TwoWayTransformerWeights> {
125    let mut layers = Vec::with_capacity(depth);
126    for i in 0..depth {
127        let p = format!("sam_mask_decoder.transformer.layers.{i}");
128        let self_attn =
129            load_attention(weights, &format!("{p}.self_attn"), embed_dim, num_heads, 1)?;
130        let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
131        let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
132        let cross_t2i = load_attention(
133            weights,
134            &format!("{p}.cross_attn_token_to_image"),
135            embed_dim,
136            num_heads,
137            2,
138        )?;
139        let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
140        let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
141        let (mlp_lin1_w, sh) = weights.take(&format!("{p}.mlp.layers.0.weight"))?;
142        ensure!(
143            sh == vec![mlp_dim, embed_dim],
144            "{p}.mlp.layers.0.weight shape {sh:?} not [{mlp_dim}, {embed_dim}]"
145        );
146        let (mlp_lin1_b, _) = weights.take(&format!("{p}.mlp.layers.0.bias"))?;
147        let (mlp_lin2_w, _) = weights.take(&format!("{p}.mlp.layers.1.weight"))?;
148        let (mlp_lin2_b, _) = weights.take(&format!("{p}.mlp.layers.1.bias"))?;
149        let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
150        let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
151        let cross_i2t = load_attention(
152            weights,
153            &format!("{p}.cross_attn_image_to_token"),
154            embed_dim,
155            num_heads,
156            2,
157        )?;
158        let (norm4_g, _) = weights.take(&format!("{p}.norm4.weight"))?;
159        let (norm4_b, _) = weights.take(&format!("{p}.norm4.bias"))?;
160        layers.push(Sam2TwoWayAttentionBlockWeights {
161            self_attn,
162            norm1_g,
163            norm1_b,
164            cross_token_to_image: cross_t2i,
165            norm2_g,
166            norm2_b,
167            mlp_lin1_w,
168            mlp_lin1_b,
169            mlp_lin2_w,
170            mlp_lin2_b,
171            norm3_g,
172            norm3_b,
173            cross_image_to_token: cross_i2t,
174            norm4_g,
175            norm4_b,
176            skip_first_layer_pe: i == 0,
177        });
178    }
179    let final_attn = load_attention(
180        weights,
181        "sam_mask_decoder.transformer.final_attn_token_to_image",
182        embed_dim,
183        num_heads,
184        2,
185    )?;
186    let (norm_final_g, _) = weights.take("sam_mask_decoder.transformer.norm_final_attn.weight")?;
187    let (norm_final_b, _) = weights.take("sam_mask_decoder.transformer.norm_final_attn.bias")?;
188    Ok(Sam2TwoWayTransformerWeights {
189        layers,
190        final_attn_token_to_image: final_attn,
191        norm_final_g,
192        norm_final_b,
193        embed_dim,
194    })
195}
196
197// ─── Host-side execution ─────────────────────────────────────────
198
199/// Standard scaled-dot-product multi-head attention.
200/// All inputs `[B, N_*, embed_dim]`. `b` is the batch dim.
201pub fn sam2_attention_forward(
202    w: &Sam2AttentionWeights,
203    q: &[f32],
204    q_n: usize,
205    k: &[f32],
206    k_n: usize,
207    v: &[f32],
208    v_n: usize,
209    b: usize,
210) -> Vec<f32> {
211    let e = w.embed_dim;
212    let id = w.internal_dim;
213    let nh = w.num_heads;
214    let dh = id / nh;
215    let scale = 1.0 / (dh as f32).sqrt();
216
217    let q_p = linear(q, &w.q_w, &w.q_b, b * q_n, e, id);
218    let k_p = linear(k, &w.k_w, &w.k_b, b * k_n, e, id);
219    let v_p = linear(v, &w.v_w, &w.v_b, b * v_n, e, id);
220
221    let q_h = separate_heads(&q_p, b, q_n, nh, dh);
222    let k_h = separate_heads(&k_p, b, k_n, nh, dh);
223    let v_h = separate_heads(&v_p, b, v_n, nh, dh);
224
225    let mut out_h = vec![0f32; b * nh * q_n * dh];
226    let mut scores = vec![0f32; q_n * k_n];
227    for bi in 0..b {
228        for h in 0..nh {
229            let q_off = ((bi * nh) + h) * q_n * dh;
230            let k_off = ((bi * nh) + h) * k_n * dh;
231            let v_off = ((bi * nh) + h) * v_n * dh;
232            let out_off = ((bi * nh) + h) * q_n * dh;
233
234            for i in 0..q_n {
235                for j in 0..k_n {
236                    let mut acc = 0f32;
237                    for d in 0..dh {
238                        acc += q_h[q_off + i * dh + d] * k_h[k_off + j * dh + d];
239                    }
240                    scores[i * k_n + j] = acc * scale;
241                }
242            }
243            for i in 0..q_n {
244                let row = &mut scores[i * k_n..(i + 1) * k_n];
245                let m = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
246                let mut s = 0f32;
247                for v in row.iter_mut() {
248                    *v = (*v - m).exp();
249                    s += *v;
250                }
251                for v in row.iter_mut() {
252                    *v /= s;
253                }
254            }
255            for i in 0..q_n {
256                for d in 0..dh {
257                    let mut acc = 0f32;
258                    for j in 0..k_n {
259                        acc += scores[i * k_n + j] * v_h[v_off + j * dh + d];
260                    }
261                    out_h[out_off + i * dh + d] = acc;
262                }
263            }
264        }
265    }
266
267    let merged = recombine_heads(&out_h, b, q_n, nh, dh);
268    linear(&merged, &w.out_w, &w.out_b, b * q_n, id, e)
269}
270
271/// Standard PyTorch Linear: `y = x @ W^T + b` where `W: [out, in]`.
272pub fn linear(x: &[f32], w: &[f32], b: &[f32], rows: usize, in_d: usize, out_d: usize) -> Vec<f32> {
273    let mut out = vec![0f32; rows * out_d];
274    for r in 0..rows {
275        for o in 0..out_d {
276            let mut acc = b[o];
277            for k in 0..in_d {
278                acc += x[r * in_d + k] * w[o * in_d + k];
279            }
280            out[r * out_d + o] = acc;
281        }
282    }
283    out
284}
285
286fn separate_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
287    let mut out = vec![0f32; b * nh * n * dh];
288    for bi in 0..b {
289        for i in 0..n {
290            for h in 0..nh {
291                for d in 0..dh {
292                    out[((bi * nh + h) * n + i) * dh + d] =
293                        x[(bi * n + i) * (nh * dh) + h * dh + d];
294                }
295            }
296        }
297    }
298    out
299}
300
301fn recombine_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
302    let mut out = vec![0f32; b * n * nh * dh];
303    for bi in 0..b {
304        for h in 0..nh {
305            for i in 0..n {
306                for d in 0..dh {
307                    out[(bi * n + i) * (nh * dh) + h * dh + d] =
308                        x[((bi * nh + h) * n + i) * dh + d];
309                }
310            }
311        }
312    }
313    out
314}
315
316/// LayerNorm over the last axis. `x: [rows, n]` (two-pass variance).
317pub fn layer_norm_last(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
318    for r in 0..rows {
319        let row = &mut x[r * n..(r + 1) * n];
320        let mut mean = 0f32;
321        for v in row.iter() {
322            mean += *v;
323        }
324        mean /= n as f32;
325        let mut var = 0f32;
326        for v in row.iter() {
327            let d = *v - mean;
328            var += d * d;
329        }
330        var /= n as f32;
331        let inv = 1.0 / (var + eps).sqrt();
332        for k in 0..n {
333            row[k] = (row[k] - mean) * inv * g[k] + b[k];
334        }
335    }
336}
337
338/// Stack / IR final norm — matches compiled [`Op::LayerNorm`] on CPU.
339pub fn layer_norm_last_cpu(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
340    let mut tmp = vec![0f32; n];
341    for r in 0..rows {
342        let base = r * n;
343        rlx_cpu::kernels::layer_norm_row(&x[base..base + n], g, b, &mut tmp, n, eps);
344        x[base..base + n].copy_from_slice(&tmp);
345    }
346}
347
348pub(super) fn add_inplace(dst: &mut [f32], src: &[f32]) {
349    for (d, s) in dst.iter_mut().zip(src.iter()) {
350        *d += *s;
351    }
352}
353
354fn relu_inplace(x: &mut [f32]) {
355    for v in x.iter_mut() {
356        if *v < 0.0 {
357            *v = 0.0;
358        }
359    }
360}
361
362/// One `TwoWayAttentionBlock` forward.
363pub fn two_way_attention_block_forward(
364    w: &Sam2TwoWayAttentionBlockWeights,
365    queries: Vec<f32>,
366    keys: Vec<f32>,
367    query_pe: &[f32],
368    key_pe: &[f32],
369    b: usize,
370    q_n: usize,
371    k_n: usize,
372) -> (Vec<f32>, Vec<f32>) {
373    let e = w.self_attn.embed_dim;
374
375    // ── Self attention block ──
376    let mut queries = if w.skip_first_layer_pe {
377        sam2_attention_forward(&w.self_attn, &queries, q_n, &queries, q_n, &queries, q_n, b)
378    } else {
379        let mut q = queries.clone();
380        add_inplace(&mut q, query_pe);
381        let attn_out = sam2_attention_forward(&w.self_attn, &q, q_n, &q, q_n, &queries, q_n, b);
382        let mut out = queries;
383        add_inplace(&mut out, &attn_out);
384        out
385    };
386    layer_norm_last(&mut queries, b * q_n, e, &w.norm1_g, &w.norm1_b, 1e-5);
387
388    // ── Cross attention, tokens attending to image ──
389    let mut q_pe = queries.clone();
390    add_inplace(&mut q_pe, query_pe);
391    let mut k_pe = keys.clone();
392    add_inplace(&mut k_pe, key_pe);
393    let attn_out = sam2_attention_forward(
394        &w.cross_token_to_image,
395        &q_pe,
396        q_n,
397        &k_pe,
398        k_n,
399        &keys,
400        k_n,
401        b,
402    );
403    add_inplace(&mut queries, &attn_out);
404    layer_norm_last(&mut queries, b * q_n, e, &w.norm2_g, &w.norm2_b, 1e-5);
405
406    // ── MLP (ReLU activation per reference's `MLPBlock`) ──
407    let mlp_dim = w.mlp_lin1_b.len();
408    let mut mlp_mid = linear(&queries, &w.mlp_lin1_w, &w.mlp_lin1_b, b * q_n, e, mlp_dim);
409    relu_inplace(&mut mlp_mid);
410    let mlp_out = linear(&mlp_mid, &w.mlp_lin2_w, &w.mlp_lin2_b, b * q_n, mlp_dim, e);
411    add_inplace(&mut queries, &mlp_out);
412    layer_norm_last(&mut queries, b * q_n, e, &w.norm3_g, &w.norm3_b, 1e-5);
413
414    // ── Cross attention, image attending to tokens ──
415    let mut q_pe = queries.clone();
416    add_inplace(&mut q_pe, query_pe);
417    let mut k_pe = keys.clone();
418    add_inplace(&mut k_pe, key_pe);
419    let attn_out = sam2_attention_forward(
420        &w.cross_image_to_token,
421        &k_pe,
422        k_n,
423        &q_pe,
424        q_n,
425        &queries,
426        q_n,
427        b,
428    );
429    let mut keys = keys;
430    add_inplace(&mut keys, &attn_out);
431    layer_norm_last(&mut keys, b * k_n, e, &w.norm4_g, &w.norm4_b, 1e-5);
432
433    (queries, keys)
434}
435
436/// Top-level two-way transformer forward.
437///
438/// `image_embedding`: NCHW `[B, C, H, W]` (flat).
439/// `image_pe`: same shape.
440/// `point_embedding`: `[B, q_n, E]`.
441///
442/// Returns `(queries, keys)` where queries is `[B, q_n, E]` and keys
443/// is `[B, H*W, E]` (after the final LN on queries).
444pub fn two_way_transformer_forward(
445    w: &Sam2TwoWayTransformerWeights,
446    image_embedding: &[f32],
447    image_pe: &[f32],
448    point_embedding: &[f32],
449    b: usize,
450    c: usize,
451    h: usize,
452    ww: usize,
453    q_n: usize,
454) -> (Vec<f32>, Vec<f32>) {
455    let k_n = h * ww;
456    let mut image_seq = vec![0f32; b * k_n * c];
457    let mut image_pe_seq = vec![0f32; b * k_n * c];
458    for bi in 0..b {
459        for y in 0..h {
460            for x in 0..ww {
461                for ch in 0..c {
462                    let src = (bi * c + ch) * h * ww + y * ww + x;
463                    let dst = (bi * k_n + y * ww + x) * c + ch;
464                    image_seq[dst] = image_embedding[src];
465                    image_pe_seq[dst] = image_pe[src];
466                }
467            }
468        }
469    }
470
471    let mut queries = point_embedding.to_vec();
472    let mut keys = image_seq;
473
474    for layer in &w.layers {
475        let (q, k) = two_way_attention_block_forward(
476            layer,
477            queries,
478            keys,
479            point_embedding,
480            &image_pe_seq,
481            b,
482            q_n,
483            k_n,
484        );
485        queries = q;
486        keys = k;
487    }
488
489    // Final cross attention token → image
490    let mut q_pe = queries.clone();
491    add_inplace(&mut q_pe, point_embedding);
492    let mut k_pe = keys.clone();
493    add_inplace(&mut k_pe, &image_pe_seq);
494    let attn_out = sam2_attention_forward(
495        &w.final_attn_token_to_image,
496        &q_pe,
497        q_n,
498        &k_pe,
499        k_n,
500        &keys,
501        k_n,
502        b,
503    );
504    add_inplace(&mut queries, &attn_out);
505    layer_norm_last(
506        &mut queries,
507        b * q_n,
508        w.embed_dim,
509        &w.norm_final_g,
510        &w.norm_final_b,
511        1e-5,
512    );
513
514    (queries, keys)
515}