1use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM};
24use super::prompt_mask_ir::SamPromptMaskCompiled;
25use anyhow::{Result, ensure};
26use rlx_core::weight_map::WeightMap;
27
28pub struct PromptEncoderWeights {
31 pub pe_gaussian: Vec<f32>,
34 pub not_a_point_embed: Vec<f32>,
37 pub point_embeddings: Vec<f32>,
41 pub mask_conv1_w: Vec<f32>,
44 pub mask_conv1_b: Vec<f32>,
45 pub mask_ln1_g: Vec<f32>,
46 pub mask_ln1_b: Vec<f32>,
47 pub mask_conv2_w: Vec<f32>,
48 pub mask_conv2_b: Vec<f32>,
49 pub mask_ln2_g: Vec<f32>,
50 pub mask_ln2_b: Vec<f32>,
51 pub mask_conv3_w: Vec<f32>,
52 pub mask_conv3_b: Vec<f32>,
53 pub no_mask_embed: Vec<f32>,
56 pub embed_dim: usize,
57 pub mask_in_chans: usize,
59}
60
61pub(super) fn extract_prompt_encoder_weights(
62 weights: &mut WeightMap,
63 embed_dim: usize,
64 mask_in_chans: usize,
65) -> Result<PromptEncoderWeights> {
66 let half = embed_dim / 2;
67 let (pe_gaussian, sh) =
68 weights.take("prompt_encoder.pe_layer.positional_encoding_gaussian_matrix")?;
69 ensure!(
70 sh == vec![2, half],
71 "pe_gaussian expected [2, {half}], got {sh:?}"
72 );
73
74 let (not_a_point_embed, _) = weights.take("prompt_encoder.not_a_point_embed.weight")?;
75 let (no_mask_embed, _) = weights.take("prompt_encoder.no_mask_embed.weight")?;
76
77 let mut point_embeddings = vec![0f32; 4 * embed_dim];
79 for i in 0..4 {
80 let (data, _) = weights.take(&format!("prompt_encoder.point_embeddings.{i}.weight"))?;
81 point_embeddings[i * embed_dim..(i + 1) * embed_dim].copy_from_slice(&data);
82 }
83
84 let q = mask_in_chans / 4;
85 let (mask_conv1_w, sh1) = weights.take("prompt_encoder.mask_downscaling.0.weight")?;
86 ensure!(
87 sh1 == vec![q, 1, 2, 2],
88 "mask_downscaling.0.weight expected [{q}, 1, 2, 2], got {sh1:?}"
89 );
90 let (mask_conv1_b, _) = weights.take("prompt_encoder.mask_downscaling.0.bias")?;
91 let (mask_ln1_g, _) = weights.take("prompt_encoder.mask_downscaling.1.weight")?;
92 let (mask_ln1_b, _) = weights.take("prompt_encoder.mask_downscaling.1.bias")?;
93
94 let (mask_conv2_w, sh2) = weights.take("prompt_encoder.mask_downscaling.3.weight")?;
95 ensure!(
96 sh2 == vec![mask_in_chans, q, 2, 2],
97 "mask_downscaling.3.weight expected [{mask_in_chans}, {q}, 2, 2], got {sh2:?}"
98 );
99 let (mask_conv2_b, _) = weights.take("prompt_encoder.mask_downscaling.3.bias")?;
100 let (mask_ln2_g, _) = weights.take("prompt_encoder.mask_downscaling.4.weight")?;
101 let (mask_ln2_b, _) = weights.take("prompt_encoder.mask_downscaling.4.bias")?;
102
103 let (mask_conv3_w, sh3) = weights.take("prompt_encoder.mask_downscaling.6.weight")?;
104 ensure!(
105 sh3 == vec![embed_dim, mask_in_chans, 1, 1],
106 "mask_downscaling.6.weight expected [{embed_dim}, {mask_in_chans}, 1, 1], got {sh3:?}"
107 );
108 let (mask_conv3_b, _) = weights.take("prompt_encoder.mask_downscaling.6.bias")?;
109
110 Ok(PromptEncoderWeights {
111 pe_gaussian,
112 not_a_point_embed,
113 point_embeddings,
114 mask_conv1_w,
115 mask_conv1_b,
116 mask_ln1_g,
117 mask_ln1_b,
118 mask_conv2_w,
119 mask_conv2_b,
120 mask_ln2_g,
121 mask_ln2_b,
122 mask_conv3_w,
123 mask_conv3_b,
124 no_mask_embed,
125 embed_dim,
126 mask_in_chans,
127 })
128}
129
130pub struct PromptEncoderOutput {
133 pub sparse_embeddings: Vec<f32>,
136 pub num_sparse_tokens: usize,
137 pub dense_embeddings: Vec<f32>,
140 pub image_pe: Vec<f32>,
143}
144
145pub fn prompt_encoder_forward(
155 w: &PromptEncoderWeights,
156 mask_stack: &mut SamPromptMaskCompiled,
157 points: Option<(&[f32], &[f32])>,
158 boxes: Option<&[f32]>,
159 masks: Option<&[f32]>,
160) -> Result<PromptEncoderOutput> {
161 let e = w.embed_dim;
162 let hw = SAM_EMBED_HW;
163
164 let pad_points = boxes.is_none();
166 let mut sparse = Vec::new();
167
168 if let Some((coords, labels)) = points {
169 let n = labels.len();
170 ensure!(
171 coords.len() == n * 2,
172 "points coords len {} ≠ N·2 ({}·2)",
173 coords.len(),
174 n
175 );
176 let mut pts: Vec<f32> = coords.iter().map(|c| c + 0.5).collect();
178 let mut lbls = labels.to_vec();
179 if pad_points {
180 pts.push(0.0);
182 pts.push(0.0);
183 lbls.push(-1.0);
184 }
185 let n_padded = lbls.len();
186 let emb = embed_points_and_boxes(w, &pts, n_padded, false, Some(&lbls))?;
187 sparse.extend_from_slice(&emb);
188 }
189 if let Some(box_coords) = boxes {
190 let m = box_coords.len() / 4;
191 ensure!(box_coords.len() == m * 4, "boxes len must be multiple of 4");
192 let coords_with_half: Vec<f32> = box_coords.iter().map(|c| c + 0.5).collect();
193 let emb = embed_points_and_boxes(w, &coords_with_half, m * 2, true, None)?;
194 sparse.extend_from_slice(&emb);
195 }
196 let num_sparse_tokens = if sparse.is_empty() {
197 0
198 } else {
199 sparse.len() / e
200 };
201
202 let dense_embeddings = match masks {
204 Some(m) => embed_mask(mask_stack, m, hw)?,
205 None => {
206 let mut out = vec![0f32; e * hw * hw];
208 for c in 0..e {
209 let v = w.no_mask_embed[c];
210 let plane = &mut out[c * hw * hw..(c + 1) * hw * hw];
211 plane.fill(v);
212 }
213 out
214 }
215 };
216
217 let image_pe = compute_image_pe(w, hw, hw);
219
220 Ok(PromptEncoderOutput {
221 sparse_embeddings: sparse,
222 num_sparse_tokens,
223 dense_embeddings,
224 image_pe,
225 })
226}
227
228fn compute_image_pe(w: &PromptEncoderWeights, h: usize, ww: usize) -> Vec<f32> {
231 let e = w.embed_dim;
232 let half = e / 2;
233 let mut out = vec![0f32; e * h * ww];
234 for y in 0..h {
237 let fy = (y as f32 + 0.5) / h as f32;
238 for x in 0..ww {
239 let fx = (x as f32 + 0.5) / ww as f32;
240 let cx = fx * 2.0 - 1.0;
242 let cy = fy * 2.0 - 1.0;
243 for k in 0..half {
245 let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
246 acc *= 2.0 * std::f32::consts::PI;
247 out[k * h * ww + y * ww + x] = acc.sin();
248 out[(half + k) * h * ww + y * ww + x] = acc.cos();
249 }
250 }
251 }
252 out
253}
254
255fn pe_encode_normalized(w: &PromptEncoderWeights, coords: &[f32], n: usize) -> Vec<f32> {
259 let e = w.embed_dim;
260 let half = e / 2;
261 let mut out = vec![0f32; n * e];
262 for i in 0..n {
263 let cx = coords[i * 2] * 2.0 - 1.0;
264 let cy = coords[i * 2 + 1] * 2.0 - 1.0;
265 for k in 0..half {
266 let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
267 acc *= 2.0 * std::f32::consts::PI;
268 out[i * e + k] = acc.sin();
269 out[i * e + half + k] = acc.cos();
270 }
271 }
272 out
273}
274
275fn embed_points_and_boxes(
280 w: &PromptEncoderWeights,
281 coords_in_pixels: &[f32], n: usize,
283 is_box: bool,
284 labels: Option<&[f32]>,
285) -> Result<Vec<f32>> {
286 let e = w.embed_dim;
287 let img = SAM_IMG_SIZE as f32;
289 let normed: Vec<f32> = coords_in_pixels.iter().map(|c| c / img).collect();
290 let mut emb = pe_encode_normalized(w, &normed, n);
291
292 if is_box {
293 for i in 0..n {
295 let pe_idx = if i % 2 == 0 { 2 } else { 3 };
296 for k in 0..e {
297 emb[i * e + k] += w.point_embeddings[pe_idx * e + k];
298 }
299 }
300 } else if let Some(lbls) = labels {
301 ensure!(lbls.len() == n, "labels len {} ≠ n {n}", lbls.len());
302 for i in 0..n {
303 let label = lbls[i];
304 if label < 0.0 {
305 for k in 0..e {
307 emb[i * e + k] = w.not_a_point_embed[k];
308 }
309 } else if label == 0.0 {
310 for k in 0..e {
311 emb[i * e + k] += w.point_embeddings[k];
312 }
313 } else {
314 for k in 0..e {
316 emb[i * e + k] += w.point_embeddings[e + k];
317 }
318 }
319 }
320 }
321 Ok(emb)
322}
323
324fn embed_mask(stack: &mut SamPromptMaskCompiled, mask: &[f32], hw: usize) -> Result<Vec<f32>> {
329 let in_h = 4 * hw;
330 let in_w = 4 * hw;
331 ensure!(
332 mask.len() == in_h * in_w,
333 "mask must be [1, {in_h}, {in_w}], got len {}",
334 mask.len()
335 );
336 stack.run(mask, in_h, in_w)
338}
339
340#[allow(dead_code)]
347fn conv2d_stride2_k2_pad0(
348 input: &[f32],
349 in_c: usize,
350 out_c: usize,
351 in_h: usize,
352 in_w: usize,
353 weight: &[f32], bias: &[f32], ) -> Vec<f32> {
356 let out_h = in_h / 2;
357 let out_w = in_w / 2;
358 let mut out = vec![0f32; out_c * out_h * out_w];
359 for oc in 0..out_c {
360 for oy in 0..out_h {
361 for ox in 0..out_w {
362 let mut acc = bias[oc];
363 for ic in 0..in_c {
364 for ky in 0..2 {
365 let iy = oy * 2 + ky;
366 for kx in 0..2 {
367 let ix = ox * 2 + kx;
368 let v = input[ic * in_h * in_w + iy * in_w + ix];
369 let w_idx = ((oc * in_c + ic) * 2 + ky) * 2 + kx;
370 acc += v * weight[w_idx];
371 }
372 }
373 }
374 out[oc * out_h * out_w + oy * out_w + ox] = acc;
375 }
376 }
377 }
378 out
379}
380
381#[allow(dead_code)]
383fn conv2d_1x1(
384 input: &[f32],
385 in_c: usize,
386 out_c: usize,
387 h: usize,
388 w: usize,
389 weight: &[f32], bias: &[f32], ) -> Vec<f32> {
392 let mut out = vec![0f32; out_c * h * w];
393 for oc in 0..out_c {
394 let b = bias[oc];
395 for y in 0..h {
396 for x in 0..w {
397 let mut acc = b;
398 for ic in 0..in_c {
399 acc += input[ic * h * w + y * w + x] * weight[oc * in_c + ic];
400 }
401 out[oc * h * w + y * w + x] = acc;
402 }
403 }
404 }
405 out
406}
407
408#[allow(dead_code)]
411fn layernorm2d_nchw(
412 data: &mut [f32],
413 c: usize,
414 h: usize,
415 w: usize,
416 gamma: &[f32],
417 beta: &[f32],
418 eps: f32,
419) {
420 let n = h * w;
421 for i in 0..n {
422 let mut mean = 0f32;
423 for k in 0..c {
424 mean += data[k * n + i];
425 }
426 mean /= c as f32;
427 let mut var = 0f32;
428 for k in 0..c {
429 let d = data[k * n + i] - mean;
430 var += d * d;
431 }
432 var /= c as f32;
433 let inv = 1.0 / (var + eps).sqrt();
434 for k in 0..c {
435 let v = (data[k * n + i] - mean) * inv;
436 data[k * n + i] = v * gamma[k] + beta[k];
437 }
438 }
439}
440
441#[allow(dead_code)]
443pub(super) fn gelu_erf_inplace(data: &mut [f32]) {
444 const INV_SQRT2: f32 = std::f32::consts::FRAC_1_SQRT_2;
445 for v in data.iter_mut() {
446 let x = *v;
449 let s = (x * INV_SQRT2).abs();
450 let p = 0.327_591_1;
451 let a1 = 0.254_829_6;
452 let a2 = -0.284_496_7;
453 let a3 = 1.421_413_8;
454 let a4 = -1.453_152_1;
455 let a5 = 1.061_405_4;
456 let t = 1.0 / (1.0 + p * s);
457 let y = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
458 let erf_abs = 1.0 - y * (-s * s).exp();
459 let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
460 *v = 0.5 * x * (1.0 + erf);
461 }
462}
463
464#[cfg(test)]
465#[allow(dead_code)]
466pub(super) fn assert_shape(label: &str, actual: usize, expected: usize) {
467 assert_eq!(actual, expected, "{label}: {actual} ≠ {expected}");
468}
469
470#[allow(dead_code)]
471fn _silence_constant() {
472 let _ = SAM_PROMPT_EMBED_DIM;
473}