rlx-sam-ir 0.2.5

Shared SAM / SAM2 mask-decoder IR helpers
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Two-way transformer IR (`Op::Attention` + LayerNorm + ReLU MLP).

use anyhow::Result;
use rlx_flow::CompileProfile;
use rlx_ir::op::{Activation, BinaryOp, MaskKind};
use rlx_ir::{DType, Graph, NodeId, Shape};
use rlx_runtime::{CompiledGraph, Device};
use std::collections::HashMap;

const LN_EPS: f32 = 1e-5;

/// Max extra sparse prompt tokens (points/boxes) compiled into the IR graph.
pub const MAX_SPARSE_PROMPT_TOKENS: usize = 32;

struct LayerMaskIds {
    self_attn: NodeId,
    t2i: NodeId,
    i2t: NodeId,
}

/// Attention weights in PyTorch `[out, in]` layout.
#[derive(Clone)]
pub struct AttentionSpec {
    pub q_w: Vec<f32>,
    pub q_b: Vec<f32>,
    pub k_w: Vec<f32>,
    pub k_b: Vec<f32>,
    pub v_w: Vec<f32>,
    pub v_b: Vec<f32>,
    pub out_w: Vec<f32>,
    pub out_b: Vec<f32>,
    pub num_heads: usize,
    pub embed_dim: usize,
    pub internal_dim: usize,
}

#[derive(Clone)]
pub struct TwoWayBlockSpec {
    pub self_attn: AttentionSpec,
    pub norm1_g: Vec<f32>,
    pub norm1_b: Vec<f32>,
    pub cross_token_to_image: AttentionSpec,
    pub norm2_g: Vec<f32>,
    pub norm2_b: Vec<f32>,
    pub mlp_lin1_w: Vec<f32>,
    pub mlp_lin1_b: Vec<f32>,
    pub mlp_lin2_w: Vec<f32>,
    pub mlp_lin2_b: Vec<f32>,
    pub norm3_g: Vec<f32>,
    pub norm3_b: Vec<f32>,
    pub cross_image_to_token: AttentionSpec,
    pub norm4_g: Vec<f32>,
    pub norm4_b: Vec<f32>,
    pub skip_first_layer_pe: bool,
}

#[derive(Clone)]
pub struct TwoWayTransformerSpec {
    pub layers: Vec<TwoWayBlockSpec>,
    pub final_attn: AttentionSpec,
    pub norm_final_g: Vec<f32>,
    pub norm_final_b: Vec<f32>,
    pub embed_dim: usize,
}

pub struct TwoWayTransformerCompiled {
    graph: CompiledGraph,
    /// Compiled query-token slots (`base + MAX_SPARSE` when `masked`).
    pub max_q_n: usize,
    pub k_n: usize,
    pub embed_dim: usize,
    pub num_heads: usize,
    pub num_layers: usize,
    /// When true, pass per-attention masks and pad queries to `max_q_n`.
    pub masked: bool,
}

impl TwoWayTransformerCompiled {
    pub fn compile(
        spec: &TwoWayTransformerSpec,
        q_n: usize,
        k_n: usize,
        device: Device,
    ) -> Result<Self> {
        Self::compile_with_profile(
            spec,
            q_n,
            k_n,
            device,
            false,
            &CompileProfile::sam_encoder(),
        )
    }

    pub fn compile_with_profile(
        spec: &TwoWayTransformerSpec,
        q_n: usize,
        k_n: usize,
        device: Device,
        masked: bool,
        profile: &CompileProfile,
    ) -> Result<Self> {
        Self::compile_inner(spec, q_n, k_n, device, masked, profile)
    }

    /// `base_q_n` + up to [`MAX_SPARSE_PROMPT_TOKENS`] padded query slots (masked attention).
    pub fn compile_with_sparse_slots(
        spec: &TwoWayTransformerSpec,
        base_q_n: usize,
        k_n: usize,
        device: Device,
    ) -> Result<Self> {
        let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
        Self::compile_with_profile(
            spec,
            max_q,
            k_n,
            device,
            true,
            &CompileProfile::sam_encoder(),
        )
    }

    pub fn compile_with_sparse_slots_profile(
        spec: &TwoWayTransformerSpec,
        base_q_n: usize,
        k_n: usize,
        device: Device,
        profile: &CompileProfile,
    ) -> Result<Self> {
        let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
        Self::compile_with_profile(spec, max_q, k_n, device, true, profile)
    }

    fn compile_inner(
        spec: &TwoWayTransformerSpec,
        max_q_n: usize,
        k_n: usize,
        device: Device,
        masked: bool,
        profile: &CompileProfile,
    ) -> Result<Self> {
        let nh = spec
            .layers
            .first()
            .map(|l| l.self_attn.num_heads)
            .unwrap_or(spec.final_attn.num_heads);
        let (graph, params) = build_transformer_graph(spec, max_q_n, k_n, masked)?;
        let mut compiled =
            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
        for (name, data) in &params {
            compiled.set_param(name, data);
        }
        Ok(Self {
            graph: compiled,
            max_q_n,
            k_n,
            embed_dim: spec.embed_dim,
            num_heads: nh,
            num_layers: spec.layers.len(),
            masked,
        })
    }

    /// Fill `[1, H, max_q, max_k]` mask (1 = attend, 0 = mask out).
    pub fn fill_attn_mask(
        out: &mut [f32],
        num_heads: usize,
        max_q: usize,
        max_k: usize,
        active_q: usize,
        active_k: usize,
    ) {
        out.fill(0.0);
        for h in 0..num_heads {
            for qi in 0..active_q.min(max_q) {
                for s in 0..active_k.min(max_k) {
                    let idx = (h * max_q + qi) * max_k + s;
                    out[idx] = 1.0;
                }
            }
        }
    }

    /// NCHW `[E, H, W]` → sequence `[H*W, E]` (same layout as host `two_way_transformer_forward`).
    pub fn nchw_to_seq(nchw: &[f32], e: usize, h: usize, w: usize) -> Vec<f32> {
        let k_n = h * w;
        let mut seq = vec![0f32; k_n * e];
        for y in 0..h {
            for x in 0..w {
                for ch in 0..e {
                    let src = ch * h * w + y * w + x;
                    let dst = (y * w + x) * e + ch;
                    seq[dst] = nchw[src];
                }
            }
        }
        seq
    }

    /// `tokens`: `[q_n, E]`; image tensors NCHW `[E, g, g]`.
    pub fn run_nchw(
        &mut self,
        tokens: &[f32],
        image_nchw: &[f32],
        image_pe_nchw: &[f32],
        grid: usize,
    ) -> Result<(Vec<f32>, Vec<f32>)> {
        let e = self.embed_dim;
        let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
        let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
        if self.masked {
            self.run_nchw_masked(tokens, tokens.len() / e, image_nchw, image_pe_nchw, grid)
        } else {
            self.run(tokens, &image_seq, &image_pe)
        }
    }

    /// Padded-query path: `active_q_n` real tokens, rest masked out.
    pub fn run_nchw_masked(
        &mut self,
        tokens: &[f32],
        active_q_n: usize,
        image_nchw: &[f32],
        image_pe_nchw: &[f32],
        grid: usize,
    ) -> Result<(Vec<f32>, Vec<f32>)> {
        anyhow::ensure!(
            self.masked,
            "run_nchw_masked requires compile_with_sparse_slots"
        );
        anyhow::ensure!(
            active_q_n <= self.max_q_n,
            "active_q_n {active_q_n} > compiled max_q_n {}",
            self.max_q_n
        );
        let e = self.embed_dim;
        let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
        let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
        let mut padded = vec![0f32; self.max_q_n * e];
        padded[..tokens.len()].copy_from_slice(tokens);
        let (q, k) = self.run_masked(&padded, active_q_n, &image_seq, &image_pe)?;
        Ok((q, k))
    }

    /// `tokens` / `query_pe`: `[q_n, E]`; `image_seq` / `image_pe_seq`: `[k_n, E]` row-major.
    pub fn run(
        &mut self,
        tokens: &[f32],
        image_seq: &[f32],
        image_pe_seq: &[f32],
    ) -> Result<(Vec<f32>, Vec<f32>)> {
        let e = self.embed_dim;
        anyhow::ensure!(!self.masked, "use run_masked for masked compile");
        anyhow::ensure!(tokens.len() == self.max_q_n * e, "tokens len mismatch");
        anyhow::ensure!(image_seq.len() == self.k_n * e, "image_seq len mismatch");
        anyhow::ensure!(
            image_pe_seq.len() == self.k_n * e,
            "image_pe_seq len mismatch"
        );
        let outs = self.graph.run(&[
            ("tokens", tokens),
            ("image_seq", image_seq),
            ("image_pe", image_pe_seq),
        ]);
        let mut it = outs.into_iter();
        let queries = it.next().expect("queries_out");
        let keys = it.next().expect("keys_out");
        Ok((queries, keys))
    }

    pub fn run_masked(
        &mut self,
        tokens_padded: &[f32],
        active_q_n: usize,
        image_seq: &[f32],
        image_pe_seq: &[f32],
    ) -> Result<(Vec<f32>, Vec<f32>)> {
        let e = self.embed_dim;
        let nh = self.num_heads;
        let max_q = self.max_q_n;
        let max_k = self.k_n;
        let plane = max_q * max_k;
        let mut mask_buf = vec![0f32; nh * plane];

        let mut owned: Vec<(String, Vec<f32>)> = vec![
            ("tokens".into(), tokens_padded.to_vec()),
            ("image_seq".into(), image_seq.to_vec()),
            ("image_pe".into(), image_pe_seq.to_vec()),
        ];
        for i in 0..self.num_layers {
            Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_q, active_q_n, active_q_n);
            owned.push((format!("mask_L{i}_self"), mask_buf.clone()));
            Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
            owned.push((format!("mask_L{i}_t2i"), mask_buf.clone()));
            Self::fill_attn_mask(&mut mask_buf, nh, max_k, max_q, max_k, active_q_n);
            owned.push((format!("mask_L{i}_i2t"), mask_buf.clone()));
        }
        Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
        owned.push(("mask_final_t2i".into(), mask_buf.clone()));

        let feeds: Vec<(&str, &[f32])> = owned
            .iter()
            .map(|(n, d)| (n.as_str(), d.as_slice()))
            .collect();
        let outs = self.graph.run(&feeds);
        let mut it = outs.into_iter();
        let queries_full = it.next().expect("queries_out");
        let keys = it.next().expect("keys_out");
        let mut queries = vec![0f32; active_q_n * e];
        queries.copy_from_slice(&queries_full[..active_q_n * e]);
        Ok((queries, keys))
    }
}

fn matmul_weight(w_out_in: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
    let mut t = vec![0f32; in_d * out_d];
    for o in 0..out_d {
        for k in 0..in_d {
            t[k * out_d + o] = w_out_in[o * in_d + k];
        }
    }
    t
}

fn bind_linear(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    w: &[f32],
    b: &[f32],
    in_d: usize,
    out_d: usize,
) -> (NodeId, NodeId) {
    let f = DType::F32;
    let w_id = g.param(format!("{prefix}.w"), Shape::new(&[in_d, out_d], f));
    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[out_d], f));
    params.insert(format!("{prefix}.w"), matmul_weight(w, in_d, out_d));
    params.insert(format!("{prefix}.b"), b.to_vec());
    (w_id, b_id)
}

fn linear(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    x: NodeId,
    w: &[f32],
    b: &[f32],
    in_d: usize,
    out_d: usize,
    seq: usize,
) -> NodeId {
    let f = DType::F32;
    let (w_id, b_id) = bind_linear(g, params, prefix, w, b, in_d, out_d);
    g.fused_matmul_bias_act(x, w_id, b_id, None, Shape::new(&[1, seq, out_d], f))
}

fn bind_ln(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    gamm: &[f32],
    bet: &[f32],
    e: usize,
) -> (NodeId, NodeId) {
    let f = DType::F32;
    let g_id = g.param(format!("{prefix}.g"), Shape::new(&[e], f));
    let b_id = g.param(format!("{prefix}.b"), Shape::new(&[e], f));
    params.insert(format!("{prefix}.g"), gamm.to_vec());
    params.insert(format!("{prefix}.b"), bet.to_vec());
    (g_id, b_id)
}

fn layer_norm(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    x: NodeId,
    gamm: &[f32],
    bet: &[f32],
    seq: usize,
    e: usize,
) -> NodeId {
    let f = DType::F32;
    let shape = Shape::new(&[1, seq, e], f);
    let (g_id, b_id) = bind_ln(g, params, prefix, gamm, bet, e);
    g.layer_norm(x, g_id, b_id, -1, LN_EPS, shape)
}

fn build_attention(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    spec: &AttentionSpec,
    q_in: NodeId,
    k_in: NodeId,
    v_in: NodeId,
    q_len: usize,
    k_len: usize,
    mask: Option<NodeId>,
) -> NodeId {
    let e = spec.embed_dim;
    let id = spec.internal_dim;
    let nh = spec.num_heads;
    let dh = id / nh;
    let f = DType::F32;

    let q_proj = linear(
        g,
        params,
        &format!("{prefix}.q"),
        q_in,
        &spec.q_w,
        &spec.q_b,
        e,
        id,
        q_len,
    );
    let k_proj = linear(
        g,
        params,
        &format!("{prefix}.k"),
        k_in,
        &spec.k_w,
        &spec.k_b,
        e,
        id,
        k_len,
    );
    let v_proj = linear(
        g,
        params,
        &format!("{prefix}.v"),
        v_in,
        &spec.v_w,
        &spec.v_b,
        e,
        id,
        k_len,
    );
    let out_shape = Shape::new(&[1, q_len, id], f);
    let attn = if let Some(m) = mask {
        g.attention(q_proj, k_proj, v_proj, m, nh, dh, out_shape.clone())
    } else {
        g.attention_kind(
            q_proj,
            k_proj,
            v_proj,
            nh,
            dh,
            MaskKind::None,
            out_shape.clone(),
        )
    };
    linear(
        g,
        params,
        &format!("{prefix}.o"),
        attn,
        &spec.out_w,
        &spec.out_b,
        id,
        e,
        q_len,
    )
}

fn build_block(
    g: &mut Graph,
    params: &mut HashMap<String, Vec<f32>>,
    prefix: &str,
    block: &TwoWayBlockSpec,
    queries: NodeId,
    keys: NodeId,
    query_pe: NodeId,
    key_pe: NodeId,
    q_n: usize,
    k_n: usize,
    e: usize,
    masks: Option<&LayerMaskIds>,
) -> (NodeId, NodeId) {
    let f = DType::F32;
    let q_shape = Shape::new(&[1, q_n, e], f);
    let k_shape = Shape::new(&[1, k_n, e], f);

    let m_self = masks.map(|m| m.self_attn);
    let m_t2i = masks.map(|m| m.t2i);
    let m_i2t = masks.map(|m| m.i2t);

    let mut q = if block.skip_first_layer_pe {
        build_attention(
            g,
            params,
            &format!("{prefix}.self"),
            &block.self_attn,
            queries,
            queries,
            queries,
            q_n,
            q_n,
            m_self,
        )
    } else {
        let q_pe_sum = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
        let attn = build_attention(
            g,
            params,
            &format!("{prefix}.self"),
            &block.self_attn,
            q_pe_sum,
            q_pe_sum,
            queries,
            q_n,
            q_n,
            m_self,
        );
        g.binary(BinaryOp::Add, queries, attn, q_shape.clone())
    };
    q = layer_norm(
        g,
        params,
        &format!("{prefix}.n1"),
        q,
        &block.norm1_g,
        &block.norm1_b,
        q_n,
        e,
    );

    let q_pe_sum = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
    let k_pe_sum = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
    let cross_t = build_attention(
        g,
        params,
        &format!("{prefix}.t2i"),
        &block.cross_token_to_image,
        q_pe_sum,
        k_pe_sum,
        keys,
        q_n,
        k_n,
        m_t2i,
    );
    q = g.binary(BinaryOp::Add, q, cross_t, q_shape.clone());
    q = layer_norm(
        g,
        params,
        &format!("{prefix}.n2"),
        q,
        &block.norm2_g,
        &block.norm2_b,
        q_n,
        e,
    );

    let mlp_dim = block.mlp_lin1_b.len();
    let mid = linear(
        g,
        params,
        &format!("{prefix}.mlp1"),
        q,
        &block.mlp_lin1_w,
        &block.mlp_lin1_b,
        e,
        mlp_dim,
        q_n,
    );
    let mid_relu = g.activation(Activation::Relu, mid, Shape::new(&[1, q_n, mlp_dim], f));
    let mlp_out = linear(
        g,
        params,
        &format!("{prefix}.mlp2"),
        mid_relu,
        &block.mlp_lin2_w,
        &block.mlp_lin2_b,
        mlp_dim,
        e,
        q_n,
    );
    q = g.binary(BinaryOp::Add, q, mlp_out, q_shape.clone());
    q = layer_norm(
        g,
        params,
        &format!("{prefix}.n3"),
        q,
        &block.norm3_g,
        &block.norm3_b,
        q_n,
        e,
    );

    let q_pe2 = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
    let k_pe2 = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
    let cross_i = build_attention(
        g,
        params,
        &format!("{prefix}.i2t"),
        &block.cross_image_to_token,
        k_pe2,
        q_pe2,
        q,
        k_n,
        q_n,
        m_i2t,
    );
    let keys_out = g.binary(BinaryOp::Add, keys, cross_i, k_shape);
    let keys_out = layer_norm(
        g,
        params,
        &format!("{prefix}.n4"),
        keys_out,
        &block.norm4_g,
        &block.norm4_b,
        k_n,
        e,
    );
    (q, keys_out)
}

fn build_transformer_graph(
    spec: &TwoWayTransformerSpec,
    q_n: usize,
    k_n: usize,
    masked: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    let e = spec.embed_dim;
    let f = DType::F32;
    let mut g = Graph::new("twoway_transformer");
    let mut params = HashMap::new();
    let nh0 = spec
        .layers
        .first()
        .map(|l| l.self_attn.num_heads)
        .unwrap_or(spec.final_attn.num_heads);

    let tokens = g.input("tokens", Shape::new(&[1, q_n, e], f));
    let image_seq = g.input("image_seq", Shape::new(&[1, k_n, e], f));
    let image_pe = g.input("image_pe", Shape::new(&[1, k_n, e], f));
    let query_pe = tokens;

    let mut layer_masks = Vec::new();
    if masked {
        for i in 0..spec.layers.len() {
            let nh = spec.layers[i].self_attn.num_heads;
            layer_masks.push(LayerMaskIds {
                self_attn: g.input(format!("mask_L{i}_self"), Shape::new(&[1, nh, q_n, q_n], f)),
                t2i: g.input(format!("mask_L{i}_t2i"), Shape::new(&[1, nh, q_n, k_n], f)),
                i2t: g.input(format!("mask_L{i}_i2t"), Shape::new(&[1, nh, k_n, q_n], f)),
            });
        }
    }
    let final_mask = if masked {
        Some(g.input("mask_final_t2i", Shape::new(&[1, nh0, q_n, k_n], f)))
    } else {
        None
    };

    let mut queries = tokens;
    let mut keys = image_seq;
    for (i, layer) in spec.layers.iter().enumerate() {
        let masks = if masked { Some(&layer_masks[i]) } else { None };
        let (q, k) = build_block(
            &mut g,
            &mut params,
            &format!("L{i}"),
            layer,
            queries,
            keys,
            query_pe,
            image_pe,
            q_n,
            k_n,
            e,
            masks,
        );
        queries = q;
        keys = k;
    }

    let q_shape = Shape::new(&[1, q_n, e], f);
    let k_shape = Shape::new(&[1, k_n, e], f);
    let q_pe_f = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
    let k_pe_f = g.binary(BinaryOp::Add, keys, image_pe, k_shape.clone());
    let final_attn = build_attention(
        &mut g,
        &mut params,
        "final",
        &spec.final_attn,
        q_pe_f,
        k_pe_f,
        keys,
        q_n,
        k_n,
        final_mask,
    );
    let queries_out = g.binary(BinaryOp::Add, queries, final_attn, q_shape.clone());
    let queries_out = layer_norm(
        &mut g,
        &mut params,
        "final_ln",
        queries_out,
        &spec.norm_final_g,
        &spec.norm_final_b,
        q_n,
        e,
    );

    g.set_outputs(vec![queries_out, keys]);
    Ok((g, params))
}