1use anyhow::{Result, ensure};
28use rlx_core::weight_map::WeightMap;
29
30pub struct AttentionWeights {
32 pub q_w: Vec<f32>, pub q_b: Vec<f32>,
34 pub k_w: Vec<f32>,
35 pub k_b: Vec<f32>,
36 pub v_w: Vec<f32>,
37 pub v_b: Vec<f32>,
38 pub out_w: Vec<f32>, pub out_b: Vec<f32>,
40 pub num_heads: usize,
41 pub embed_dim: usize,
42 pub internal_dim: usize,
43}
44
45pub struct TwoWayAttentionBlockWeights {
46 pub self_attn: AttentionWeights,
47 pub norm1_g: Vec<f32>,
48 pub norm1_b: Vec<f32>,
49 pub cross_token_to_image: AttentionWeights,
50 pub norm2_g: Vec<f32>,
51 pub norm2_b: Vec<f32>,
52 pub mlp_lin1_w: Vec<f32>,
53 pub mlp_lin1_b: Vec<f32>,
54 pub mlp_lin2_w: Vec<f32>,
55 pub mlp_lin2_b: Vec<f32>,
56 pub norm3_g: Vec<f32>,
57 pub norm3_b: Vec<f32>,
58 pub cross_image_to_token: AttentionWeights,
59 pub norm4_g: Vec<f32>,
60 pub norm4_b: Vec<f32>,
61 pub skip_first_layer_pe: bool,
62}
63
64pub struct TwoWayTransformerWeights {
65 pub layers: Vec<TwoWayAttentionBlockWeights>,
66 pub final_attn_token_to_image: AttentionWeights,
67 pub norm_final_g: Vec<f32>,
68 pub norm_final_b: Vec<f32>,
69 pub embed_dim: usize,
70}
71
72fn load_attention(
73 weights: &mut WeightMap,
74 prefix: &str,
75 embed_dim: usize,
76 num_heads: usize,
77 downsample_rate: usize,
78) -> Result<AttentionWeights> {
79 let internal_dim = embed_dim / downsample_rate;
80 let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
81 ensure!(
82 sh == vec![internal_dim, embed_dim],
83 "{prefix}.q_proj.weight shape {sh:?}"
84 );
85 let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
86 let (k_w, _) = weights.take(&format!("{prefix}.k_proj.weight"))?;
87 let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
88 let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
89 let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
90 let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
91 ensure!(
92 sh == vec![embed_dim, internal_dim],
93 "{prefix}.out_proj.weight shape {sh:?}"
94 );
95 let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
96 Ok(AttentionWeights {
97 q_w,
98 q_b,
99 k_w,
100 k_b,
101 v_w,
102 v_b,
103 out_w,
104 out_b,
105 num_heads,
106 embed_dim,
107 internal_dim,
108 })
109}
110
111pub(super) fn extract_two_way_transformer_weights(
112 weights: &mut WeightMap,
113 embed_dim: usize,
114 depth: usize,
115 num_heads: usize,
116 mlp_dim: usize,
117) -> Result<TwoWayTransformerWeights> {
118 let mut layers = Vec::with_capacity(depth);
119 for i in 0..depth {
120 let p = format!("mask_decoder.transformer.layers.{i}");
121 let self_attn =
122 load_attention(weights, &format!("{p}.self_attn"), embed_dim, num_heads, 1)?;
123 let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
124 let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
125 let cross_t2i = load_attention(
126 weights,
127 &format!("{p}.cross_attn_token_to_image"),
128 embed_dim,
129 num_heads,
130 2,
131 )?;
132 let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
133 let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
134 let (mlp_lin1_w, sh) = weights.take(&format!("{p}.mlp.lin1.weight"))?;
135 ensure!(
136 sh == vec![mlp_dim, embed_dim],
137 "{p}.mlp.lin1.weight shape {sh:?}"
138 );
139 let (mlp_lin1_b, _) = weights.take(&format!("{p}.mlp.lin1.bias"))?;
140 let (mlp_lin2_w, _) = weights.take(&format!("{p}.mlp.lin2.weight"))?;
141 let (mlp_lin2_b, _) = weights.take(&format!("{p}.mlp.lin2.bias"))?;
142 let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
143 let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
144 let cross_i2t = load_attention(
145 weights,
146 &format!("{p}.cross_attn_image_to_token"),
147 embed_dim,
148 num_heads,
149 2,
150 )?;
151 let (norm4_g, _) = weights.take(&format!("{p}.norm4.weight"))?;
152 let (norm4_b, _) = weights.take(&format!("{p}.norm4.bias"))?;
153 layers.push(TwoWayAttentionBlockWeights {
154 self_attn,
155 norm1_g,
156 norm1_b,
157 cross_token_to_image: cross_t2i,
158 norm2_g,
159 norm2_b,
160 mlp_lin1_w,
161 mlp_lin1_b,
162 mlp_lin2_w,
163 mlp_lin2_b,
164 norm3_g,
165 norm3_b,
166 cross_image_to_token: cross_i2t,
167 norm4_g,
168 norm4_b,
169 skip_first_layer_pe: i == 0,
170 });
171 }
172 let final_attn = load_attention(
173 weights,
174 "mask_decoder.transformer.final_attn_token_to_image",
175 embed_dim,
176 num_heads,
177 2,
178 )?;
179 let (norm_final_g, _) = weights.take("mask_decoder.transformer.norm_final_attn.weight")?;
180 let (norm_final_b, _) = weights.take("mask_decoder.transformer.norm_final_attn.bias")?;
181 Ok(TwoWayTransformerWeights {
182 layers,
183 final_attn_token_to_image: final_attn,
184 norm_final_g,
185 norm_final_b,
186 embed_dim,
187 })
188}
189
190pub fn attention_forward(
198 w: &AttentionWeights,
199 q: &[f32],
200 q_n: usize,
201 k: &[f32],
202 k_n: usize,
203 v: &[f32],
204 v_n: usize,
205 b: usize,
206) -> Vec<f32> {
207 let e = w.embed_dim;
208 let id = w.internal_dim;
209 let nh = w.num_heads;
210 let dh = id / nh;
211 let scale = 1.0 / (dh as f32).sqrt();
212
213 let q_p = linear(q, &w.q_w, &w.q_b, b * q_n, e, id);
215 let k_p = linear(k, &w.k_w, &w.k_b, b * k_n, e, id);
216 let v_p = linear(v, &w.v_w, &w.v_b, b * v_n, e, id);
217
218 let q_h = separate_heads(&q_p, b, q_n, nh, dh);
220 let k_h = separate_heads(&k_p, b, k_n, nh, dh);
221 let v_h = separate_heads(&v_p, b, v_n, nh, dh);
222
223 let mut out_h = vec![0f32; b * nh * q_n * dh];
228 let mut scores = vec![0f32; q_n * k_n];
229 let mut k_t = vec![0f32; dh * k_n];
232 for bi in 0..b {
233 for h in 0..nh {
234 let q_off = ((bi * nh) + h) * q_n * dh;
235 let k_off = ((bi * nh) + h) * k_n * dh;
236 let v_off = ((bi * nh) + h) * v_n * dh;
237 let out_off = ((bi * nh) + h) * q_n * dh;
238
239 for j in 0..k_n {
241 for d in 0..dh {
242 k_t[d * k_n + j] = k_h[k_off + j * dh + d];
243 }
244 }
245 rlx_cpu::blas::sgemm_auto(
247 &q_h[q_off..q_off + q_n * dh],
248 &k_t,
249 &mut scores,
250 q_n,
251 dh,
252 k_n,
253 );
254 for i in 0..q_n {
256 let row = &mut scores[i * k_n..(i + 1) * k_n];
257 let mut m = f32::NEG_INFINITY;
258 for v in row.iter_mut() {
259 *v *= scale;
260 if *v > m {
261 m = *v;
262 }
263 }
264 let mut s = 0f32;
265 for v in row.iter_mut() {
266 *v = (*v - m).exp();
267 s += *v;
268 }
269 let inv = 1.0 / s;
270 for v in row.iter_mut() {
271 *v *= inv;
272 }
273 }
274 rlx_cpu::blas::sgemm_auto(
277 &scores,
278 &v_h[v_off..v_off + v_n * dh],
279 &mut out_h[out_off..out_off + q_n * dh],
280 q_n,
281 k_n,
282 dh,
283 );
284 }
285 }
286
287 let merged = recombine_heads(&out_h, b, q_n, nh, dh);
289 linear(&merged, &w.out_w, &w.out_b, b * q_n, id, e)
291}
292
293pub fn linear(x: &[f32], w: &[f32], b: &[f32], rows: usize, in_d: usize, out_d: usize) -> Vec<f32> {
306 let mut w_t = vec![0f32; in_d * out_d];
307 for o in 0..out_d {
308 for k in 0..in_d {
309 w_t[k * out_d + o] = w[o * in_d + k];
310 }
311 }
312 let mut out = vec![0f32; rows * out_d];
313 rlx_cpu::blas::sgemm_auto(x, &w_t, &mut out, rows, in_d, out_d);
314 for r in 0..rows {
315 for o in 0..out_d {
316 out[r * out_d + o] += b[o];
317 }
318 }
319 out
320}
321
322fn separate_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
323 let mut out = vec![0f32; b * nh * n * dh];
325 for bi in 0..b {
326 for i in 0..n {
327 for h in 0..nh {
328 for d in 0..dh {
329 out[((bi * nh + h) * n + i) * dh + d] =
330 x[(bi * n + i) * (nh * dh) + h * dh + d];
331 }
332 }
333 }
334 }
335 out
336}
337
338fn recombine_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
339 let mut out = vec![0f32; b * n * nh * dh];
341 for bi in 0..b {
342 for h in 0..nh {
343 for i in 0..n {
344 for d in 0..dh {
345 out[(bi * n + i) * (nh * dh) + h * dh + d] =
346 x[((bi * nh + h) * n + i) * dh + d];
347 }
348 }
349 }
350 }
351 out
352}
353
354pub fn layer_norm_last(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
357 for r in 0..rows {
358 let row = &mut x[r * n..(r + 1) * n];
359 let mut mean = 0f32;
360 for v in row.iter() {
361 mean += *v;
362 }
363 mean /= n as f32;
364 let mut var = 0f32;
365 for v in row.iter() {
366 let d = *v - mean;
367 var += d * d;
368 }
369 var /= n as f32;
370 let inv = 1.0 / (var + eps).sqrt();
371 for k in 0..n {
372 row[k] = (row[k] - mean) * inv * g[k] + b[k];
373 }
374 }
375}
376
377pub fn layer_norm_last_cpu(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
379 let mut tmp = vec![0f32; n];
380 for r in 0..rows {
381 let base = r * n;
382 rlx_cpu::kernels::layer_norm_row(&x[base..base + n], g, b, &mut tmp, n, eps);
383 x[base..base + n].copy_from_slice(&tmp);
384 }
385}
386
387fn add_inplace(dst: &mut [f32], src: &[f32]) {
388 for (d, s) in dst.iter_mut().zip(src.iter()) {
389 *d += *s;
390 }
391}
392
393fn relu_inplace(x: &mut [f32]) {
394 for v in x.iter_mut() {
395 if *v < 0.0 {
396 *v = 0.0;
397 }
398 }
399}
400
401pub fn two_way_attention_block_forward(
404 w: &TwoWayAttentionBlockWeights,
405 queries: Vec<f32>,
406 keys: Vec<f32>,
407 query_pe: &[f32],
408 key_pe: &[f32],
409 b: usize,
410 q_n: usize,
411 k_n: usize,
412) -> (Vec<f32>, Vec<f32>) {
413 let e = w.self_attn.embed_dim;
414
415 let mut queries = if w.skip_first_layer_pe {
417 attention_forward(&w.self_attn, &queries, q_n, &queries, q_n, &queries, q_n, b)
418 } else {
419 let mut q = queries.clone();
420 add_inplace(&mut q, query_pe);
421 let attn_out = attention_forward(&w.self_attn, &q, q_n, &q, q_n, &queries, q_n, b);
422 let mut out = queries;
423 add_inplace(&mut out, &attn_out);
424 out
425 };
426 layer_norm_last(&mut queries, b * q_n, e, &w.norm1_g, &w.norm1_b, 1e-5);
427
428 let mut q_pe = queries.clone();
430 add_inplace(&mut q_pe, query_pe);
431 let mut k_pe = keys.clone();
432 add_inplace(&mut k_pe, key_pe);
433 let attn_out = attention_forward(
434 &w.cross_token_to_image,
435 &q_pe,
436 q_n,
437 &k_pe,
438 k_n,
439 &keys,
440 k_n,
441 b,
442 );
443 add_inplace(&mut queries, &attn_out);
444 layer_norm_last(&mut queries, b * q_n, e, &w.norm2_g, &w.norm2_b, 1e-5);
445
446 let mlp_dim = w.mlp_lin1_b.len();
448 let mut mlp_mid = linear(&queries, &w.mlp_lin1_w, &w.mlp_lin1_b, b * q_n, e, mlp_dim);
449 relu_inplace(&mut mlp_mid);
450 let mlp_out = linear(&mlp_mid, &w.mlp_lin2_w, &w.mlp_lin2_b, b * q_n, mlp_dim, e);
451 add_inplace(&mut queries, &mlp_out);
452 layer_norm_last(&mut queries, b * q_n, e, &w.norm3_g, &w.norm3_b, 1e-5);
453
454 let mut q_pe = queries.clone();
456 add_inplace(&mut q_pe, query_pe);
457 let mut k_pe = keys.clone();
458 add_inplace(&mut k_pe, key_pe);
459 let attn_out = attention_forward(
461 &w.cross_image_to_token,
462 &k_pe,
463 k_n,
464 &q_pe,
465 q_n,
466 &queries,
467 q_n,
468 b,
469 );
470 let mut keys = keys;
471 add_inplace(&mut keys, &attn_out);
472 layer_norm_last(&mut keys, b * k_n, e, &w.norm4_g, &w.norm4_b, 1e-5);
473
474 (queries, keys)
475}
476
477pub fn two_way_transformer_forward(
486 w: &TwoWayTransformerWeights,
487 image_embedding: &[f32],
488 image_pe: &[f32],
489 point_embedding: &[f32],
490 b: usize,
491 c: usize,
492 h: usize,
493 ww: usize,
494 q_n: usize,
495) -> (Vec<f32>, Vec<f32>) {
496 let k_n = h * ww;
497 let mut image_seq = vec![0f32; b * k_n * c];
499 let mut image_pe_seq = vec![0f32; b * k_n * c];
500 for bi in 0..b {
501 for y in 0..h {
502 for x in 0..ww {
503 for ch in 0..c {
504 let src = (bi * c + ch) * h * ww + y * ww + x;
505 let dst = (bi * k_n + y * ww + x) * c + ch;
506 image_seq[dst] = image_embedding[src];
507 image_pe_seq[dst] = image_pe[src];
508 }
509 }
510 }
511 }
512
513 let mut queries = point_embedding.to_vec();
514 let mut keys = image_seq;
515
516 for layer in &w.layers {
517 let (q, k) = two_way_attention_block_forward(
518 layer,
519 queries,
520 keys,
521 point_embedding,
522 &image_pe_seq,
523 b,
524 q_n,
525 k_n,
526 );
527 queries = q;
528 keys = k;
529 }
530
531 let mut q_pe = queries.clone();
533 add_inplace(&mut q_pe, point_embedding);
534 let mut k_pe = keys.clone();
535 add_inplace(&mut k_pe, &image_pe_seq);
536 let attn_out = attention_forward(
537 &w.final_attn_token_to_image,
538 &q_pe,
539 q_n,
540 &k_pe,
541 k_n,
542 &keys,
543 k_n,
544 b,
545 );
546 add_inplace(&mut queries, &attn_out);
547 layer_norm_last(
548 &mut queries,
549 b * q_n,
550 w.embed_dim,
551 &w.norm_final_g,
552 &w.norm_final_b,
553 1e-5,
554 );
555
556 (queries, keys)
557}