1use anyhow::{Result, ensure};
35use rlx_core::weight_map::WeightMap;
36
37pub struct Sam2AttentionWeights {
39 pub q_w: Vec<f32>, pub q_b: Vec<f32>,
41 pub k_w: Vec<f32>,
42 pub k_b: Vec<f32>,
43 pub v_w: Vec<f32>,
44 pub v_b: Vec<f32>,
45 pub out_w: Vec<f32>, pub out_b: Vec<f32>,
47 pub num_heads: usize,
48 pub embed_dim: usize,
49 pub internal_dim: usize,
50}
51
52pub struct Sam2TwoWayAttentionBlockWeights {
53 pub self_attn: Sam2AttentionWeights,
54 pub norm1_g: Vec<f32>,
55 pub norm1_b: Vec<f32>,
56 pub cross_token_to_image: Sam2AttentionWeights,
57 pub norm2_g: Vec<f32>,
58 pub norm2_b: Vec<f32>,
59 pub mlp_lin1_w: Vec<f32>,
60 pub mlp_lin1_b: Vec<f32>,
61 pub mlp_lin2_w: Vec<f32>,
62 pub mlp_lin2_b: Vec<f32>,
63 pub norm3_g: Vec<f32>,
64 pub norm3_b: Vec<f32>,
65 pub cross_image_to_token: Sam2AttentionWeights,
66 pub norm4_g: Vec<f32>,
67 pub norm4_b: Vec<f32>,
68 pub skip_first_layer_pe: bool,
69}
70
71pub struct Sam2TwoWayTransformerWeights {
72 pub layers: Vec<Sam2TwoWayAttentionBlockWeights>,
73 pub final_attn_token_to_image: Sam2AttentionWeights,
74 pub norm_final_g: Vec<f32>,
75 pub norm_final_b: Vec<f32>,
76 pub embed_dim: usize,
77}
78
79fn load_attention(
80 weights: &mut WeightMap,
81 prefix: &str,
82 embed_dim: usize,
83 num_heads: usize,
84 downsample_rate: usize,
85) -> Result<Sam2AttentionWeights> {
86 let internal_dim = embed_dim / downsample_rate;
87 let (q_w, sh) = weights.take(&format!("{prefix}.q_proj.weight"))?;
88 ensure!(
89 sh == vec![internal_dim, embed_dim],
90 "{prefix}.q_proj.weight shape {sh:?} not [{internal_dim}, {embed_dim}]"
91 );
92 let (q_b, _) = weights.take(&format!("{prefix}.q_proj.bias"))?;
93 let (k_w, _) = weights.take(&format!("{prefix}.k_proj.weight"))?;
94 let (k_b, _) = weights.take(&format!("{prefix}.k_proj.bias"))?;
95 let (v_w, _) = weights.take(&format!("{prefix}.v_proj.weight"))?;
96 let (v_b, _) = weights.take(&format!("{prefix}.v_proj.bias"))?;
97 let (out_w, sh) = weights.take(&format!("{prefix}.out_proj.weight"))?;
98 ensure!(
99 sh == vec![embed_dim, internal_dim],
100 "{prefix}.out_proj.weight shape {sh:?} not [{embed_dim}, {internal_dim}]"
101 );
102 let (out_b, _) = weights.take(&format!("{prefix}.out_proj.bias"))?;
103 Ok(Sam2AttentionWeights {
104 q_w,
105 q_b,
106 k_w,
107 k_b,
108 v_w,
109 v_b,
110 out_w,
111 out_b,
112 num_heads,
113 embed_dim,
114 internal_dim,
115 })
116}
117
118pub(super) fn extract_two_way_transformer_weights(
119 weights: &mut WeightMap,
120 embed_dim: usize,
121 depth: usize,
122 num_heads: usize,
123 mlp_dim: usize,
124) -> Result<Sam2TwoWayTransformerWeights> {
125 let mut layers = Vec::with_capacity(depth);
126 for i in 0..depth {
127 let p = format!("sam_mask_decoder.transformer.layers.{i}");
128 let self_attn =
129 load_attention(weights, &format!("{p}.self_attn"), embed_dim, num_heads, 1)?;
130 let (norm1_g, _) = weights.take(&format!("{p}.norm1.weight"))?;
131 let (norm1_b, _) = weights.take(&format!("{p}.norm1.bias"))?;
132 let cross_t2i = load_attention(
133 weights,
134 &format!("{p}.cross_attn_token_to_image"),
135 embed_dim,
136 num_heads,
137 2,
138 )?;
139 let (norm2_g, _) = weights.take(&format!("{p}.norm2.weight"))?;
140 let (norm2_b, _) = weights.take(&format!("{p}.norm2.bias"))?;
141 let (mlp_lin1_w, sh) = weights.take(&format!("{p}.mlp.layers.0.weight"))?;
142 ensure!(
143 sh == vec![mlp_dim, embed_dim],
144 "{p}.mlp.layers.0.weight shape {sh:?} not [{mlp_dim}, {embed_dim}]"
145 );
146 let (mlp_lin1_b, _) = weights.take(&format!("{p}.mlp.layers.0.bias"))?;
147 let (mlp_lin2_w, _) = weights.take(&format!("{p}.mlp.layers.1.weight"))?;
148 let (mlp_lin2_b, _) = weights.take(&format!("{p}.mlp.layers.1.bias"))?;
149 let (norm3_g, _) = weights.take(&format!("{p}.norm3.weight"))?;
150 let (norm3_b, _) = weights.take(&format!("{p}.norm3.bias"))?;
151 let cross_i2t = load_attention(
152 weights,
153 &format!("{p}.cross_attn_image_to_token"),
154 embed_dim,
155 num_heads,
156 2,
157 )?;
158 let (norm4_g, _) = weights.take(&format!("{p}.norm4.weight"))?;
159 let (norm4_b, _) = weights.take(&format!("{p}.norm4.bias"))?;
160 layers.push(Sam2TwoWayAttentionBlockWeights {
161 self_attn,
162 norm1_g,
163 norm1_b,
164 cross_token_to_image: cross_t2i,
165 norm2_g,
166 norm2_b,
167 mlp_lin1_w,
168 mlp_lin1_b,
169 mlp_lin2_w,
170 mlp_lin2_b,
171 norm3_g,
172 norm3_b,
173 cross_image_to_token: cross_i2t,
174 norm4_g,
175 norm4_b,
176 skip_first_layer_pe: i == 0,
177 });
178 }
179 let final_attn = load_attention(
180 weights,
181 "sam_mask_decoder.transformer.final_attn_token_to_image",
182 embed_dim,
183 num_heads,
184 2,
185 )?;
186 let (norm_final_g, _) = weights.take("sam_mask_decoder.transformer.norm_final_attn.weight")?;
187 let (norm_final_b, _) = weights.take("sam_mask_decoder.transformer.norm_final_attn.bias")?;
188 Ok(Sam2TwoWayTransformerWeights {
189 layers,
190 final_attn_token_to_image: final_attn,
191 norm_final_g,
192 norm_final_b,
193 embed_dim,
194 })
195}
196
197pub fn sam2_attention_forward(
202 w: &Sam2AttentionWeights,
203 q: &[f32],
204 q_n: usize,
205 k: &[f32],
206 k_n: usize,
207 v: &[f32],
208 v_n: usize,
209 b: usize,
210) -> Vec<f32> {
211 let e = w.embed_dim;
212 let id = w.internal_dim;
213 let nh = w.num_heads;
214 let dh = id / nh;
215 let scale = 1.0 / (dh as f32).sqrt();
216
217 let q_p = linear(q, &w.q_w, &w.q_b, b * q_n, e, id);
218 let k_p = linear(k, &w.k_w, &w.k_b, b * k_n, e, id);
219 let v_p = linear(v, &w.v_w, &w.v_b, b * v_n, e, id);
220
221 let q_h = separate_heads(&q_p, b, q_n, nh, dh);
222 let k_h = separate_heads(&k_p, b, k_n, nh, dh);
223 let v_h = separate_heads(&v_p, b, v_n, nh, dh);
224
225 let mut out_h = vec![0f32; b * nh * q_n * dh];
226 let mut scores = vec![0f32; q_n * k_n];
227 for bi in 0..b {
228 for h in 0..nh {
229 let q_off = ((bi * nh) + h) * q_n * dh;
230 let k_off = ((bi * nh) + h) * k_n * dh;
231 let v_off = ((bi * nh) + h) * v_n * dh;
232 let out_off = ((bi * nh) + h) * q_n * dh;
233
234 for i in 0..q_n {
235 for j in 0..k_n {
236 let mut acc = 0f32;
237 for d in 0..dh {
238 acc += q_h[q_off + i * dh + d] * k_h[k_off + j * dh + d];
239 }
240 scores[i * k_n + j] = acc * scale;
241 }
242 }
243 for i in 0..q_n {
244 let row = &mut scores[i * k_n..(i + 1) * k_n];
245 let m = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
246 let mut s = 0f32;
247 for v in row.iter_mut() {
248 *v = (*v - m).exp();
249 s += *v;
250 }
251 for v in row.iter_mut() {
252 *v /= s;
253 }
254 }
255 for i in 0..q_n {
256 for d in 0..dh {
257 let mut acc = 0f32;
258 for j in 0..k_n {
259 acc += scores[i * k_n + j] * v_h[v_off + j * dh + d];
260 }
261 out_h[out_off + i * dh + d] = acc;
262 }
263 }
264 }
265 }
266
267 let merged = recombine_heads(&out_h, b, q_n, nh, dh);
268 linear(&merged, &w.out_w, &w.out_b, b * q_n, id, e)
269}
270
271pub fn linear(x: &[f32], w: &[f32], b: &[f32], rows: usize, in_d: usize, out_d: usize) -> Vec<f32> {
273 let mut out = vec![0f32; rows * out_d];
274 for r in 0..rows {
275 for o in 0..out_d {
276 let mut acc = b[o];
277 for k in 0..in_d {
278 acc += x[r * in_d + k] * w[o * in_d + k];
279 }
280 out[r * out_d + o] = acc;
281 }
282 }
283 out
284}
285
286fn separate_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
287 let mut out = vec![0f32; b * nh * n * dh];
288 for bi in 0..b {
289 for i in 0..n {
290 for h in 0..nh {
291 for d in 0..dh {
292 out[((bi * nh + h) * n + i) * dh + d] =
293 x[(bi * n + i) * (nh * dh) + h * dh + d];
294 }
295 }
296 }
297 }
298 out
299}
300
301fn recombine_heads(x: &[f32], b: usize, n: usize, nh: usize, dh: usize) -> Vec<f32> {
302 let mut out = vec![0f32; b * n * nh * dh];
303 for bi in 0..b {
304 for h in 0..nh {
305 for i in 0..n {
306 for d in 0..dh {
307 out[(bi * n + i) * (nh * dh) + h * dh + d] =
308 x[((bi * nh + h) * n + i) * dh + d];
309 }
310 }
311 }
312 }
313 out
314}
315
316pub fn layer_norm_last(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
318 for r in 0..rows {
319 let row = &mut x[r * n..(r + 1) * n];
320 let mut mean = 0f32;
321 for v in row.iter() {
322 mean += *v;
323 }
324 mean /= n as f32;
325 let mut var = 0f32;
326 for v in row.iter() {
327 let d = *v - mean;
328 var += d * d;
329 }
330 var /= n as f32;
331 let inv = 1.0 / (var + eps).sqrt();
332 for k in 0..n {
333 row[k] = (row[k] - mean) * inv * g[k] + b[k];
334 }
335 }
336}
337
338pub fn layer_norm_last_cpu(x: &mut [f32], rows: usize, n: usize, g: &[f32], b: &[f32], eps: f32) {
340 let mut tmp = vec![0f32; n];
341 for r in 0..rows {
342 let base = r * n;
343 rlx_cpu::kernels::layer_norm_row(&x[base..base + n], g, b, &mut tmp, n, eps);
344 x[base..base + n].copy_from_slice(&tmp);
345 }
346}
347
348pub(super) fn add_inplace(dst: &mut [f32], src: &[f32]) {
349 for (d, s) in dst.iter_mut().zip(src.iter()) {
350 *d += *s;
351 }
352}
353
354fn relu_inplace(x: &mut [f32]) {
355 for v in x.iter_mut() {
356 if *v < 0.0 {
357 *v = 0.0;
358 }
359 }
360}
361
362pub fn two_way_attention_block_forward(
364 w: &Sam2TwoWayAttentionBlockWeights,
365 queries: Vec<f32>,
366 keys: Vec<f32>,
367 query_pe: &[f32],
368 key_pe: &[f32],
369 b: usize,
370 q_n: usize,
371 k_n: usize,
372) -> (Vec<f32>, Vec<f32>) {
373 let e = w.self_attn.embed_dim;
374
375 let mut queries = if w.skip_first_layer_pe {
377 sam2_attention_forward(&w.self_attn, &queries, q_n, &queries, q_n, &queries, q_n, b)
378 } else {
379 let mut q = queries.clone();
380 add_inplace(&mut q, query_pe);
381 let attn_out = sam2_attention_forward(&w.self_attn, &q, q_n, &q, q_n, &queries, q_n, b);
382 let mut out = queries;
383 add_inplace(&mut out, &attn_out);
384 out
385 };
386 layer_norm_last(&mut queries, b * q_n, e, &w.norm1_g, &w.norm1_b, 1e-5);
387
388 let mut q_pe = queries.clone();
390 add_inplace(&mut q_pe, query_pe);
391 let mut k_pe = keys.clone();
392 add_inplace(&mut k_pe, key_pe);
393 let attn_out = sam2_attention_forward(
394 &w.cross_token_to_image,
395 &q_pe,
396 q_n,
397 &k_pe,
398 k_n,
399 &keys,
400 k_n,
401 b,
402 );
403 add_inplace(&mut queries, &attn_out);
404 layer_norm_last(&mut queries, b * q_n, e, &w.norm2_g, &w.norm2_b, 1e-5);
405
406 let mlp_dim = w.mlp_lin1_b.len();
408 let mut mlp_mid = linear(&queries, &w.mlp_lin1_w, &w.mlp_lin1_b, b * q_n, e, mlp_dim);
409 relu_inplace(&mut mlp_mid);
410 let mlp_out = linear(&mlp_mid, &w.mlp_lin2_w, &w.mlp_lin2_b, b * q_n, mlp_dim, e);
411 add_inplace(&mut queries, &mlp_out);
412 layer_norm_last(&mut queries, b * q_n, e, &w.norm3_g, &w.norm3_b, 1e-5);
413
414 let mut q_pe = queries.clone();
416 add_inplace(&mut q_pe, query_pe);
417 let mut k_pe = keys.clone();
418 add_inplace(&mut k_pe, key_pe);
419 let attn_out = sam2_attention_forward(
420 &w.cross_image_to_token,
421 &k_pe,
422 k_n,
423 &q_pe,
424 q_n,
425 &queries,
426 q_n,
427 b,
428 );
429 let mut keys = keys;
430 add_inplace(&mut keys, &attn_out);
431 layer_norm_last(&mut keys, b * k_n, e, &w.norm4_g, &w.norm4_b, 1e-5);
432
433 (queries, keys)
434}
435
436pub fn two_way_transformer_forward(
445 w: &Sam2TwoWayTransformerWeights,
446 image_embedding: &[f32],
447 image_pe: &[f32],
448 point_embedding: &[f32],
449 b: usize,
450 c: usize,
451 h: usize,
452 ww: usize,
453 q_n: usize,
454) -> (Vec<f32>, Vec<f32>) {
455 let k_n = h * ww;
456 let mut image_seq = vec![0f32; b * k_n * c];
457 let mut image_pe_seq = vec![0f32; b * k_n * c];
458 for bi in 0..b {
459 for y in 0..h {
460 for x in 0..ww {
461 for ch in 0..c {
462 let src = (bi * c + ch) * h * ww + y * ww + x;
463 let dst = (bi * k_n + y * ww + x) * c + ch;
464 image_seq[dst] = image_embedding[src];
465 image_pe_seq[dst] = image_pe[src];
466 }
467 }
468 }
469 }
470
471 let mut queries = point_embedding.to_vec();
472 let mut keys = image_seq;
473
474 for layer in &w.layers {
475 let (q, k) = two_way_attention_block_forward(
476 layer,
477 queries,
478 keys,
479 point_embedding,
480 &image_pe_seq,
481 b,
482 q_n,
483 k_n,
484 );
485 queries = q;
486 keys = k;
487 }
488
489 let mut q_pe = queries.clone();
491 add_inplace(&mut q_pe, point_embedding);
492 let mut k_pe = keys.clone();
493 add_inplace(&mut k_pe, &image_pe_seq);
494 let attn_out = sam2_attention_forward(
495 &w.final_attn_token_to_image,
496 &q_pe,
497 q_n,
498 &k_pe,
499 k_n,
500 &keys,
501 k_n,
502 b,
503 );
504 add_inplace(&mut queries, &attn_out);
505 layer_norm_last(
506 &mut queries,
507 b * q_n,
508 w.embed_dim,
509 &w.norm_final_g,
510 &w.norm_final_b,
511 1e-5,
512 );
513
514 (queries, keys)
515}