use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM};
use super::prompt_mask_ir::SamPromptMaskCompiled;
use anyhow::{Result, ensure};
use rlx_core::weight_map::WeightMap;
pub struct PromptEncoderWeights {
pub pe_gaussian: Vec<f32>,
pub not_a_point_embed: Vec<f32>,
pub point_embeddings: Vec<f32>,
pub mask_conv1_w: Vec<f32>,
pub mask_conv1_b: Vec<f32>,
pub mask_ln1_g: Vec<f32>,
pub mask_ln1_b: Vec<f32>,
pub mask_conv2_w: Vec<f32>,
pub mask_conv2_b: Vec<f32>,
pub mask_ln2_g: Vec<f32>,
pub mask_ln2_b: Vec<f32>,
pub mask_conv3_w: Vec<f32>,
pub mask_conv3_b: Vec<f32>,
pub no_mask_embed: Vec<f32>,
pub embed_dim: usize,
pub mask_in_chans: usize,
}
pub(super) fn extract_prompt_encoder_weights(
weights: &mut WeightMap,
embed_dim: usize,
mask_in_chans: usize,
) -> Result<PromptEncoderWeights> {
let half = embed_dim / 2;
let (pe_gaussian, sh) =
weights.take("prompt_encoder.pe_layer.positional_encoding_gaussian_matrix")?;
ensure!(
sh == vec![2, half],
"pe_gaussian expected [2, {half}], got {sh:?}"
);
let (not_a_point_embed, _) = weights.take("prompt_encoder.not_a_point_embed.weight")?;
let (no_mask_embed, _) = weights.take("prompt_encoder.no_mask_embed.weight")?;
let mut point_embeddings = vec![0f32; 4 * embed_dim];
for i in 0..4 {
let (data, _) = weights.take(&format!("prompt_encoder.point_embeddings.{i}.weight"))?;
point_embeddings[i * embed_dim..(i + 1) * embed_dim].copy_from_slice(&data);
}
let q = mask_in_chans / 4;
let (mask_conv1_w, sh1) = weights.take("prompt_encoder.mask_downscaling.0.weight")?;
ensure!(
sh1 == vec![q, 1, 2, 2],
"mask_downscaling.0.weight expected [{q}, 1, 2, 2], got {sh1:?}"
);
let (mask_conv1_b, _) = weights.take("prompt_encoder.mask_downscaling.0.bias")?;
let (mask_ln1_g, _) = weights.take("prompt_encoder.mask_downscaling.1.weight")?;
let (mask_ln1_b, _) = weights.take("prompt_encoder.mask_downscaling.1.bias")?;
let (mask_conv2_w, sh2) = weights.take("prompt_encoder.mask_downscaling.3.weight")?;
ensure!(
sh2 == vec![mask_in_chans, q, 2, 2],
"mask_downscaling.3.weight expected [{mask_in_chans}, {q}, 2, 2], got {sh2:?}"
);
let (mask_conv2_b, _) = weights.take("prompt_encoder.mask_downscaling.3.bias")?;
let (mask_ln2_g, _) = weights.take("prompt_encoder.mask_downscaling.4.weight")?;
let (mask_ln2_b, _) = weights.take("prompt_encoder.mask_downscaling.4.bias")?;
let (mask_conv3_w, sh3) = weights.take("prompt_encoder.mask_downscaling.6.weight")?;
ensure!(
sh3 == vec![embed_dim, mask_in_chans, 1, 1],
"mask_downscaling.6.weight expected [{embed_dim}, {mask_in_chans}, 1, 1], got {sh3:?}"
);
let (mask_conv3_b, _) = weights.take("prompt_encoder.mask_downscaling.6.bias")?;
Ok(PromptEncoderWeights {
pe_gaussian,
not_a_point_embed,
point_embeddings,
mask_conv1_w,
mask_conv1_b,
mask_ln1_g,
mask_ln1_b,
mask_conv2_w,
mask_conv2_b,
mask_ln2_g,
mask_ln2_b,
mask_conv3_w,
mask_conv3_b,
no_mask_embed,
embed_dim,
mask_in_chans,
})
}
pub struct PromptEncoderOutput {
pub sparse_embeddings: Vec<f32>,
pub num_sparse_tokens: usize,
pub dense_embeddings: Vec<f32>,
pub image_pe: Vec<f32>,
}
pub fn prompt_encoder_forward(
w: &PromptEncoderWeights,
mask_stack: &mut SamPromptMaskCompiled,
points: Option<(&[f32], &[f32])>,
boxes: Option<&[f32]>,
masks: Option<&[f32]>,
) -> Result<PromptEncoderOutput> {
let e = w.embed_dim;
let hw = SAM_EMBED_HW;
let pad_points = boxes.is_none();
let mut sparse = Vec::new();
if let Some((coords, labels)) = points {
let n = labels.len();
ensure!(
coords.len() == n * 2,
"points coords len {} ≠ N·2 ({}·2)",
coords.len(),
n
);
let mut pts: Vec<f32> = coords.iter().map(|c| c + 0.5).collect();
let mut lbls = labels.to_vec();
if pad_points {
pts.push(0.0);
pts.push(0.0);
lbls.push(-1.0);
}
let n_padded = lbls.len();
let emb = embed_points_and_boxes(w, &pts, n_padded, false, Some(&lbls))?;
sparse.extend_from_slice(&emb);
}
if let Some(box_coords) = boxes {
let m = box_coords.len() / 4;
ensure!(box_coords.len() == m * 4, "boxes len must be multiple of 4");
let coords_with_half: Vec<f32> = box_coords.iter().map(|c| c + 0.5).collect();
let emb = embed_points_and_boxes(w, &coords_with_half, m * 2, true, None)?;
sparse.extend_from_slice(&emb);
}
let num_sparse_tokens = if sparse.is_empty() {
0
} else {
sparse.len() / e
};
let dense_embeddings = match masks {
Some(m) => embed_mask(mask_stack, m, hw)?,
None => {
let mut out = vec![0f32; e * hw * hw];
for c in 0..e {
let v = w.no_mask_embed[c];
let plane = &mut out[c * hw * hw..(c + 1) * hw * hw];
plane.fill(v);
}
out
}
};
let image_pe = compute_image_pe(w, hw, hw);
Ok(PromptEncoderOutput {
sparse_embeddings: sparse,
num_sparse_tokens,
dense_embeddings,
image_pe,
})
}
fn compute_image_pe(w: &PromptEncoderWeights, h: usize, ww: usize) -> Vec<f32> {
let e = w.embed_dim;
let half = e / 2;
let mut out = vec![0f32; e * h * ww];
for y in 0..h {
let fy = (y as f32 + 0.5) / h as f32;
for x in 0..ww {
let fx = (x as f32 + 0.5) / ww as f32;
let cx = fx * 2.0 - 1.0;
let cy = fy * 2.0 - 1.0;
for k in 0..half {
let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
acc *= 2.0 * std::f32::consts::PI;
out[k * h * ww + y * ww + x] = acc.sin();
out[(half + k) * h * ww + y * ww + x] = acc.cos();
}
}
}
out
}
fn pe_encode_normalized(w: &PromptEncoderWeights, coords: &[f32], n: usize) -> Vec<f32> {
let e = w.embed_dim;
let half = e / 2;
let mut out = vec![0f32; n * e];
for i in 0..n {
let cx = coords[i * 2] * 2.0 - 1.0;
let cy = coords[i * 2 + 1] * 2.0 - 1.0;
for k in 0..half {
let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
acc *= 2.0 * std::f32::consts::PI;
out[i * e + k] = acc.sin();
out[i * e + half + k] = acc.cos();
}
}
out
}
fn embed_points_and_boxes(
w: &PromptEncoderWeights,
coords_in_pixels: &[f32], n: usize,
is_box: bool,
labels: Option<&[f32]>,
) -> Result<Vec<f32>> {
let e = w.embed_dim;
let img = SAM_IMG_SIZE as f32;
let normed: Vec<f32> = coords_in_pixels.iter().map(|c| c / img).collect();
let mut emb = pe_encode_normalized(w, &normed, n);
if is_box {
for i in 0..n {
let pe_idx = if i % 2 == 0 { 2 } else { 3 };
for k in 0..e {
emb[i * e + k] += w.point_embeddings[pe_idx * e + k];
}
}
} else if let Some(lbls) = labels {
ensure!(lbls.len() == n, "labels len {} ≠ n {n}", lbls.len());
for i in 0..n {
let label = lbls[i];
if label < 0.0 {
for k in 0..e {
emb[i * e + k] = w.not_a_point_embed[k];
}
} else if label == 0.0 {
for k in 0..e {
emb[i * e + k] += w.point_embeddings[k];
}
} else {
for k in 0..e {
emb[i * e + k] += w.point_embeddings[e + k];
}
}
}
}
Ok(emb)
}
fn embed_mask(stack: &mut SamPromptMaskCompiled, mask: &[f32], hw: usize) -> Result<Vec<f32>> {
let in_h = 4 * hw;
let in_w = 4 * hw;
ensure!(
mask.len() == in_h * in_w,
"mask must be [1, {in_h}, {in_w}], got len {}",
mask.len()
);
stack.run(mask, in_h, in_w)
}
#[allow(dead_code)]
fn conv2d_stride2_k2_pad0(
input: &[f32],
in_c: usize,
out_c: usize,
in_h: usize,
in_w: usize,
weight: &[f32], bias: &[f32], ) -> Vec<f32> {
let out_h = in_h / 2;
let out_w = in_w / 2;
let mut out = vec![0f32; out_c * out_h * out_w];
for oc in 0..out_c {
for oy in 0..out_h {
for ox in 0..out_w {
let mut acc = bias[oc];
for ic in 0..in_c {
for ky in 0..2 {
let iy = oy * 2 + ky;
for kx in 0..2 {
let ix = ox * 2 + kx;
let v = input[ic * in_h * in_w + iy * in_w + ix];
let w_idx = ((oc * in_c + ic) * 2 + ky) * 2 + kx;
acc += v * weight[w_idx];
}
}
}
out[oc * out_h * out_w + oy * out_w + ox] = acc;
}
}
}
out
}
#[allow(dead_code)]
fn conv2d_1x1(
input: &[f32],
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32], bias: &[f32], ) -> Vec<f32> {
let mut out = vec![0f32; out_c * h * w];
for oc in 0..out_c {
let b = bias[oc];
for y in 0..h {
for x in 0..w {
let mut acc = b;
for ic in 0..in_c {
acc += input[ic * h * w + y * w + x] * weight[oc * in_c + ic];
}
out[oc * h * w + y * w + x] = acc;
}
}
}
out
}
#[allow(dead_code)]
fn layernorm2d_nchw(
data: &mut [f32],
c: usize,
h: usize,
w: usize,
gamma: &[f32],
beta: &[f32],
eps: f32,
) {
let n = h * w;
for i in 0..n {
let mut mean = 0f32;
for k in 0..c {
mean += data[k * n + i];
}
mean /= c as f32;
let mut var = 0f32;
for k in 0..c {
let d = data[k * n + i] - mean;
var += d * d;
}
var /= c as f32;
let inv = 1.0 / (var + eps).sqrt();
for k in 0..c {
let v = (data[k * n + i] - mean) * inv;
data[k * n + i] = v * gamma[k] + beta[k];
}
}
}
#[allow(dead_code)]
pub(super) fn gelu_erf_inplace(data: &mut [f32]) {
const INV_SQRT2: f32 = std::f32::consts::FRAC_1_SQRT_2;
for v in data.iter_mut() {
let x = *v;
let s = (x * INV_SQRT2).abs();
let p = 0.327_591_1;
let a1 = 0.254_829_6;
let a2 = -0.284_496_7;
let a3 = 1.421_413_8;
let a4 = -1.453_152_1;
let a5 = 1.061_405_4;
let t = 1.0 / (1.0 + p * s);
let y = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
let erf_abs = 1.0 - y * (-s * s).exp();
let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
*v = 0.5 * x * (1.0 + erf);
}
}
#[cfg(test)]
#[allow(dead_code)]
pub(super) fn assert_shape(label: &str, actual: usize, expected: usize) {
assert_eq!(actual, expected, "{label}: {actual} ≠ {expected}");
}
#[allow(dead_code)]
fn _silence_constant() {
let _ = SAM_PROMPT_EMBED_DIM;
}