1use anyhow::Result;
19use rlx_flow::CompileProfile;
20use rlx_ir::op::{Activation, BinaryOp, MaskKind};
21use rlx_ir::{DType, Graph, NodeId, Shape};
22use rlx_runtime::{CompiledGraph, Device};
23use std::collections::HashMap;
24
25const LN_EPS: f32 = 1e-5;
26
27pub const MAX_SPARSE_PROMPT_TOKENS: usize = 32;
29
30struct LayerMaskIds {
31 self_attn: NodeId,
32 t2i: NodeId,
33 i2t: NodeId,
34}
35
36#[derive(Clone)]
38pub struct AttentionSpec {
39 pub q_w: Vec<f32>,
40 pub q_b: Vec<f32>,
41 pub k_w: Vec<f32>,
42 pub k_b: Vec<f32>,
43 pub v_w: Vec<f32>,
44 pub v_b: Vec<f32>,
45 pub out_w: Vec<f32>,
46 pub out_b: Vec<f32>,
47 pub num_heads: usize,
48 pub embed_dim: usize,
49 pub internal_dim: usize,
50}
51
52#[derive(Clone)]
53pub struct TwoWayBlockSpec {
54 pub self_attn: AttentionSpec,
55 pub norm1_g: Vec<f32>,
56 pub norm1_b: Vec<f32>,
57 pub cross_token_to_image: AttentionSpec,
58 pub norm2_g: Vec<f32>,
59 pub norm2_b: Vec<f32>,
60 pub mlp_lin1_w: Vec<f32>,
61 pub mlp_lin1_b: Vec<f32>,
62 pub mlp_lin2_w: Vec<f32>,
63 pub mlp_lin2_b: Vec<f32>,
64 pub norm3_g: Vec<f32>,
65 pub norm3_b: Vec<f32>,
66 pub cross_image_to_token: AttentionSpec,
67 pub norm4_g: Vec<f32>,
68 pub norm4_b: Vec<f32>,
69 pub skip_first_layer_pe: bool,
70}
71
72#[derive(Clone)]
73pub struct TwoWayTransformerSpec {
74 pub layers: Vec<TwoWayBlockSpec>,
75 pub final_attn: AttentionSpec,
76 pub norm_final_g: Vec<f32>,
77 pub norm_final_b: Vec<f32>,
78 pub embed_dim: usize,
79}
80
81pub struct TwoWayTransformerCompiled {
82 graph: CompiledGraph,
83 pub max_q_n: usize,
85 pub k_n: usize,
86 pub embed_dim: usize,
87 pub num_heads: usize,
88 pub num_layers: usize,
89 pub masked: bool,
91}
92
93impl TwoWayTransformerCompiled {
94 pub fn compile(
95 spec: &TwoWayTransformerSpec,
96 q_n: usize,
97 k_n: usize,
98 device: Device,
99 ) -> Result<Self> {
100 Self::compile_with_profile(
101 spec,
102 q_n,
103 k_n,
104 device,
105 false,
106 &CompileProfile::sam_encoder(),
107 )
108 }
109
110 pub fn compile_with_profile(
111 spec: &TwoWayTransformerSpec,
112 q_n: usize,
113 k_n: usize,
114 device: Device,
115 masked: bool,
116 profile: &CompileProfile,
117 ) -> Result<Self> {
118 Self::compile_inner(spec, q_n, k_n, device, masked, profile)
119 }
120
121 pub fn compile_with_sparse_slots(
123 spec: &TwoWayTransformerSpec,
124 base_q_n: usize,
125 k_n: usize,
126 device: Device,
127 ) -> Result<Self> {
128 let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
129 Self::compile_with_profile(
130 spec,
131 max_q,
132 k_n,
133 device,
134 true,
135 &CompileProfile::sam_encoder(),
136 )
137 }
138
139 pub fn compile_with_sparse_slots_profile(
140 spec: &TwoWayTransformerSpec,
141 base_q_n: usize,
142 k_n: usize,
143 device: Device,
144 profile: &CompileProfile,
145 ) -> Result<Self> {
146 let max_q = base_q_n + MAX_SPARSE_PROMPT_TOKENS;
147 Self::compile_with_profile(spec, max_q, k_n, device, true, profile)
148 }
149
150 fn compile_inner(
151 spec: &TwoWayTransformerSpec,
152 max_q_n: usize,
153 k_n: usize,
154 device: Device,
155 masked: bool,
156 profile: &CompileProfile,
157 ) -> Result<Self> {
158 let nh = spec
159 .layers
160 .first()
161 .map(|l| l.self_attn.num_heads)
162 .unwrap_or(spec.final_attn.num_heads);
163 let (graph, params) = build_transformer_graph(spec, max_q_n, k_n, masked)?;
164 let mut compiled =
165 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
166 for (name, data) in ¶ms {
167 compiled.set_param(name, data);
168 }
169 Ok(Self {
170 graph: compiled,
171 max_q_n,
172 k_n,
173 embed_dim: spec.embed_dim,
174 num_heads: nh,
175 num_layers: spec.layers.len(),
176 masked,
177 })
178 }
179
180 pub fn fill_attn_mask(
182 out: &mut [f32],
183 num_heads: usize,
184 max_q: usize,
185 max_k: usize,
186 active_q: usize,
187 active_k: usize,
188 ) {
189 out.fill(0.0);
190 for h in 0..num_heads {
191 for qi in 0..active_q.min(max_q) {
192 for s in 0..active_k.min(max_k) {
193 let idx = (h * max_q + qi) * max_k + s;
194 out[idx] = 1.0;
195 }
196 }
197 }
198 }
199
200 pub fn nchw_to_seq(nchw: &[f32], e: usize, h: usize, w: usize) -> Vec<f32> {
202 let k_n = h * w;
203 let mut seq = vec![0f32; k_n * e];
204 for y in 0..h {
205 for x in 0..w {
206 for ch in 0..e {
207 let src = ch * h * w + y * w + x;
208 let dst = (y * w + x) * e + ch;
209 seq[dst] = nchw[src];
210 }
211 }
212 }
213 seq
214 }
215
216 pub fn run_nchw(
218 &mut self,
219 tokens: &[f32],
220 image_nchw: &[f32],
221 image_pe_nchw: &[f32],
222 grid: usize,
223 ) -> Result<(Vec<f32>, Vec<f32>)> {
224 let e = self.embed_dim;
225 let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
226 let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
227 if self.masked {
228 self.run_nchw_masked(tokens, tokens.len() / e, image_nchw, image_pe_nchw, grid)
229 } else {
230 self.run(tokens, &image_seq, &image_pe)
231 }
232 }
233
234 pub fn run_nchw_masked(
236 &mut self,
237 tokens: &[f32],
238 active_q_n: usize,
239 image_nchw: &[f32],
240 image_pe_nchw: &[f32],
241 grid: usize,
242 ) -> Result<(Vec<f32>, Vec<f32>)> {
243 anyhow::ensure!(
244 self.masked,
245 "run_nchw_masked requires compile_with_sparse_slots"
246 );
247 anyhow::ensure!(
248 active_q_n <= self.max_q_n,
249 "active_q_n {active_q_n} > compiled max_q_n {}",
250 self.max_q_n
251 );
252 let e = self.embed_dim;
253 let image_seq = Self::nchw_to_seq(image_nchw, e, grid, grid);
254 let image_pe = Self::nchw_to_seq(image_pe_nchw, e, grid, grid);
255 let mut padded = vec![0f32; self.max_q_n * e];
256 padded[..tokens.len()].copy_from_slice(tokens);
257 let (q, k) = self.run_masked(&padded, active_q_n, &image_seq, &image_pe)?;
258 Ok((q, k))
259 }
260
261 pub fn run(
263 &mut self,
264 tokens: &[f32],
265 image_seq: &[f32],
266 image_pe_seq: &[f32],
267 ) -> Result<(Vec<f32>, Vec<f32>)> {
268 let e = self.embed_dim;
269 anyhow::ensure!(!self.masked, "use run_masked for masked compile");
270 anyhow::ensure!(tokens.len() == self.max_q_n * e, "tokens len mismatch");
271 anyhow::ensure!(image_seq.len() == self.k_n * e, "image_seq len mismatch");
272 anyhow::ensure!(
273 image_pe_seq.len() == self.k_n * e,
274 "image_pe_seq len mismatch"
275 );
276 let outs = self.graph.run(&[
277 ("tokens", tokens),
278 ("image_seq", image_seq),
279 ("image_pe", image_pe_seq),
280 ]);
281 let mut it = outs.into_iter();
282 let queries = it.next().expect("queries_out");
283 let keys = it.next().expect("keys_out");
284 Ok((queries, keys))
285 }
286
287 pub fn run_masked(
288 &mut self,
289 tokens_padded: &[f32],
290 active_q_n: usize,
291 image_seq: &[f32],
292 image_pe_seq: &[f32],
293 ) -> Result<(Vec<f32>, Vec<f32>)> {
294 let e = self.embed_dim;
295 let nh = self.num_heads;
296 let max_q = self.max_q_n;
297 let max_k = self.k_n;
298 let plane = max_q * max_k;
299 let mut mask_buf = vec![0f32; nh * plane];
300
301 let mut owned: Vec<(String, Vec<f32>)> = vec![
302 ("tokens".into(), tokens_padded.to_vec()),
303 ("image_seq".into(), image_seq.to_vec()),
304 ("image_pe".into(), image_pe_seq.to_vec()),
305 ];
306 for i in 0..self.num_layers {
307 Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_q, active_q_n, active_q_n);
308 owned.push((format!("mask_L{i}_self"), mask_buf.clone()));
309 Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
310 owned.push((format!("mask_L{i}_t2i"), mask_buf.clone()));
311 Self::fill_attn_mask(&mut mask_buf, nh, max_k, max_q, max_k, active_q_n);
312 owned.push((format!("mask_L{i}_i2t"), mask_buf.clone()));
313 }
314 Self::fill_attn_mask(&mut mask_buf, nh, max_q, max_k, active_q_n, max_k);
315 owned.push(("mask_final_t2i".into(), mask_buf.clone()));
316
317 let feeds: Vec<(&str, &[f32])> = owned
318 .iter()
319 .map(|(n, d)| (n.as_str(), d.as_slice()))
320 .collect();
321 let outs = self.graph.run(&feeds);
322 let mut it = outs.into_iter();
323 let queries_full = it.next().expect("queries_out");
324 let keys = it.next().expect("keys_out");
325 let mut queries = vec![0f32; active_q_n * e];
326 queries.copy_from_slice(&queries_full[..active_q_n * e]);
327 Ok((queries, keys))
328 }
329}
330
331fn matmul_weight(w_out_in: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
332 let mut t = vec![0f32; in_d * out_d];
333 for o in 0..out_d {
334 for k in 0..in_d {
335 t[k * out_d + o] = w_out_in[o * in_d + k];
336 }
337 }
338 t
339}
340
341fn bind_linear(
342 g: &mut Graph,
343 params: &mut HashMap<String, Vec<f32>>,
344 prefix: &str,
345 w: &[f32],
346 b: &[f32],
347 in_d: usize,
348 out_d: usize,
349) -> (NodeId, NodeId) {
350 let f = DType::F32;
351 let w_id = g.param(format!("{prefix}.w"), Shape::new(&[in_d, out_d], f));
352 let b_id = g.param(format!("{prefix}.b"), Shape::new(&[out_d], f));
353 params.insert(format!("{prefix}.w"), matmul_weight(w, in_d, out_d));
354 params.insert(format!("{prefix}.b"), b.to_vec());
355 (w_id, b_id)
356}
357
358fn linear(
359 g: &mut Graph,
360 params: &mut HashMap<String, Vec<f32>>,
361 prefix: &str,
362 x: NodeId,
363 w: &[f32],
364 b: &[f32],
365 in_d: usize,
366 out_d: usize,
367 seq: usize,
368) -> NodeId {
369 let f = DType::F32;
370 let (w_id, b_id) = bind_linear(g, params, prefix, w, b, in_d, out_d);
371 g.fused_matmul_bias_act(x, w_id, b_id, None, Shape::new(&[1, seq, out_d], f))
372}
373
374fn bind_ln(
375 g: &mut Graph,
376 params: &mut HashMap<String, Vec<f32>>,
377 prefix: &str,
378 gamm: &[f32],
379 bet: &[f32],
380 e: usize,
381) -> (NodeId, NodeId) {
382 let f = DType::F32;
383 let g_id = g.param(format!("{prefix}.g"), Shape::new(&[e], f));
384 let b_id = g.param(format!("{prefix}.b"), Shape::new(&[e], f));
385 params.insert(format!("{prefix}.g"), gamm.to_vec());
386 params.insert(format!("{prefix}.b"), bet.to_vec());
387 (g_id, b_id)
388}
389
390fn layer_norm(
391 g: &mut Graph,
392 params: &mut HashMap<String, Vec<f32>>,
393 prefix: &str,
394 x: NodeId,
395 gamm: &[f32],
396 bet: &[f32],
397 seq: usize,
398 e: usize,
399) -> NodeId {
400 let f = DType::F32;
401 let shape = Shape::new(&[1, seq, e], f);
402 let (g_id, b_id) = bind_ln(g, params, prefix, gamm, bet, e);
403 g.layer_norm(x, g_id, b_id, -1, LN_EPS, shape)
404}
405
406fn build_attention(
407 g: &mut Graph,
408 params: &mut HashMap<String, Vec<f32>>,
409 prefix: &str,
410 spec: &AttentionSpec,
411 q_in: NodeId,
412 k_in: NodeId,
413 v_in: NodeId,
414 q_len: usize,
415 k_len: usize,
416 mask: Option<NodeId>,
417) -> NodeId {
418 let e = spec.embed_dim;
419 let id = spec.internal_dim;
420 let nh = spec.num_heads;
421 let dh = id / nh;
422 let f = DType::F32;
423
424 let q_proj = linear(
425 g,
426 params,
427 &format!("{prefix}.q"),
428 q_in,
429 &spec.q_w,
430 &spec.q_b,
431 e,
432 id,
433 q_len,
434 );
435 let k_proj = linear(
436 g,
437 params,
438 &format!("{prefix}.k"),
439 k_in,
440 &spec.k_w,
441 &spec.k_b,
442 e,
443 id,
444 k_len,
445 );
446 let v_proj = linear(
447 g,
448 params,
449 &format!("{prefix}.v"),
450 v_in,
451 &spec.v_w,
452 &spec.v_b,
453 e,
454 id,
455 k_len,
456 );
457 let out_shape = Shape::new(&[1, q_len, id], f);
458 let attn = if let Some(m) = mask {
459 g.attention(q_proj, k_proj, v_proj, m, nh, dh, out_shape.clone())
460 } else {
461 g.attention_kind(
462 q_proj,
463 k_proj,
464 v_proj,
465 nh,
466 dh,
467 MaskKind::None,
468 out_shape.clone(),
469 )
470 };
471 linear(
472 g,
473 params,
474 &format!("{prefix}.o"),
475 attn,
476 &spec.out_w,
477 &spec.out_b,
478 id,
479 e,
480 q_len,
481 )
482}
483
484fn build_block(
485 g: &mut Graph,
486 params: &mut HashMap<String, Vec<f32>>,
487 prefix: &str,
488 block: &TwoWayBlockSpec,
489 queries: NodeId,
490 keys: NodeId,
491 query_pe: NodeId,
492 key_pe: NodeId,
493 q_n: usize,
494 k_n: usize,
495 e: usize,
496 masks: Option<&LayerMaskIds>,
497) -> (NodeId, NodeId) {
498 let f = DType::F32;
499 let q_shape = Shape::new(&[1, q_n, e], f);
500 let k_shape = Shape::new(&[1, k_n, e], f);
501
502 let m_self = masks.map(|m| m.self_attn);
503 let m_t2i = masks.map(|m| m.t2i);
504 let m_i2t = masks.map(|m| m.i2t);
505
506 let mut q = if block.skip_first_layer_pe {
507 build_attention(
508 g,
509 params,
510 &format!("{prefix}.self"),
511 &block.self_attn,
512 queries,
513 queries,
514 queries,
515 q_n,
516 q_n,
517 m_self,
518 )
519 } else {
520 let q_pe_sum = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
521 let attn = build_attention(
522 g,
523 params,
524 &format!("{prefix}.self"),
525 &block.self_attn,
526 q_pe_sum,
527 q_pe_sum,
528 queries,
529 q_n,
530 q_n,
531 m_self,
532 );
533 g.binary(BinaryOp::Add, queries, attn, q_shape.clone())
534 };
535 q = layer_norm(
536 g,
537 params,
538 &format!("{prefix}.n1"),
539 q,
540 &block.norm1_g,
541 &block.norm1_b,
542 q_n,
543 e,
544 );
545
546 let q_pe_sum = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
547 let k_pe_sum = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
548 let cross_t = build_attention(
549 g,
550 params,
551 &format!("{prefix}.t2i"),
552 &block.cross_token_to_image,
553 q_pe_sum,
554 k_pe_sum,
555 keys,
556 q_n,
557 k_n,
558 m_t2i,
559 );
560 q = g.binary(BinaryOp::Add, q, cross_t, q_shape.clone());
561 q = layer_norm(
562 g,
563 params,
564 &format!("{prefix}.n2"),
565 q,
566 &block.norm2_g,
567 &block.norm2_b,
568 q_n,
569 e,
570 );
571
572 let mlp_dim = block.mlp_lin1_b.len();
573 let mid = linear(
574 g,
575 params,
576 &format!("{prefix}.mlp1"),
577 q,
578 &block.mlp_lin1_w,
579 &block.mlp_lin1_b,
580 e,
581 mlp_dim,
582 q_n,
583 );
584 let mid_relu = g.activation(Activation::Relu, mid, Shape::new(&[1, q_n, mlp_dim], f));
585 let mlp_out = linear(
586 g,
587 params,
588 &format!("{prefix}.mlp2"),
589 mid_relu,
590 &block.mlp_lin2_w,
591 &block.mlp_lin2_b,
592 mlp_dim,
593 e,
594 q_n,
595 );
596 q = g.binary(BinaryOp::Add, q, mlp_out, q_shape.clone());
597 q = layer_norm(
598 g,
599 params,
600 &format!("{prefix}.n3"),
601 q,
602 &block.norm3_g,
603 &block.norm3_b,
604 q_n,
605 e,
606 );
607
608 let q_pe2 = g.binary(BinaryOp::Add, q, query_pe, q_shape.clone());
609 let k_pe2 = g.binary(BinaryOp::Add, keys, key_pe, k_shape.clone());
610 let cross_i = build_attention(
611 g,
612 params,
613 &format!("{prefix}.i2t"),
614 &block.cross_image_to_token,
615 k_pe2,
616 q_pe2,
617 q,
618 k_n,
619 q_n,
620 m_i2t,
621 );
622 let keys_out = g.binary(BinaryOp::Add, keys, cross_i, k_shape);
623 let keys_out = layer_norm(
624 g,
625 params,
626 &format!("{prefix}.n4"),
627 keys_out,
628 &block.norm4_g,
629 &block.norm4_b,
630 k_n,
631 e,
632 );
633 (q, keys_out)
634}
635
636fn build_transformer_graph(
637 spec: &TwoWayTransformerSpec,
638 q_n: usize,
639 k_n: usize,
640 masked: bool,
641) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
642 let e = spec.embed_dim;
643 let f = DType::F32;
644 let mut g = Graph::new("twoway_transformer");
645 let mut params = HashMap::new();
646 let nh0 = spec
647 .layers
648 .first()
649 .map(|l| l.self_attn.num_heads)
650 .unwrap_or(spec.final_attn.num_heads);
651
652 let tokens = g.input("tokens", Shape::new(&[1, q_n, e], f));
653 let image_seq = g.input("image_seq", Shape::new(&[1, k_n, e], f));
654 let image_pe = g.input("image_pe", Shape::new(&[1, k_n, e], f));
655 let query_pe = tokens;
656
657 let mut layer_masks = Vec::new();
658 if masked {
659 for i in 0..spec.layers.len() {
660 let nh = spec.layers[i].self_attn.num_heads;
661 layer_masks.push(LayerMaskIds {
662 self_attn: g.input(format!("mask_L{i}_self"), Shape::new(&[1, nh, q_n, q_n], f)),
663 t2i: g.input(format!("mask_L{i}_t2i"), Shape::new(&[1, nh, q_n, k_n], f)),
664 i2t: g.input(format!("mask_L{i}_i2t"), Shape::new(&[1, nh, k_n, q_n], f)),
665 });
666 }
667 }
668 let final_mask = if masked {
669 Some(g.input("mask_final_t2i", Shape::new(&[1, nh0, q_n, k_n], f)))
670 } else {
671 None
672 };
673
674 let mut queries = tokens;
675 let mut keys = image_seq;
676 for (i, layer) in spec.layers.iter().enumerate() {
677 let masks = if masked { Some(&layer_masks[i]) } else { None };
678 let (q, k) = build_block(
679 &mut g,
680 &mut params,
681 &format!("L{i}"),
682 layer,
683 queries,
684 keys,
685 query_pe,
686 image_pe,
687 q_n,
688 k_n,
689 e,
690 masks,
691 );
692 queries = q;
693 keys = k;
694 }
695
696 let q_shape = Shape::new(&[1, q_n, e], f);
697 let k_shape = Shape::new(&[1, k_n, e], f);
698 let q_pe_f = g.binary(BinaryOp::Add, queries, query_pe, q_shape.clone());
699 let k_pe_f = g.binary(BinaryOp::Add, keys, image_pe, k_shape.clone());
700 let final_attn = build_attention(
701 &mut g,
702 &mut params,
703 "final",
704 &spec.final_attn,
705 q_pe_f,
706 k_pe_f,
707 keys,
708 q_n,
709 k_n,
710 final_mask,
711 );
712 let queries_out = g.binary(BinaryOp::Add, queries, final_attn, q_shape.clone());
713 let queries_out = layer_norm(
714 &mut g,
715 &mut params,
716 "final_ln",
717 queries_out,
718 &spec.norm_final_g,
719 &spec.norm_final_b,
720 q_n,
721 e,
722 );
723
724 g.set_outputs(vec![queries_out, keys]);
725 Ok((g, params))
726}