use super::config::SAM_EMBED_HW;
use super::transformer::{
TwoWayTransformerWeights, extract_two_way_transformer_weights, linear,
two_way_transformer_forward,
};
use super::upscale_ir::SamMaskUpscaleCompiled;
use anyhow::{Result, ensure};
use rlx_core::weight_map::WeightMap;
use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
pub struct MaskDecoderWeights {
pub iou_token: Vec<f32>, pub mask_tokens: Vec<f32>, pub transformer: TwoWayTransformerWeights,
pub upscale_conv1_w: Vec<f32>,
pub upscale_conv1_b: Vec<f32>,
pub upscale_ln_g: Vec<f32>,
pub upscale_ln_b: Vec<f32>,
pub upscale_conv2_w: Vec<f32>,
pub upscale_conv2_b: Vec<f32>,
pub hyper_mlps: Vec<HypernetMlp>,
pub iou_head: HypernetMlp,
pub transformer_dim: usize,
pub num_mask_tokens: usize,
}
pub struct HypernetMlp {
pub layers: Vec<MlpLayer>,
}
pub struct MlpLayer {
pub w: Vec<f32>,
pub b: Vec<f32>,
pub in_d: usize,
pub out_d: usize,
}
pub(super) fn extract_mask_decoder_weights(
weights: &mut WeightMap,
transformer_dim: usize,
num_mask_tokens: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
transformer_depth: usize,
transformer_num_heads: usize,
transformer_mlp_dim: usize,
) -> Result<MaskDecoderWeights> {
let (iou_token, sh) = weights.take("mask_decoder.iou_token.weight")?;
ensure!(
sh == vec![1, transformer_dim],
"iou_token shape {sh:?} not [1, {transformer_dim}]"
);
let (mask_tokens, sh) = weights.take("mask_decoder.mask_tokens.weight")?;
ensure!(
sh == vec![num_mask_tokens, transformer_dim],
"mask_tokens shape {sh:?} not [{num_mask_tokens}, {transformer_dim}]"
);
let q4 = transformer_dim / 4;
let q8 = transformer_dim / 8;
let (upscale_conv1_w, sh) = weights.take("mask_decoder.output_upscaling.0.weight")?;
ensure!(
sh == vec![transformer_dim, q4, 2, 2],
"output_upscaling.0.weight shape {sh:?} not [{transformer_dim}, {q4}, 2, 2]"
);
let (upscale_conv1_b, _) = weights.take("mask_decoder.output_upscaling.0.bias")?;
let (upscale_ln_g, _) = weights.take("mask_decoder.output_upscaling.1.weight")?;
let (upscale_ln_b, _) = weights.take("mask_decoder.output_upscaling.1.bias")?;
let (upscale_conv2_w, sh) = weights.take("mask_decoder.output_upscaling.3.weight")?;
ensure!(
sh == vec![q4, q8, 2, 2],
"output_upscaling.3.weight shape {sh:?} not [{q4}, {q8}, 2, 2]"
);
let (upscale_conv2_b, _) = weights.take("mask_decoder.output_upscaling.3.bias")?;
let mut hyper_mlps = Vec::with_capacity(num_mask_tokens);
for i in 0..num_mask_tokens {
let mlp = extract_mlp(
weights,
&format!("mask_decoder.output_hypernetworks_mlps.{i}"),
transformer_dim,
transformer_dim,
q8,
3,
)?;
hyper_mlps.push(mlp);
}
let iou_head = extract_mlp(
weights,
"mask_decoder.iou_prediction_head",
transformer_dim,
iou_head_hidden_dim,
num_mask_tokens,
iou_head_depth,
)?;
let transformer = extract_two_way_transformer_weights(
weights,
transformer_dim,
transformer_depth,
transformer_num_heads,
transformer_mlp_dim,
)?;
Ok(MaskDecoderWeights {
iou_token,
mask_tokens,
transformer,
upscale_conv1_w,
upscale_conv1_b,
upscale_ln_g,
upscale_ln_b,
upscale_conv2_w,
upscale_conv2_b,
hyper_mlps,
iou_head,
transformer_dim,
num_mask_tokens,
})
}
fn extract_mlp(
weights: &mut WeightMap,
prefix: &str,
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
num_layers: usize,
) -> Result<HypernetMlp> {
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let in_d = if i == 0 { input_dim } else { hidden_dim };
let out_d = if i + 1 == num_layers {
output_dim
} else {
hidden_dim
};
let (w, sh) = weights.take(&format!("{prefix}.layers.{i}.weight"))?;
ensure!(
sh == vec![out_d, in_d],
"{prefix}.layers.{i}.weight shape {sh:?} not [{out_d}, {in_d}]"
);
let (b, _) = weights.take(&format!("{prefix}.layers.{i}.bias"))?;
layers.push(MlpLayer { w, b, in_d, out_d });
}
Ok(HypernetMlp { layers })
}
pub fn mlp_forward(mlp: &HypernetMlp, x: &[f32], rows: usize) -> Vec<f32> {
let mut cur = x.to_vec();
let n = mlp.layers.len();
for (i, layer) in mlp.layers.iter().enumerate() {
cur = linear(&cur, &layer.w, &layer.b, rows, layer.in_d, layer.out_d);
if i + 1 < n {
for v in cur.iter_mut() {
if *v < 0.0 {
*v = 0.0;
}
}
}
}
cur
}
pub fn mask_decoder_forward(
w: &MaskDecoderWeights,
upscale: &mut SamMaskUpscaleCompiled,
hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
iou_head_ir: Option<&mut MlpReluCompiled>,
tw_ir: Option<&mut TwoWayTransformerCompiled>,
image_embeddings: &[f32],
image_pe: &[f32],
sparse_prompt_embeddings: &[f32],
num_sparse_tokens: usize,
dense_prompt_embeddings: &[f32],
multimask_output: bool,
) -> Result<(Vec<f32>, Vec<f32>, usize, usize)> {
let e = w.transformer_dim;
let hw = SAM_EMBED_HW;
ensure!(
image_embeddings.len() == e * hw * hw,
"image_embeddings len {} ≠ E·hw·hw ({e}·{hw}·{hw})",
image_embeddings.len()
);
ensure!(
image_pe.len() == e * hw * hw,
"image_pe len {} ≠ E·hw·hw",
image_pe.len()
);
ensure!(
dense_prompt_embeddings.len() == e * hw * hw,
"dense_prompt_embeddings len {} ≠ E·hw·hw",
dense_prompt_embeddings.len()
);
ensure!(
sparse_prompt_embeddings.len() == num_sparse_tokens * e,
"sparse_prompt_embeddings len {} ≠ num_sparse·E ({num_sparse_tokens}·{e})",
sparse_prompt_embeddings.len()
);
let nm = w.num_mask_tokens;
let n_out_tokens = 1 + nm;
let q_n = n_out_tokens + num_sparse_tokens;
let mut tokens = Vec::with_capacity(q_n * e);
tokens.extend_from_slice(&w.iou_token); tokens.extend_from_slice(&w.mask_tokens); tokens.extend_from_slice(sparse_prompt_embeddings);
let mut src = image_embeddings.to_vec();
for i in 0..src.len() {
src[i] += dense_prompt_embeddings[i];
}
let pos_src = image_pe.to_vec();
let k_n = hw * hw;
let (hs, src_post) = if let Some(tw) = tw_ir {
if tw.masked && q_n <= tw.max_q_n && tw.k_n == k_n {
tw.run_nchw_masked(&tokens, q_n, &src, &pos_src, hw)?
} else if !tw.masked && q_n == tw.max_q_n && tw.k_n == k_n {
tw.run_nchw(&tokens, &src, &pos_src, hw)?
} else {
two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
}
} else {
two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, hw, hw, q_n)
};
let iou_token_out: Vec<f32> = hs[..e].to_vec();
let mask_tokens_out = &hs[e..e * (1 + nm)];
let mut src_nchw = vec![0f32; e * hw * hw];
for s in 0..hw * hw {
for c in 0..e {
src_nchw[c * hw * hw + s] = src_post[s * e + c];
}
}
let q8 = e / 8;
let h2 = hw * 4;
let w2 = hw * 4;
let up2 = upscale.run(&src_nchw)?;
let mut hyper_in = vec![0f32; nm * q8];
if let Some(mlps) = hyper_mlps_ir {
ensure!(
mlps.len() == nm,
"hyper_mlps_ir len {} ≠ num_mask_tokens {}",
mlps.len(),
nm
);
for i in 0..nm {
let token = &mask_tokens_out[i * e..(i + 1) * e];
let h = mlps[i].run(token, 1)?;
hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
}
} else {
for i in 0..nm {
let token = &mask_tokens_out[i * e..(i + 1) * e];
let h = mlp_forward(&w.hyper_mlps[i], token, 1);
hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
}
}
let spat = h2 * w2;
let mut masks_all = vec![0f32; nm * spat];
if let Some(hm) = hyper_matmul {
hm.run(&hyper_in, &up2, &mut masks_all)?;
} else {
rlx_cpu::blas::sgemm_auto(&hyper_in, &up2, &mut masks_all, nm, q8, spat);
}
let iou_pred_all = if let Some(head) = iou_head_ir {
head.run(&iou_token_out, 1)?
} else {
mlp_forward(&w.iou_head, &iou_token_out, 1)
};
let (masks, iou_pred, num_masks) = if multimask_output {
let mut masks = vec![0f32; (nm - 1) * spat];
masks.copy_from_slice(&masks_all[spat..]);
let mut iou = vec![0f32; nm - 1];
iou.copy_from_slice(&iou_pred_all[1..]);
(masks, iou, nm - 1)
} else {
let masks = masks_all[..spat].to_vec();
let iou = iou_pred_all[..1].to_vec();
(masks, iou, 1)
};
Ok((masks, iou_pred, num_masks, h2))
}