1use super::config::SAM2_IMG_SIZE;
36use super::prompt_mask_ir::Sam2PromptMaskCompiled;
37use anyhow::{Result, ensure};
38use rlx_core::weight_map::WeightMap;
39
40pub const SAM2_PROMPT_GRID: usize = 64;
45
46pub const SAM2_MASK_IN_CHANS: usize = 16;
49
50pub struct Sam2PromptEncoderWeights {
53 pub pe_gaussian: Vec<f32>,
56 pub not_a_point_embed: Vec<f32>,
59 pub point_embeddings: Vec<f32>,
63 pub mask_conv1_w: Vec<f32>,
66 pub mask_conv1_b: Vec<f32>,
67 pub mask_ln1_g: Vec<f32>,
68 pub mask_ln1_b: Vec<f32>,
69 pub mask_conv2_w: Vec<f32>,
70 pub mask_conv2_b: Vec<f32>,
71 pub mask_ln2_g: Vec<f32>,
72 pub mask_ln2_b: Vec<f32>,
73 pub mask_conv3_w: Vec<f32>,
74 pub mask_conv3_b: Vec<f32>,
75 pub no_mask_embed: Vec<f32>,
78 pub embed_dim: usize,
79 pub mask_in_chans: usize,
81 pub grid: usize,
83}
84
85pub fn extract_prompt_encoder_weights(
88 weights: &mut WeightMap,
89 embed_dim: usize,
90 mask_in_chans: usize,
91) -> Result<Sam2PromptEncoderWeights> {
92 let half = embed_dim / 2;
93 let (pe_gaussian, sh) =
94 weights.take("sam_prompt_encoder.pe_layer.positional_encoding_gaussian_matrix")?;
95 ensure!(
96 sh == vec![2, half],
97 "pe_gaussian expected [2, {half}], got {sh:?}"
98 );
99
100 let (not_a_point_embed, _) = weights.take("sam_prompt_encoder.not_a_point_embed.weight")?;
101 let (no_mask_embed, _) = weights.take("sam_prompt_encoder.no_mask_embed.weight")?;
102
103 let mut point_embeddings = vec![0f32; 4 * embed_dim];
104 for i in 0..4 {
105 let (data, _) = weights.take(&format!("sam_prompt_encoder.point_embeddings.{i}.weight"))?;
106 point_embeddings[i * embed_dim..(i + 1) * embed_dim].copy_from_slice(&data);
107 }
108
109 let q = mask_in_chans / 4;
110 let (mask_conv1_w, sh1) = weights.take("sam_prompt_encoder.mask_downscaling.0.weight")?;
111 ensure!(
112 sh1 == vec![q, 1, 2, 2],
113 "mask_downscaling.0.weight expected [{q}, 1, 2, 2], got {sh1:?}"
114 );
115 let (mask_conv1_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.0.bias")?;
116 let (mask_ln1_g, _) = weights.take("sam_prompt_encoder.mask_downscaling.1.weight")?;
117 let (mask_ln1_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.1.bias")?;
118
119 let (mask_conv2_w, sh2) = weights.take("sam_prompt_encoder.mask_downscaling.3.weight")?;
120 ensure!(
121 sh2 == vec![mask_in_chans, q, 2, 2],
122 "mask_downscaling.3.weight expected [{mask_in_chans}, {q}, 2, 2], got {sh2:?}"
123 );
124 let (mask_conv2_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.3.bias")?;
125 let (mask_ln2_g, _) = weights.take("sam_prompt_encoder.mask_downscaling.4.weight")?;
126 let (mask_ln2_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.4.bias")?;
127
128 let (mask_conv3_w, sh3) = weights.take("sam_prompt_encoder.mask_downscaling.6.weight")?;
129 ensure!(
130 sh3 == vec![embed_dim, mask_in_chans, 1, 1],
131 "mask_downscaling.6.weight expected [{embed_dim}, {mask_in_chans}, 1, 1], got {sh3:?}"
132 );
133 let (mask_conv3_b, _) = weights.take("sam_prompt_encoder.mask_downscaling.6.bias")?;
134
135 Ok(Sam2PromptEncoderWeights {
136 pe_gaussian,
137 not_a_point_embed,
138 point_embeddings,
139 mask_conv1_w,
140 mask_conv1_b,
141 mask_ln1_g,
142 mask_ln1_b,
143 mask_conv2_w,
144 mask_conv2_b,
145 mask_ln2_g,
146 mask_ln2_b,
147 mask_conv3_w,
148 mask_conv3_b,
149 no_mask_embed,
150 embed_dim,
151 mask_in_chans,
152 grid: SAM2_PROMPT_GRID,
153 })
154}
155
156pub struct Sam2PromptEncoderOutput {
159 pub sparse_embeddings: Vec<f32>,
162 pub num_sparse_tokens: usize,
163 pub dense_embeddings: Vec<f32>,
165 pub image_pe: Vec<f32>,
168}
169
170pub fn prompt_encoder_forward(
180 w: &Sam2PromptEncoderWeights,
181 mask_stack: &mut Sam2PromptMaskCompiled,
182 points: Option<(&[f32], &[f32])>,
183 boxes: Option<&[f32]>,
184 masks: Option<&[f32]>,
185) -> Result<Sam2PromptEncoderOutput> {
186 let e = w.embed_dim;
187 let g = w.grid;
188
189 let pad_points = boxes.is_none();
191 let mut sparse = Vec::new();
192
193 if let Some((coords, labels)) = points {
194 let n = labels.len();
195 ensure!(
196 coords.len() == n * 2,
197 "points coords len {} ≠ N·2 ({}·2)",
198 coords.len(),
199 n
200 );
201 let mut pts: Vec<f32> = coords.iter().map(|c| c + 0.5).collect();
202 let mut lbls = labels.to_vec();
203 if pad_points {
204 pts.push(0.0);
205 pts.push(0.0);
206 lbls.push(-1.0);
207 }
208 let n_padded = lbls.len();
209 let emb = embed_points_and_boxes(w, &pts, n_padded, false, Some(&lbls))?;
210 sparse.extend_from_slice(&emb);
211 }
212 if let Some(box_coords) = boxes {
213 let m = box_coords.len() / 4;
214 ensure!(box_coords.len() == m * 4, "boxes len must be multiple of 4");
215 let coords_with_half: Vec<f32> = box_coords.iter().map(|c| c + 0.5).collect();
216 let emb = embed_points_and_boxes(w, &coords_with_half, m * 2, true, None)?;
217 sparse.extend_from_slice(&emb);
218 }
219 let num_sparse_tokens = if sparse.is_empty() {
220 0
221 } else {
222 sparse.len() / e
223 };
224
225 let dense_embeddings = match masks {
227 Some(m) => mask_stack.run(m)?,
228 None => {
229 let mut out = vec![0f32; e * g * g];
231 for c in 0..e {
232 let v = w.no_mask_embed[c];
233 out[c * g * g..(c + 1) * g * g].fill(v);
234 }
235 out
236 }
237 };
238
239 let image_pe = compute_image_pe(w, g, g);
241
242 Ok(Sam2PromptEncoderOutput {
243 sparse_embeddings: sparse,
244 num_sparse_tokens,
245 dense_embeddings,
246 image_pe,
247 })
248}
249
250pub fn compute_image_pe(w: &Sam2PromptEncoderWeights, h: usize, ww: usize) -> Vec<f32> {
253 let e = w.embed_dim;
254 let half = e / 2;
255 let mut out = vec![0f32; e * h * ww];
256 for y in 0..h {
257 let fy = (y as f32 + 0.5) / h as f32;
258 for x in 0..ww {
259 let fx = (x as f32 + 0.5) / ww as f32;
260 let cx = fx * 2.0 - 1.0;
261 let cy = fy * 2.0 - 1.0;
262 for k in 0..half {
263 let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
264 acc *= 2.0 * std::f32::consts::PI;
265 out[k * h * ww + y * ww + x] = acc.sin();
266 out[(half + k) * h * ww + y * ww + x] = acc.cos();
267 }
268 }
269 }
270 out
271}
272
273fn pe_encode_normalized(w: &Sam2PromptEncoderWeights, coords: &[f32], n: usize) -> Vec<f32> {
276 let e = w.embed_dim;
277 let half = e / 2;
278 let mut out = vec![0f32; n * e];
279 for i in 0..n {
280 let cx = coords[i * 2] * 2.0 - 1.0;
281 let cy = coords[i * 2 + 1] * 2.0 - 1.0;
282 for k in 0..half {
283 let mut acc = cx * w.pe_gaussian[k] + cy * w.pe_gaussian[half + k];
284 acc *= 2.0 * std::f32::consts::PI;
285 out[i * e + k] = acc.sin();
286 out[i * e + half + k] = acc.cos();
287 }
288 }
289 out
290}
291
292fn embed_points_and_boxes(
294 w: &Sam2PromptEncoderWeights,
295 coords_in_pixels: &[f32],
296 n: usize,
297 is_box: bool,
298 labels: Option<&[f32]>,
299) -> Result<Vec<f32>> {
300 let e = w.embed_dim;
301 let img = SAM2_IMG_SIZE as f32;
302 let normed: Vec<f32> = coords_in_pixels.iter().map(|c| c / img).collect();
303 let mut emb = pe_encode_normalized(w, &normed, n);
304
305 if is_box {
306 for i in 0..n {
307 let pe_idx = if i % 2 == 0 { 2 } else { 3 };
308 for k in 0..e {
309 emb[i * e + k] += w.point_embeddings[pe_idx * e + k];
310 }
311 }
312 } else if let Some(lbls) = labels {
313 ensure!(lbls.len() == n, "labels len {} ≠ n {n}", lbls.len());
314 for i in 0..n {
315 let label = lbls[i];
316 if label < 0.0 {
317 for k in 0..e {
318 emb[i * e + k] = w.not_a_point_embed[k];
319 }
320 } else if label == 0.0 {
321 for k in 0..e {
322 emb[i * e + k] += w.point_embeddings[k];
323 }
324 } else {
325 for k in 0..e {
326 emb[i * e + k] += w.point_embeddings[e + k];
327 }
328 }
329 }
330 }
331 Ok(emb)
332}
333
334#[allow(dead_code)]
338pub(super) fn conv2d_stride2_k2_pad0(
339 input: &[f32],
340 in_c: usize,
341 out_c: usize,
342 in_h: usize,
343 in_w: usize,
344 weight: &[f32], bias: &[f32], ) -> Vec<f32> {
347 let out_h = in_h / 2;
348 let out_w = in_w / 2;
349 let mut out = vec![0f32; out_c * out_h * out_w];
350 for oc in 0..out_c {
351 for oy in 0..out_h {
352 for ox in 0..out_w {
353 let mut acc = bias[oc];
354 for ic in 0..in_c {
355 for ky in 0..2 {
356 let iy = oy * 2 + ky;
357 for kx in 0..2 {
358 let ix = ox * 2 + kx;
359 let v = input[ic * in_h * in_w + iy * in_w + ix];
360 let w_idx = ((oc * in_c + ic) * 2 + ky) * 2 + kx;
361 acc += v * weight[w_idx];
362 }
363 }
364 }
365 out[oc * out_h * out_w + oy * out_w + ox] = acc;
366 }
367 }
368 }
369 out
370}
371
372pub(super) fn conv2d_1x1(
374 input: &[f32],
375 in_c: usize,
376 out_c: usize,
377 h: usize,
378 w: usize,
379 weight: &[f32], bias: &[f32], ) -> Vec<f32> {
382 let mut out = vec![0f32; out_c * h * w];
383 for oc in 0..out_c {
384 let b = bias[oc];
385 for y in 0..h {
386 for x in 0..w {
387 let mut acc = b;
388 for ic in 0..in_c {
389 acc += input[ic * h * w + y * w + x] * weight[oc * in_c + ic];
390 }
391 out[oc * h * w + y * w + x] = acc;
392 }
393 }
394 }
395 out
396}
397
398pub(super) fn layernorm2d_nchw(
400 data: &mut [f32],
401 c: usize,
402 h: usize,
403 w: usize,
404 gamma: &[f32],
405 beta: &[f32],
406 eps: f32,
407) {
408 let n = h * w;
409 for i in 0..n {
410 let mut mean = 0f32;
411 for k in 0..c {
412 mean += data[k * n + i];
413 }
414 mean /= c as f32;
415 let mut var = 0f32;
416 for k in 0..c {
417 let d = data[k * n + i] - mean;
418 var += d * d;
419 }
420 var /= c as f32;
421 let inv = 1.0 / (var + eps).sqrt();
422 for k in 0..c {
423 let v = (data[k * n + i] - mean) * inv;
424 data[k * n + i] = v * gamma[k] + beta[k];
425 }
426 }
427}
428
429pub(super) fn gelu_erf_inplace(data: &mut [f32]) {
431 const INV_SQRT2: f32 = std::f32::consts::FRAC_1_SQRT_2;
432 for v in data.iter_mut() {
433 let x = *v;
434 let s = (x * INV_SQRT2).abs();
435 let p = 0.327_591_1;
436 let a1 = 0.254_829_6;
437 let a2 = -0.284_496_7;
438 let a3 = 1.421_413_8;
439 let a4 = -1.453_152_1;
440 let a5 = 1.061_405_4;
441 let t = 1.0 / (1.0 + p * s);
442 let y = ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t;
443 let erf_abs = 1.0 - y * (-s * s).exp();
444 let erf = if x >= 0.0 { erf_abs } else { -erf_abs };
445 *v = 0.5 * x * (1.0 + erf);
446 }
447}
448
449pub(super) fn sigmoid_inplace(x: &mut [f32]) {
451 for v in x.iter_mut() {
452 *v = 1.0 / (1.0 + (-*v).exp());
453 }
454}