1use super::config::SAM_EMBED_HW;
24use super::transformer::{
25 TwoWayTransformerWeights, extract_two_way_transformer_weights, linear,
26 two_way_transformer_forward,
27};
28use super::upscale_ir::SamMaskUpscaleCompiled;
29use anyhow::{Result, ensure};
30use rlx_core::weight_map::WeightMap;
31use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
32use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
33use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
34
35pub struct MaskDecoderWeights {
36 pub iou_token: Vec<f32>, pub mask_tokens: Vec<f32>, pub transformer: TwoWayTransformerWeights,
39
40 pub upscale_conv1_w: Vec<f32>,
43 pub upscale_conv1_b: Vec<f32>,
44 pub upscale_ln_g: Vec<f32>,
46 pub upscale_ln_b: Vec<f32>,
47 pub upscale_conv2_w: Vec<f32>,
49 pub upscale_conv2_b: Vec<f32>,
50
51 pub hyper_mlps: Vec<HypernetMlp>,
55
56 pub iou_head: HypernetMlp,
59
60 pub transformer_dim: usize,
61 pub num_mask_tokens: usize,
62}
63
64pub struct HypernetMlp {
65 pub layers: Vec<MlpLayer>,
66}
67
68pub struct MlpLayer {
69 pub w: Vec<f32>,
70 pub b: Vec<f32>,
71 pub in_d: usize,
72 pub out_d: usize,
73}
74
75pub(super) fn extract_mask_decoder_weights(
76 weights: &mut WeightMap,
77 transformer_dim: usize,
78 num_mask_tokens: usize,
79 iou_head_depth: usize,
80 iou_head_hidden_dim: usize,
81 transformer_depth: usize,
82 transformer_num_heads: usize,
83 transformer_mlp_dim: usize,
84) -> Result<MaskDecoderWeights> {
85 let (iou_token, sh) = weights.take("mask_decoder.iou_token.weight")?;
86 ensure!(
87 sh == vec![1, transformer_dim],
88 "iou_token shape {sh:?} not [1, {transformer_dim}]"
89 );
90 let (mask_tokens, sh) = weights.take("mask_decoder.mask_tokens.weight")?;
91 ensure!(
92 sh == vec![num_mask_tokens, transformer_dim],
93 "mask_tokens shape {sh:?} not [{num_mask_tokens}, {transformer_dim}]"
94 );
95
96 let q4 = transformer_dim / 4;
98 let q8 = transformer_dim / 8;
99 let (upscale_conv1_w, sh) = weights.take("mask_decoder.output_upscaling.0.weight")?;
100 ensure!(
101 sh == vec![transformer_dim, q4, 2, 2],
102 "output_upscaling.0.weight shape {sh:?} not [{transformer_dim}, {q4}, 2, 2]"
103 );
104 let (upscale_conv1_b, _) = weights.take("mask_decoder.output_upscaling.0.bias")?;
105 let (upscale_ln_g, _) = weights.take("mask_decoder.output_upscaling.1.weight")?;
106 let (upscale_ln_b, _) = weights.take("mask_decoder.output_upscaling.1.bias")?;
107 let (upscale_conv2_w, sh) = weights.take("mask_decoder.output_upscaling.3.weight")?;
108 ensure!(
109 sh == vec![q4, q8, 2, 2],
110 "output_upscaling.3.weight shape {sh:?} not [{q4}, {q8}, 2, 2]"
111 );
112 let (upscale_conv2_b, _) = weights.take("mask_decoder.output_upscaling.3.bias")?;
113
114 let mut hyper_mlps = Vec::with_capacity(num_mask_tokens);
117 for i in 0..num_mask_tokens {
118 let mlp = extract_mlp(
119 weights,
120 &format!("mask_decoder.output_hypernetworks_mlps.{i}"),
121 transformer_dim,
122 transformer_dim,
123 q8,
124 3,
125 )?;
126 hyper_mlps.push(mlp);
127 }
128
129 let iou_head = extract_mlp(
130 weights,
131 "mask_decoder.iou_prediction_head",
132 transformer_dim,
133 iou_head_hidden_dim,
134 num_mask_tokens,
135 iou_head_depth,
136 )?;
137
138 let transformer = extract_two_way_transformer_weights(
139 weights,
140 transformer_dim,
141 transformer_depth,
142 transformer_num_heads,
143 transformer_mlp_dim,
144 )?;
145
146 Ok(MaskDecoderWeights {
147 iou_token,
148 mask_tokens,
149 transformer,
150 upscale_conv1_w,
151 upscale_conv1_b,
152 upscale_ln_g,
153 upscale_ln_b,
154 upscale_conv2_w,
155 upscale_conv2_b,
156 hyper_mlps,
157 iou_head,
158 transformer_dim,
159 num_mask_tokens,
160 })
161}
162
163fn extract_mlp(
164 weights: &mut WeightMap,
165 prefix: &str,
166 input_dim: usize,
167 hidden_dim: usize,
168 output_dim: usize,
169 num_layers: usize,
170) -> Result<HypernetMlp> {
171 let mut layers = Vec::with_capacity(num_layers);
172 for i in 0..num_layers {
173 let in_d = if i == 0 { input_dim } else { hidden_dim };
174 let out_d = if i + 1 == num_layers {
175 output_dim
176 } else {
177 hidden_dim
178 };
179 let (w, sh) = weights.take(&format!("{prefix}.layers.{i}.weight"))?;
180 ensure!(
181 sh == vec![out_d, in_d],
182 "{prefix}.layers.{i}.weight shape {sh:?} not [{out_d}, {in_d}]"
183 );
184 let (b, _) = weights.take(&format!("{prefix}.layers.{i}.bias"))?;
185 layers.push(MlpLayer { w, b, in_d, out_d });
186 }
187 Ok(HypernetMlp { layers })
188}
189
190pub fn mlp_forward(mlp: &HypernetMlp, x: &[f32], rows: usize) -> Vec<f32> {
193 let mut cur = x.to_vec();
194 let n = mlp.layers.len();
195 for (i, layer) in mlp.layers.iter().enumerate() {
196 cur = linear(&cur, &layer.w, &layer.b, rows, layer.in_d, layer.out_d);
197 if i + 1 < n {
198 for v in cur.iter_mut() {
199 if *v < 0.0 {
200 *v = 0.0;
201 }
202 }
203 }
204 }
205 cur
206}
207
208pub fn mask_decoder_forward(
223 w: &MaskDecoderWeights,
224 upscale: &mut SamMaskUpscaleCompiled,
225 hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
226 hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
227 iou_head_ir: Option<&mut MlpReluCompiled>,
228 tw_ir: Option<&mut TwoWayTransformerCompiled>,
229 image_embeddings: &[f32],
230 image_pe: &[f32],
231 sparse_prompt_embeddings: &[f32],
232 num_sparse_tokens: usize,
233 dense_prompt_embeddings: &[f32],
234 multimask_output: bool,
235) -> Result<(Vec<f32>, Vec<f32>, usize, usize)> {
236 let e = w.transformer_dim;
237 let hw = SAM_EMBED_HW;
238 ensure!(
239 image_embeddings.len() == e * hw * hw,
240 "image_embeddings len {} ≠ E·hw·hw ({e}·{hw}·{hw})",
241 image_embeddings.len()
242 );
243 ensure!(
244 image_pe.len() == e * hw * hw,
245 "image_pe len {} ≠ E·hw·hw",
246 image_pe.len()
247 );
248 ensure!(
249 dense_prompt_embeddings.len() == e * hw * hw,
250 "dense_prompt_embeddings len {} ≠ E·hw·hw",
251 dense_prompt_embeddings.len()
252 );
253 ensure!(
254 sparse_prompt_embeddings.len() == num_sparse_tokens * e,
255 "sparse_prompt_embeddings len {} ≠ num_sparse·E ({num_sparse_tokens}·{e})",
256 sparse_prompt_embeddings.len()
257 );
258
259 let nm = w.num_mask_tokens;
262 let n_out_tokens = 1 + nm;
263 let q_n = n_out_tokens + num_sparse_tokens;
264 let mut tokens = Vec::with_capacity(q_n * e);
265 tokens.extend_from_slice(&w.iou_token); tokens.extend_from_slice(&w.mask_tokens); tokens.extend_from_slice(sparse_prompt_embeddings); let mut src = image_embeddings.to_vec();
272 for i in 0..src.len() {
273 src[i] += dense_prompt_embeddings[i];
274 }
275 let pos_src = image_pe.to_vec();
276
277 let k_n = hw * hw;
279 let (hs, src_post) = if let Some(tw) = tw_ir {
280 if tw.masked && q_n <= tw.max_q_n && tw.k_n == k_n {
281 tw.run_nchw_masked(&tokens, q_n, &src, &pos_src, hw)?
282 } else if !tw.masked && q_n == tw.max_q_n && tw.k_n == k_n {
283 tw.run_nchw(&tokens, &src, &pos_src, hw)?
284 } else {
285 two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
286 }
287 } else {
288 two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
289 };
290 let iou_token_out: Vec<f32> = hs[..e].to_vec();
294 let mask_tokens_out = &hs[e..e * (1 + nm)];
296
297 let mut src_nchw = vec![0f32; e * hw * hw];
300 for s in 0..hw * hw {
301 for c in 0..e {
302 src_nchw[c * hw * hw + s] = src_post[s * e + c];
303 }
304 }
305
306 let q8 = e / 8;
308 let h2 = hw * 4;
309 let w2 = hw * 4;
310 let up2 = upscale.run(&src_nchw)?;
311
312 let mut hyper_in = vec![0f32; nm * q8];
314 if let Some(mlps) = hyper_mlps_ir {
315 ensure!(
316 mlps.len() == nm,
317 "hyper_mlps_ir len {} ≠ num_mask_tokens {}",
318 mlps.len(),
319 nm
320 );
321 for i in 0..nm {
322 let token = &mask_tokens_out[i * e..(i + 1) * e];
323 let h = mlps[i].run(token, 1)?;
324 hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
325 }
326 } else {
327 for i in 0..nm {
328 let token = &mask_tokens_out[i * e..(i + 1) * e];
329 let h = mlp_forward(&w.hyper_mlps[i], token, 1);
330 hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
331 }
332 }
333 let spat = h2 * w2;
336 let mut masks_all = vec![0f32; nm * spat];
337 if let Some(hm) = hyper_matmul {
338 hm.run(&hyper_in, &up2, &mut masks_all)?;
339 } else {
340 rlx_cpu::blas::sgemm_auto(&hyper_in, &up2, &mut masks_all, nm, q8, spat);
341 }
342
343 let iou_pred_all = if let Some(head) = iou_head_ir {
345 head.run(&iou_token_out, 1)?
346 } else {
347 mlp_forward(&w.iou_head, &iou_token_out, 1)
348 };
349
350 let (masks, iou_pred, num_masks) = if multimask_output {
352 let mut masks = vec![0f32; (nm - 1) * spat];
354 masks.copy_from_slice(&masks_all[spat..]);
355 let mut iou = vec![0f32; nm - 1];
356 iou.copy_from_slice(&iou_pred_all[1..]);
357 (masks, iou, nm - 1)
358 } else {
359 let masks = masks_all[..spat].to_vec();
360 let iou = iou_pred_all[..1].to_vec();
361 (masks, iou, 1)
362 };
363
364 Ok((masks, iou_pred, num_masks, h2))
365}