Skip to main content

rlx_sam_ir/
twoway_transformer_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Two-way transformer IR (`Op::Attention` + LayerNorm + ReLU MLP).
17
18use anyhow::Result;
19use rlx_flow::CompileProfile;
20use rlx_ir::op::{Activation, BinaryOp, MaskKind};
21use rlx_ir::{DType, Graph, NodeId, Shape};
22use rlx_runtime::{CompiledGraph, Device};
23use std::collections::HashMap;
24
25const LN_EPS: f32 = 1e-5;
26
27/// Max extra sparse prompt tokens (points/boxes) compiled into the IR graph.
28pub const MAX_SPARSE_PROMPT_TOKENS: usize = 32;
29
30struct LayerMaskIds {
31    self_attn: NodeId,
32    t2i: NodeId,
33    i2t: NodeId,
34}
35
36/// Attention weights in PyTorch `[out, in]` layout.
37#[derive(Clone)]
38pub struct AttentionSpec {
39    pub q_w: Vec<f32>,
40    pub q_b: Vec<f32>,
41    pub k_w: Vec<f32>,
42    pub k_b: Vec<f32>,
43    pub v_w: Vec<f32>,
44    pub v_b: Vec<f32>,
45    pub out_w: Vec<f32>,
46    pub out_b: Vec<f32>,
47    pub num_heads: usize,
48    pub embed_dim: usize,
49    pub internal_dim: usize,
50}
51
52#[derive(Clone)]
53pub struct TwoWayBlockSpec {
54    pub self_attn: AttentionSpec,
55    pub norm1_g: Vec<f32>,
56    pub norm1_b: Vec<f32>,
57    pub cross_token_to_image: AttentionSpec,
58    pub norm2_g: Vec<f32>,
59    pub norm2_b: Vec<f32>,
60    pub mlp_lin1_w: Vec<f32>,
61    pub mlp_lin1_b: Vec<f32>,
62    pub mlp_lin2_w: Vec<f32>,
63    pub mlp_lin2_b: Vec<f32>,
64    pub norm3_g: Vec<f32>,
65    pub norm3_b: Vec<f32>,
66    pub cross_image_to_token: AttentionSpec,
67    pub norm4_g: Vec<f32>,
68    pub norm4_b: Vec<f32>,
69    pub skip_first_layer_pe: bool,
70}
71
72#[derive(Clone)]
73pub struct TwoWayTransformerSpec {
74    pub layers: Vec<TwoWayBlockSpec>,
75    pub final_attn: AttentionSpec,
76    pub norm_final_g: Vec<f32>,
77    pub norm_final_b: Vec<f32>,
78    pub embed_dim: usize,
79}
80
81pub struct TwoWayTransformerCompiled {
82    graph: CompiledGraph,
83    /// Compiled query-token slots (`base + MAX_SPARSE` when `masked`).
84    pub max_q_n: usize,
85    pub k_n: usize,
86    pub embed_dim: usize,
87    pub num_heads: usize,
88    pub num_layers: usize,
89    /// When true, pass per-attention masks and pad queries to `max_q_n`.
90    pub masked: bool,
91}
92
93impl TwoWayTransformerCompiled {
94    pub fn compile(
95        spec: &TwoWayTransformerSpec,
96        q_n: usize,
97        k_n: usize,
98        device: Device,
99    ) -> Result<Self> {
100        Self::compile_with_profile(
101            spec,
102            q_n,
103            k_n,
104            device,
105            false,
106            &CompileProfile::sam_encoder(),
107        )
108    }
109
110    pub fn compile_with_profile(
111        spec: &TwoWayTransformerSpec,
112        q_n: usize,
113        k_n: usize,
114        device: Device,
115        masked: bool,
116        profile: &CompileProfile,
117    ) -> Result<Self> {
118        Self::compile_inner(spec, q_n, k_n, device, masked, profile)
119    }
120
121    /// `base_q_n` + up to [`MAX_SPARSE_PROMPT_TOKENS`] padded query slots (masked attention).
122    pub fn compile_with_sparse_slots(
123        spec: &TwoWayTransformerSpec,
124        base_q_n: usize,
125        k_n: usize,
126        device: Device,
127    ) -> Result<Self> {
128        let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
129        Self::compile_with_profile(
130            spec,
131            max_q,
132            k_n,
133            device,
134            true,
135            &CompileProfile::sam_encoder(),
136        )
137    }
138
139    pub fn compile_with_sparse_slots_profile(
140        spec: &TwoWayTransformerSpec,
141        base_q_n: usize,
142        k_n: usize,
143        device: Device,
144        profile: &CompileProfile,
145    ) -> Result<Self> {
146        let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
147        Self::compile_with_profile(spec, max_q, k_n, device, true, profile)
148    }
149
150    fn compile_inner(
151        spec: &TwoWayTransformerSpec,
152        max_q_n: usize,
153        k_n: usize,
154        device: Device,
155        masked: bool,
156        profile: &CompileProfile,
157    ) -> Result<Self> {
158        let nh = spec
159            .layers
160            .first()
161            .map(|l| l.self_attn.num_heads)
162            .unwrap_or(spec.final_attn.num_heads);
163        let (graph, params) = build_transformer_graph(spec, max_q_n, k_n, masked)?;
164        let mut compiled =
165            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
166        for (name, data) in &params {
167            compiled.set_param(name, data);
168        }
169        Ok(Self {
170            graph: compiled,
171            max_q_n,
172            k_n,
173            embed_dim: spec.embed_dim,
174            num_heads: nh,
175            num_layers: spec.layers.len(),
176            masked,
177        })
178    }
179
180    /// Fill `[1, H, max_q, max_k]` mask (1 = attend, 0 = mask out).
181    pub fn fill_attn_mask(
182        out: &mut [f32],
183        num_heads: usize,
184        max_q: usize,
185        max_k: usize,
186        active_q: usize,
187        active_k: usize,
188    ) {
189        out.fill(0.0);
190        for h in 0..num_heads {
191            for qi in 0..active_q.min(max_q) {
192                for s in 0..active_k.min(max_k) {
193                    let idx = (h * max_q + qi) * max_k + s;
194                    out[idx] = 1.0;
195                }
196            }
197        }
198    }
199
200    /// NCHW `[E, H, W]` → sequence `[H*W, E]` (same layout as host `two_way_transformer_forward`).
201    pub fn nchw_to_seq(nchw: &[f32], e: usize, h: usize, w: usize) -> Vec<f32> {
202        let k_n = h * w;
203        let mut seq = vec![0f32; k_n * e];
204        for y in 0..h {
205            for x in 0..w {
206                for ch in 0..e {
207                    let src = ch * h * w + y * w + x;
208                    let dst = (y * w + x) * e + ch;
209                    seq[dst] = nchw[src];
210                }
211            }
212        }
213        seq
214    }
215
216    /// `tokens`: `[q_n, E]`; image tensors NCHW `[E, g, g]`.
217    pub fn run_nchw(
218        &mut self,
219        tokens: &[f32],
220        image_nchw: &[f32],
221        image_pe_nchw: &[f32],
222        grid: usize,
223    ) -> Result<(Vec<f32>, Vec<f32>)> {
224        let e = self.embed_dim;
225        let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
226        let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
227        if self.masked {
228            self.run_nchw_masked(tokens, tokens.len() / e, image_nchw, image_pe_nchw, grid)
229        } else {
230            self.run(tokens, &image_seq, &image_pe)
231        }
232    }
233
234    /// Padded-query path: `active_q_n` real tokens, rest masked out.
235    pub fn run_nchw_masked(
236        &mut self,
237        tokens: &[f32],
238        active_q_n: usize,
239        image_nchw: &[f32],
240        image_pe_nchw: &[f32],
241        grid: usize,
242    ) -> Result<(Vec<f32>, Vec<f32>)> {
243        anyhow::ensure!(
244            self.masked,
245            "run_nchw_masked requires compile_with_sparse_slots"
246        );
247        anyhow::ensure!(
248            active_q_n <= self.max_q_n,
249            "active_q_n {active_q_n} > compiled max_q_n {}",
250            self.max_q_n
251        );
252        let e = self.embed_dim;
253        let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
254        let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
255        let mut padded = vec![0f32; self.max_q_n * e];
256        padded[..tokens.len()].copy_from_slice(tokens);
257        let (q, k) = self.run_masked(&padded, active_q_n, &image_seq, &image_pe)?;
258        Ok((q, k))
259    }
260
261    /// `tokens` / `query_pe`: `[q_n, E]`; `image_seq` / `image_pe_seq`: `[k_n, E]` row-major.
262    pub fn run(
263        &mut self,
264        tokens: &[f32],
265        image_seq: &[f32],
266        image_pe_seq: &[f32],
267    ) -> Result<(Vec<f32>, Vec<f32>)> {
268        let e = self.embed_dim;
269        anyhow::ensure!(!self.masked, "use run_masked for masked compile");
270        anyhow::ensure!(tokens.len() == self.max_q_n * e, "tokens len mismatch");
271        anyhow::ensure!(image_seq.len() == self.k_n * e, "image_seq len mismatch");
272        anyhow::ensure!(
273            image_pe_seq.len() == self.k_n * e,
274            "image_pe_seq len mismatch"
275        );
276        let outs = self.graph.run(&[
277            ("tokens", tokens),
278            ("image_seq", image_seq),
279            ("image_pe", image_pe_seq),
280        ]);
281        let mut it = outs.into_iter();
282        let queries = it.next().expect("queries_out");
283        let keys = it.next().expect("keys_out");
284        Ok((queries, keys))
285    }
286
287    pub fn run_masked(
288        &mut self,
289        tokens_padded: &[f32],
290        active_q_n: usize,
291        image_seq: &[f32],
292        image_pe_seq: &[f32],
293    ) -> Result<(Vec<f32>, Vec<f32>)> {
294        let e = self.embed_dim;
295        let nh = self.num_heads;
296        let max_q = self.max_q_n;
297        let max_k = self.k_n;
298        let plane = max_q * max_k;
299        let mut mask_buf = vec![0f32; nh * plane];
300
301        let mut owned: Vec<(String, Vec<f32>)> = vec![
302            ("tokens".into(), tokens_padded.to_vec()),
303            ("image_seq".into(), image_seq.to_vec()),
304            ("image_pe".into(), image_pe_seq.to_vec()),
305        ];
306        for i in 0..self.num_layers {
307            Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_q, active_q_n, active_q_n);
308            owned.push((format!("mask_L{i}_self"), mask_buf.clone()));
309            Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
310            owned.push((format!("mask_L{i}_t2i"), mask_buf.clone()));
311            Self::fill_attn_mask(&mut mask_buf, nh, max_k, max_q, max_k, active_q_n);
312            owned.push((format!("mask_L{i}_i2t"), mask_buf.clone()));
313        }
314        Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
315        owned.push(("mask_final_t2i".into(), mask_buf.clone()));
316
317        let feeds: Vec<(&str, &[f32])> = owned
318            .iter()
319            .map(|(n, d)| (n.as_str(), d.as_slice()))
320            .collect();
321        let outs = self.graph.run(&feeds);
322        let mut it = outs.into_iter();
323        let queries_full = it.next().expect("queries_out");
324        let keys = it.next().expect("keys_out");
325        let mut queries = vec![0f32; active_q_n * e];
326        queries.copy_from_slice(&queries_full[..active_q_n * e]);
327        Ok((queries, keys))
328    }
329}
330
331fn matmul_weight(w_out_in: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
332    let mut t = vec![0f32; in_d * out_d];
333    for o in 0..out_d {
334        for k in 0..in_d {
335            t[k * out_d + o] = w_out_in[o * in_d + k];
336        }
337    }
338    t
339}
340
341fn bind_linear(
342    g: &mut Graph,
343    params: &mut HashMap<String, Vec<f32>>,
344    prefix: &str,
345    w: &[f32],
346    b: &[f32],
347    in_d: usize,
348    out_d: usize,
349) -> (NodeId, NodeId) {
350    let f = DType::F32;
351    let w_id = g.param(format!("{prefix}.w"), Shape::new(&[in_d, out_d], f));
352    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[out_d], f));
353    params.insert(format!("{prefix}.w"), matmul_weight(w, in_d, out_d));
354    params.insert(format!("{prefix}.b"), b.to_vec());
355    (w_id, b_id)
356}
357
358fn linear(
359    g: &mut Graph,
360    params: &mut HashMap<String, Vec<f32>>,
361    prefix: &str,
362    x: NodeId,
363    w: &[f32],
364    b: &[f32],
365    in_d: usize,
366    out_d: usize,
367    seq: usize,
368) -> NodeId {
369    let f = DType::F32;
370    let (w_id, b_id) = bind_linear(g, params, prefix, w, b, in_d, out_d);
371    g.fused_matmul_bias_act(x, w_id, b_id, None, Shape::new(&[1, seq, out_d], f))
372}
373
374fn bind_ln(
375    g: &mut Graph,
376    params: &mut HashMap<String, Vec<f32>>,
377    prefix: &str,
378    gamm: &[f32],
379    bet: &[f32],
380    e: usize,
381) -> (NodeId, NodeId) {
382    let f = DType::F32;
383    let g_id = g.param(format!("{prefix}.g"), Shape::new(&[e], f));
384    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[e], f));
385    params.insert(format!("{prefix}.g"), gamm.to_vec());
386    params.insert(format!("{prefix}.b"), bet.to_vec());
387    (g_id, b_id)
388}
389
390fn layer_norm(
391    g: &mut Graph,
392    params: &mut HashMap<String, Vec<f32>>,
393    prefix: &str,
394    x: NodeId,
395    gamm: &[f32],
396    bet: &[f32],
397    seq: usize,
398    e: usize,
399) -> NodeId {
400    let f = DType::F32;
401    let shape = Shape::new(&[1, seq, e], f);
402    let (g_id, b_id) = bind_ln(g, params, prefix, gamm, bet, e);
403    g.layer_norm(x, g_id, b_id, -1, LN_EPS, shape)
404}
405
406fn build_attention(
407    g: &mut Graph,
408    params: &mut HashMap<String, Vec<f32>>,
409    prefix: &str,
410    spec: &AttentionSpec,
411    q_in: NodeId,
412    k_in: NodeId,
413    v_in: NodeId,
414    q_len: usize,
415    k_len: usize,
416    mask: Option<NodeId>,
417) -> NodeId {
418    let e = spec.embed_dim;
419    let id = spec.internal_dim;
420    let nh = spec.num_heads;
421    let dh = id / nh;
422    let f = DType::F32;
423
424    let q_proj = linear(
425        g,
426        params,
427        &format!("{prefix}.q"),
428        q_in,
429        &spec.q_w,
430        &spec.q_b,
431        e,
432        id,
433        q_len,
434    );
435    let k_proj = linear(
436        g,
437        params,
438        &format!("{prefix}.k"),
439        k_in,
440        &spec.k_w,
441        &spec.k_b,
442        e,
443        id,
444        k_len,
445    );
446    let v_proj = linear(
447        g,
448        params,
449        &format!("{prefix}.v"),
450        v_in,
451        &spec.v_w,
452        &spec.v_b,
453        e,
454        id,
455        k_len,
456    );
457    let out_shape = Shape::new(&[1, q_len, id], f);
458    let attn = if let Some(m) = mask {
459        g.attention(q_proj, k_proj, v_proj, m, nh, dh, out_shape.clone())
460    } else {
461        g.attention_kind(
462            q_proj,
463            k_proj,
464            v_proj,
465            nh,
466            dh,
467            MaskKind::None,
468            out_shape.clone(),
469        )
470    };
471    linear(
472        g,
473        params,
474        &format!("{prefix}.o"),
475        attn,
476        &spec.out_w,
477        &spec.out_b,
478        id,
479        e,
480        q_len,
481    )
482}
483
484fn build_block(
485    g: &mut Graph,
486    params: &mut HashMap<String, Vec<f32>>,
487    prefix: &str,
488    block: &TwoWayBlockSpec,
489    queries: NodeId,
490    keys: NodeId,
491    query_pe: NodeId,
492    key_pe: NodeId,
493    q_n: usize,
494    k_n: usize,
495    e: usize,
496    masks: Option<&LayerMaskIds>,
497) -> (NodeId, NodeId) {
498    let f = DType::F32;
499    let q_shape = Shape::new(&[1, q_n, e], f);
500    let k_shape = Shape::new(&[1, k_n, e], f);
501
502    let m_self = masks.map(|m| m.self_attn);
503    let m_t2i = masks.map(|m| m.t2i);
504    let m_i2t = masks.map(|m| m.i2t);
505
506    let mut q = if block.skip_first_layer_pe {
507        build_attention(
508            g,
509            params,
510            &format!("{prefix}.self"),
511            &block.self_attn,
512            queries,
513            queries,
514            queries,
515            q_n,
516            q_n,
517            m_self,
518        )
519    } else {
520        let q_pe_sum = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
521        let attn = build_attention(
522            g,
523            params,
524            &format!("{prefix}.self"),
525            &block.self_attn,
526            q_pe_sum,
527            q_pe_sum,
528            queries,
529            q_n,
530            q_n,
531            m_self,
532        );
533        g.binary(BinaryOp::Add, queries, attn, q_shape.clone())
534    };
535    q = layer_norm(
536        g,
537        params,
538        &format!("{prefix}.n1"),
539        q,
540        &block.norm1_g,
541        &block.norm1_b,
542        q_n,
543        e,
544    );
545
546    let q_pe_sum = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
547    let k_pe_sum = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
548    let cross_t = build_attention(
549        g,
550        params,
551        &format!("{prefix}.t2i"),
552        &block.cross_token_to_image,
553        q_pe_sum,
554        k_pe_sum,
555        keys,
556        q_n,
557        k_n,
558        m_t2i,
559    );
560    q = g.binary(BinaryOp::Add, q, cross_t, q_shape.clone());
561    q = layer_norm(
562        g,
563        params,
564        &format!("{prefix}.n2"),
565        q,
566        &block.norm2_g,
567        &block.norm2_b,
568        q_n,
569        e,
570    );
571
572    let mlp_dim = block.mlp_lin1_b.len();
573    let mid = linear(
574        g,
575        params,
576        &format!("{prefix}.mlp1"),
577        q,
578        &block.mlp_lin1_w,
579        &block.mlp_lin1_b,
580        e,
581        mlp_dim,
582        q_n,
583    );
584    let mid_relu = g.activation(Activation::Relu, mid, Shape::new(&[1, q_n, mlp_dim], f));
585    let mlp_out = linear(
586        g,
587        params,
588        &format!("{prefix}.mlp2"),
589        mid_relu,
590        &block.mlp_lin2_w,
591        &block.mlp_lin2_b,
592        mlp_dim,
593        e,
594        q_n,
595    );
596    q = g.binary(BinaryOp::Add, q, mlp_out, q_shape.clone());
597    q = layer_norm(
598        g,
599        params,
600        &format!("{prefix}.n3"),
601        q,
602        &block.norm3_g,
603        &block.norm3_b,
604        q_n,
605        e,
606    );
607
608    let q_pe2 = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
609    let k_pe2 = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
610    let cross_i = build_attention(
611        g,
612        params,
613        &format!("{prefix}.i2t"),
614        &block.cross_image_to_token,
615        k_pe2,
616        q_pe2,
617        q,
618        k_n,
619        q_n,
620        m_i2t,
621    );
622    let keys_out = g.binary(BinaryOp::Add, keys, cross_i, k_shape);
623    let keys_out = layer_norm(
624        g,
625        params,
626        &format!("{prefix}.n4"),
627        keys_out,
628        &block.norm4_g,
629        &block.norm4_b,
630        k_n,
631        e,
632    );
633    (q, keys_out)
634}
635
636fn build_transformer_graph(
637    spec: &TwoWayTransformerSpec,
638    q_n: usize,
639    k_n: usize,
640    masked: bool,
641) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
642    let e = spec.embed_dim;
643    let f = DType::F32;
644    let mut g = Graph::new("twoway_transformer");
645    let mut params = HashMap::new();
646    let nh0 = spec
647        .layers
648        .first()
649        .map(|l| l.self_attn.num_heads)
650        .unwrap_or(spec.final_attn.num_heads);
651
652    let tokens = g.input("tokens", Shape::new(&[1, q_n, e], f));
653    let image_seq = g.input("image_seq", Shape::new(&[1, k_n, e], f));
654    let image_pe = g.input("image_pe", Shape::new(&[1, k_n, e], f));
655    let query_pe = tokens;
656
657    let mut layer_masks = Vec::new();
658    if masked {
659        for i in 0..spec.layers.len() {
660            let nh = spec.layers[i].self_attn.num_heads;
661            layer_masks.push(LayerMaskIds {
662                self_attn: g.input(format!("mask_L{i}_self"), Shape::new(&[1, nh, q_n, q_n], f)),
663                t2i: g.input(format!("mask_L{i}_t2i"), Shape::new(&[1, nh, q_n, k_n], f)),
664                i2t: g.input(format!("mask_L{i}_i2t"), Shape::new(&[1, nh, k_n, q_n], f)),
665            });
666        }
667    }
668    let final_mask = if masked {
669        Some(g.input("mask_final_t2i", Shape::new(&[1, nh0, q_n, k_n], f)))
670    } else {
671        None
672    };
673
674    let mut queries = tokens;
675    let mut keys = image_seq;
676    for (i, layer) in spec.layers.iter().enumerate() {
677        let masks = if masked { Some(&layer_masks[i]) } else { None };
678        let (q, k) = build_block(
679            &mut g,
680            &mut params,
681            &format!("L{i}"),
682            layer,
683            queries,
684            keys,
685            query_pe,
686            image_pe,
687            q_n,
688            k_n,
689            e,
690            masks,
691        );
692        queries = q;
693        keys = k;
694    }
695
696    let q_shape = Shape::new(&[1, q_n, e], f);
697    let k_shape = Shape::new(&[1, k_n, e], f);
698    let q_pe_f = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
699    let k_pe_f = g.binary(BinaryOp::Add, keys, image_pe, k_shape.clone());
700    let final_attn = build_attention(
701        &mut g,
702        &mut params,
703        "final",
704        &spec.final_attn,
705        q_pe_f,
706        k_pe_f,
707        keys,
708        q_n,
709        k_n,
710        final_mask,
711    );
712    let queries_out = g.binary(BinaryOp::Add, queries, final_attn, q_shape.clone());
713    let queries_out = layer_norm(
714        &mut g,
715        &mut params,
716        "final_ln",
717        queries_out,
718        &spec.norm_final_g,
719        &spec.norm_final_b,
720        q_n,
721        e,
722    );
723
724    g.set_outputs(vec![queries_out, keys]);
725    Ok((g, params))
726}