Skip to main content

rlx_sam/
transformer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 two-way transformer — host-side.
17//!
18//! Implements `transformer.rs` from candle's `segment_anything`
19//! module verbatim: 2 layers of `TwoWayAttentionBlock` (self-attn on
20//! tokens → cross token→image → ReLU MLP → cross image→token) followed
21//! by a final token→image attention + LayerNorm.
22//!
23//! Decoder compute is small enough that staying on the CPU is the
24//! right tradeoff vs. growing the IR surface with cross-attention,
25//! ConvTranspose2d, etc.
26
27use anyhow::{Result, ensure};
28use rlx_core::weight_map::WeightMap;
29
30/// Weights for one `Attention` layer (`embed_dim → internal_dim → embed_dim`).
31pub struct AttentionWeights {
32    pub q_w: Vec<f32>, // [internal_dim, embed_dim] (PyTorch row-major)
33    pub q_b: Vec<f32>,
34    pub k_w: Vec<f32>,
35    pub k_b: Vec<f32>,
36    pub v_w: Vec<f32>,
37    pub v_b: Vec<f32>,
38    pub out_w: Vec<f32>, // [embed_dim, internal_dim]
39    pub out_b: Vec<f32>,
40    pub num_heads: usize,
41    pub embed_dim: usize,
42    pub internal_dim: usize,
43}
44
45pub struct TwoWayAttentionBlockWeights {
46    pub self_attn: AttentionWeights,
47    pub norm1_g: Vec<f32>,
48    pub norm1_b: Vec<f32>,
49    pub cross_token_to_image: AttentionWeights,
50    pub norm2_g: Vec<f32>,
51    pub norm2_b: Vec<f32>,
52    pub mlp_lin1_w: Vec<f32>,
53    pub mlp_lin1_b: Vec<f32>,
54    pub mlp_lin2_w: Vec<f32>,
55    pub mlp_lin2_b: Vec<f32>,
56    pub norm3_g: Vec<f32>,
57    pub norm3_b: Vec<f32>,
58    pub cross_image_to_token: AttentionWeights,
59    pub norm4_g: Vec<f32>,
60    pub norm4_b: Vec<f32>,
61    pub skip_first_layer_pe: bool,
62}
63
64pub struct TwoWayTransformerWeights {
65    pub layers: Vec<TwoWayAttentionBlockWeights>,
66    pub final_attn_token_to_image: AttentionWeights,
67    pub norm_final_g: Vec<f32>,
68    pub norm_final_b: Vec<f32>,
69    pub embed_dim: usize,
70}
71
72fn load_attention(
73    weights: &mut WeightMap,
74    prefix: &str,
75    embed_dim: usize,
76    num_heads: usize,
77    downsample_rate: usize,
78) -> Result<AttentionWeights> {
79    let internal_dim = embed_dim / downsample_rate;
80    let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
81    ensure!(
82        sh == vec![internal_dim, embed_dim],
83        "{prefix}.q_proj.weight shape {sh:?}"
84    );
85    let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
86    let (k_w, _) = weights.take(&format!("{prefix}.k_proj.weight"))?;
87    let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
88    let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
89    let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
90    let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
91    ensure!(
92        sh == vec![embed_dim, internal_dim],
93        "{prefix}.out_proj.weight shape {sh:?}"
94    );
95    let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
96    Ok(AttentionWeights {
97        q_w,
98        q_b,
99        k_w,
100        k_b,
101        v_w,
102        v_b,
103        out_w,
104        out_b,
105        num_heads,
106        embed_dim,
107        internal_dim,
108    })
109}
110
111pub(super) fn extract_two_way_transformer_weights(
112    weights: &mut WeightMap,
113    embed_dim: usize,
114    depth: usize,
115    num_heads: usize,
116    mlp_dim: usize,
117) -> Result<TwoWayTransformerWeights> {
118    let mut layers = Vec::with_capacity(depth);
119    for i in 0..depth {
120        let p = format!("mask_decoder.transformer.layers.{i}");
121        let self_attn =
122            load_attention(weights, &format!("{p}.self_attn"), embed_dim, num_heads, 1)?;
123        let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
124        let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
125        let cross_t2i = load_attention(
126            weights,
127            &format!("{p}.cross_attn_token_to_image"),
128            embed_dim,
129            num_heads,
130            2,
131        )?;
132        let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
133        let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
134        let (mlp_lin1_w, sh) = weights.take(&format!("{p}.mlp.lin1.weight"))?;
135        ensure!(
136            sh == vec![mlp_dim, embed_dim],
137            "{p}.mlp.lin1.weight shape {sh:?}"
138        );
139        let (mlp_lin1_b, _) = weights.take(&format!("{p}.mlp.lin1.bias"))?;
140        let (mlp_lin2_w, _) = weights.take(&format!("{p}.mlp.lin2.weight"))?;
141        let (mlp_lin2_b, _) = weights.take(&format!("{p}.mlp.lin2.bias"))?;
142        let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
143        let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
144        let cross_i2t = load_attention(
145            weights,
146            &format!("{p}.cross_attn_image_to_token"),
147            embed_dim,
148            num_heads,
149            2,
150        )?;
151        let (norm4_g, _) = weights.take(&format!("{p}.norm4.weight"))?;
152        let (norm4_b, _) = weights.take(&format!("{p}.norm4.bias"))?;
153        layers.push(TwoWayAttentionBlockWeights {
154            self_attn,
155            norm1_g,
156            norm1_b,
157            cross_token_to_image: cross_t2i,
158            norm2_g,
159            norm2_b,
160            mlp_lin1_w,
161            mlp_lin1_b,
162            mlp_lin2_w,
163            mlp_lin2_b,
164            norm3_g,
165            norm3_b,
166            cross_image_to_token: cross_i2t,
167            norm4_g,
168            norm4_b,
169            skip_first_layer_pe: i == 0,
170        });
171    }
172    let final_attn = load_attention(
173        weights,
174        "mask_decoder.transformer.final_attn_token_to_image",
175        embed_dim,
176        num_heads,
177        2,
178    )?;
179    let (norm_final_g, _) = weights.take("mask_decoder.transformer.norm_final_attn.weight")?;
180    let (norm_final_b, _) = weights.take("mask_decoder.transformer.norm_final_attn.bias")?;
181    Ok(TwoWayTransformerWeights {
182        layers,
183        final_attn_token_to_image: final_attn,
184        norm_final_g,
185        norm_final_b,
186        embed_dim,
187    })
188}
189
190// ─── Host-side execution ─────────────────────────────────────────
191
192/// Standard scaled-dot-product multi-head attention (candle's
193/// `transformer::Attention::forward`). All inputs `[B, N_*, embed_dim]`.
194///
195/// `b` is the batch dim (typically 1 for a single prompt batch in SAM).
196/// Each of `q, k, v` may have a different sequence length.
197pub fn attention_forward(
198    w: &AttentionWeights,
199    q: &[f32],
200    q_n: usize,
201    k: &[f32],
202    k_n: usize,
203    v: &[f32],
204    v_n: usize,
205    b: usize,
206) -> Vec<f32> {
207    let e = w.embed_dim;
208    let id = w.internal_dim;
209    let nh = w.num_heads;
210    let dh = id / nh;
211    let scale = 1.0 / (dh as f32).sqrt();
212
213    // Project Q, K, V to internal_dim: [B, N, id]
214    let q_p = linear(q, &w.q_w, &w.q_b, b * q_n, e, id);
215    let k_p = linear(k, &w.k_w, &w.k_b, b * k_n, e, id);
216    let v_p = linear(v, &w.v_w, &w.v_b, b * v_n, e, id);
217
218    // Separate heads: [B, N, id] → [B, nh, N, dh]
219    let q_h = separate_heads(&q_p, b, q_n, nh, dh);
220    let k_h = separate_heads(&k_p, b, k_n, nh, dh);
221    let v_h = separate_heads(&v_p, b, v_n, nh, dh);
222
223    // Scaled dot-product per (b, h), BLAS-backed:
224    //   scores = q_h @ k_h^T * scale    [B, nh, q_n, k_n]
225    //   attn = softmax_last_dim(scores)
226    //   out_h = attn @ v_h               [B, nh, q_n, dh]
227    let mut out_h = vec![0f32; b * nh * q_n * dh];
228    let mut scores = vec![0f32; q_n * k_n];
229    // k_h^T per head: pre-transpose into [dh, k_n] once per head so
230    // sgemm sees a standard `[q_n, dh] @ [dh, k_n]`.
231    let mut k_t = vec![0f32; dh * k_n];
232    for bi in 0..b {
233        for h in 0..nh {
234            let q_off = ((bi * nh) + h) * q_n * dh;
235            let k_off = ((bi * nh) + h) * k_n * dh;
236            let v_off = ((bi * nh) + h) * v_n * dh;
237            let out_off = ((bi * nh) + h) * q_n * dh;
238
239            // Build k_h^T as [dh, k_n]
240            for j in 0..k_n {
241                for d in 0..dh {
242                    k_t[d * k_n + j] = k_h[k_off + j * dh + d];
243                }
244            }
245            // scores = q @ k_t   (no bias)
246            rlx_cpu::blas::sgemm_auto(
247                &q_h[q_off..q_off + q_n * dh],
248                &k_t,
249                &mut scores,
250                q_n,
251                dh,
252                k_n,
253            );
254            // Apply scale and softmax in one row-pass.
255            for i in 0..q_n {
256                let row = &mut scores[i * k_n..(i + 1) * k_n];
257                let mut m = f32::NEG_INFINITY;
258                for v in row.iter_mut() {
259                    *v *= scale;
260                    if *v > m {
261                        m = *v;
262                    }
263                }
264                let mut s = 0f32;
265                for v in row.iter_mut() {
266                    *v = (*v - m).exp();
267                    s += *v;
268                }
269                let inv = 1.0 / s;
270                for v in row.iter_mut() {
271                    *v *= inv;
272                }
273            }
274            // out = scores @ V  (V is already [k_n, dh] row-major within
275            // this head's slice)
276            rlx_cpu::blas::sgemm_auto(
277                &scores,
278                &v_h[v_off..v_off + v_n * dh],
279                &mut out_h[out_off..out_off + q_n * dh],
280                q_n,
281                k_n,
282                dh,
283            );
284        }
285    }
286
287    // Recombine heads: [B, nh, q_n, dh] → [B, q_n, id]
288    let merged = recombine_heads(&out_h, b, q_n, nh, dh);
289    // Output projection
290    linear(&merged, &w.out_w, &w.out_b, b * q_n, id, e)
291}
292
293/// Standard PyTorch Linear: `y = x @ W^T + b` where `W: [out, in]`.
294/// `x: [rows, in]`, output `[rows, out]`.
295///
296/// Uses `rlx-cpu`'s NEON/AMX-tuned `sgemm_auto` for the inner matmul,
297/// then adds the per-output bias. Weight is given as `[out, in]`
298/// (PyTorch layout), so we transpose into `[in, out]` row-major
299/// before sgemm.
300///
301/// The transpose is a `rows·in·out` matmul amortized over `rows`
302/// iters in the SAM decoder, so the one-time `in·out` transpose is
303/// negligible. For the IoU head (single-row matmul) we still see a
304/// 5–10× speedup over naive loops because BLAS uses AMX/NEON SIMD.
305pub fn linear(x: &[f32], w: &[f32], b: &[f32], rows: usize, in_d: usize, out_d: usize) -> Vec<f32> {
306    let mut w_t = vec![0f32; in_d * out_d];
307    for o in 0..out_d {
308        for k in 0..in_d {
309            w_t[k * out_d + o] = w[o * in_d + k];
310        }
311    }
312    let mut out = vec![0f32; rows * out_d];
313    rlx_cpu::blas::sgemm_auto(x, &w_t, &mut out, rows, in_d, out_d);
314    for r in 0..rows {
315        for o in 0..out_d {
316            out[r * out_d + o] += b[o];
317        }
318    }
319    out
320}
321
322fn separate_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
323    // [B, N, nh*dh] → [B, nh, N, dh]
324    let mut out = vec![0f32; b * nh * n * dh];
325    for bi in 0..b {
326        for i in 0..n {
327            for h in 0..nh {
328                for d in 0..dh {
329                    out[((bi * nh + h) * n + i) * dh + d] =
330                        x[(bi * n + i) * (nh * dh) + h * dh + d];
331                }
332            }
333        }
334    }
335    out
336}
337
338fn recombine_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
339    // [B, nh, N, dh] → [B, N, nh*dh]
340    let mut out = vec![0f32; b * n * nh * dh];
341    for bi in 0..b {
342        for h in 0..nh {
343            for i in 0..n {
344                for d in 0..dh {
345                    out[(bi * n + i) * (nh * dh) + h * dh + d] =
346                        x[((bi * nh + h) * n + i) * dh + d];
347                }
348            }
349        }
350    }
351    out
352}
353
354/// LayerNorm over the last axis. `x: [rows, n]`.
355/// LayerNorm over the last axis. `x: [rows, n]` (two-pass variance).
356pub fn layer_norm_last(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
357    for r in 0..rows {
358        let row = &mut x[r * n..(r + 1) * n];
359        let mut mean = 0f32;
360        for v in row.iter() {
361            mean += *v;
362        }
363        mean /= n as f32;
364        let mut var = 0f32;
365        for v in row.iter() {
366            let d = *v - mean;
367            var += d * d;
368        }
369        var /= n as f32;
370        let inv = 1.0 / (var + eps).sqrt();
371        for k in 0..n {
372            row[k] = (row[k] - mean) * inv * g[k] + b[k];
373        }
374    }
375}
376
377/// Matches compiled `Op::LayerNorm` on CPU backends.
378pub fn layer_norm_last_cpu(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
379    let mut tmp = vec![0f32; n];
380    for r in 0..rows {
381        let base = r * n;
382        rlx_cpu::kernels::layer_norm_row(&x[base..base + n], g, b, &mut tmp, n, eps);
383        x[base..base + n].copy_from_slice(&tmp);
384    }
385}
386
387fn add_inplace(dst: &mut [f32], src: &[f32]) {
388    for (d, s) in dst.iter_mut().zip(src.iter()) {
389        *d += *s;
390    }
391}
392
393fn relu_inplace(x: &mut [f32]) {
394    for v in x.iter_mut() {
395        if *v < 0.0 {
396            *v = 0.0;
397        }
398    }
399}
400
401/// One TwoWayAttentionBlock forward. `queries: [B, q_n, E]`,
402/// `keys: [B, k_n, E]`. `query_pe`/`key_pe` same shapes as q/k.
403pub fn two_way_attention_block_forward(
404    w: &TwoWayAttentionBlockWeights,
405    queries: Vec<f32>,
406    keys: Vec<f32>,
407    query_pe: &[f32],
408    key_pe: &[f32],
409    b: usize,
410    q_n: usize,
411    k_n: usize,
412) -> (Vec<f32>, Vec<f32>) {
413    let e = w.self_attn.embed_dim;
414
415    // ── Self attention block ──
416    let mut queries = if w.skip_first_layer_pe {
417        attention_forward(&w.self_attn, &queries, q_n, &queries, q_n, &queries, q_n, b)
418    } else {
419        let mut q = queries.clone();
420        add_inplace(&mut q, query_pe);
421        let attn_out = attention_forward(&w.self_attn, &q, q_n, &q, q_n, &queries, q_n, b);
422        let mut out = queries;
423        add_inplace(&mut out, &attn_out);
424        out
425    };
426    layer_norm_last(&mut queries, b * q_n, e, &w.norm1_g, &w.norm1_b, 1e-5);
427
428    // ── Cross attention block, tokens attending to image ──
429    let mut q_pe = queries.clone();
430    add_inplace(&mut q_pe, query_pe);
431    let mut k_pe = keys.clone();
432    add_inplace(&mut k_pe, key_pe);
433    let attn_out = attention_forward(
434        &w.cross_token_to_image,
435        &q_pe,
436        q_n,
437        &k_pe,
438        k_n,
439        &keys,
440        k_n,
441        b,
442    );
443    add_inplace(&mut queries, &attn_out);
444    layer_norm_last(&mut queries, b * q_n, e, &w.norm2_g, &w.norm2_b, 1e-5);
445
446    // ── MLP block (Linear → ReLU → Linear, candle uses Activation::Relu) ──
447    let mlp_dim = w.mlp_lin1_b.len();
448    let mut mlp_mid = linear(&queries, &w.mlp_lin1_w, &w.mlp_lin1_b, b * q_n, e, mlp_dim);
449    relu_inplace(&mut mlp_mid);
450    let mlp_out = linear(&mlp_mid, &w.mlp_lin2_w, &w.mlp_lin2_b, b * q_n, mlp_dim, e);
451    add_inplace(&mut queries, &mlp_out);
452    layer_norm_last(&mut queries, b * q_n, e, &w.norm3_g, &w.norm3_b, 1e-5);
453
454    // ── Cross attention block, image attending to tokens ──
455    let mut q_pe = queries.clone();
456    add_inplace(&mut q_pe, query_pe);
457    let mut k_pe = keys.clone();
458    add_inplace(&mut k_pe, key_pe);
459    // Per candle: q = k_pe, k = q_pe, v = queries
460    let attn_out = attention_forward(
461        &w.cross_image_to_token,
462        &k_pe,
463        k_n,
464        &q_pe,
465        q_n,
466        &queries,
467        q_n,
468        b,
469    );
470    let mut keys = keys;
471    add_inplace(&mut keys, &attn_out);
472    layer_norm_last(&mut keys, b * k_n, e, &w.norm4_g, &w.norm4_b, 1e-5);
473
474    (queries, keys)
475}
476
477/// Top-level two-way transformer forward.
478///
479/// `image_embedding`: NCHW `[B, C, H, W]` (flat).
480/// `image_pe`: same shape.
481/// `point_embedding`: `[B, q_n, E]`.
482///
483/// Returns `(queries, keys)` where queries is `[B, q_n, E]` and keys is
484/// `[B, H*W, E]` (after the final LN).
485pub fn two_way_transformer_forward(
486    w: &TwoWayTransformerWeights,
487    image_embedding: &[f32],
488    image_pe: &[f32],
489    point_embedding: &[f32],
490    b: usize,
491    c: usize,
492    h: usize,
493    ww: usize,
494    q_n: usize,
495) -> (Vec<f32>, Vec<f32>) {
496    let k_n = h * ww;
497    // Flatten NCHW → [B, H*W, C]
498    let mut image_seq = vec![0f32; b * k_n * c];
499    let mut image_pe_seq = vec![0f32; b * k_n * c];
500    for bi in 0..b {
501        for y in 0..h {
502            for x in 0..ww {
503                for ch in 0..c {
504                    let src = (bi * c + ch) * h * ww + y * ww + x;
505                    let dst = (bi * k_n + y * ww + x) * c + ch;
506                    image_seq[dst] = image_embedding[src];
507                    image_pe_seq[dst] = image_pe[src];
508                }
509            }
510        }
511    }
512
513    let mut queries = point_embedding.to_vec();
514    let mut keys = image_seq;
515
516    for layer in &w.layers {
517        let (q, k) = two_way_attention_block_forward(
518            layer,
519            queries,
520            keys,
521            point_embedding,
522            &image_pe_seq,
523            b,
524            q_n,
525            k_n,
526        );
527        queries = q;
528        keys = k;
529    }
530
531    // Final cross attention token→image
532    let mut q_pe = queries.clone();
533    add_inplace(&mut q_pe, point_embedding);
534    let mut k_pe = keys.clone();
535    add_inplace(&mut k_pe, &image_pe_seq);
536    let attn_out = attention_forward(
537        &w.final_attn_token_to_image,
538        &q_pe,
539        q_n,
540        &k_pe,
541        k_n,
542        &keys,
543        k_n,
544        b,
545    );
546    add_inplace(&mut queries, &attn_out);
547    layer_norm_last(
548        &mut queries,
549        b * q_n,
550        w.embed_dim,
551        &w.norm_final_g,
552        &w.norm_final_b,
553        1e-5,
554    );
555
556    (queries, keys)
557}