1pub mod axial_rope;
46pub mod cli;
47pub mod config;
48pub mod flow;
49pub mod fpn_neck;
50pub mod fpn_neck_ir;
51pub mod image_encoder;
52pub mod mask_decoder;
53pub mod memory_attention;
54pub mod memory_attention_ir;
55pub mod memory_encoder;
56pub mod memory_mask_ir;
57pub mod mlp_ir;
58pub mod preprocess;
59pub mod prompt_encoder;
60pub mod prompt_mask_ir;
61#[allow(clippy::module_inception)]
62pub mod sam2;
63pub mod transformer;
64pub mod transformer_ir;
65pub mod upscale_ir;
66
67pub use rlx_sam::profile::{
68 SAM_PROFILE_FILE, sam_profile_near_weights, sam2_profile_default, sam2_profile_near_weights,
69};
70
71pub use config::{
72 SAM2_IMG_SIZE, SAM2_PATCH_GRID, SAM2_PATCH_KERNEL, SAM2_PATCH_PADDING, SAM2_PATCH_STRIDE,
73 SAM2_PIXEL_MEAN, SAM2_PIXEL_STD, SAM2_PROMPT_EMBED_DIM, SAM2_Q_POOL_COUNT, SAM2_Q_STRIDE,
74 Sam2Config, Sam2DecoderConfig, Sam2FpnConfig, Sam2HieraConfig, Sam2MemoryConfig,
75 Sam2MemoryEncoderConfig,
76};
77pub use flow::{Sam2ImageEncoderBuilt, Sam2ImageEncoderFlow, build_sam2_image_encoder_built};
78pub use fpn_neck::{FpnLevel, FpnNeckWeights, apply_fpn_neck, apply_fpn_neck_host};
79pub use fpn_neck_ir::{Sam2FpnNeckIr, compile_fpn_neck_ir};
80pub use image_encoder::{build_sam2_image_encoder_graph, build_sam2_image_encoder_hir};
81pub use mask_decoder::{Sam2MaskDecoderOutput, Sam2MaskDecoderWeights, mask_decoder_forward};
82pub use memory_attention::{Sam2MemoryAttentionWeights, memory_attention_forward};
83pub use memory_attention_ir::MemoryAttentionCompiled;
84pub use memory_encoder::{
85 Sam2MemoryEncoderOutput, Sam2MemoryEncoderWeights, memory_encoder_forward,
86};
87pub use preprocess::{Sam2PreprocessWeights, assemble_patch_tokens, preprocess_image};
88pub use prompt_encoder::{
89 SAM2_MASK_IN_CHANS, SAM2_PROMPT_GRID, Sam2PromptEncoderOutput, Sam2PromptEncoderWeights,
90 prompt_encoder_forward,
91};
92pub use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
93pub use sam2::{Sam2, Sam2ImagePrediction, Sam2VideoState};
94pub use transformer::{Sam2TwoWayTransformerWeights, two_way_transformer_forward};
95pub use transformer_ir::compile_two_way_transformer;
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use rlx_core::weight_map::WeightMap;
101 use rlx_runtime::Device;
102 use std::collections::HashMap;
103
104 type T = HashMap<String, (Vec<f32>, Vec<usize>)>;
105
106 fn z(n: usize) -> Vec<f32> {
107 vec![0.0f32; n]
108 }
109
110 fn add_hiera_weights(t: &mut T, cfg: &Sam2HieraConfig) {
114 let e0 = cfg.embed_dim;
115 let k = SAM2_PATCH_KERNEL;
116 let [ph, pw] = cfg.window_pos_embed_bkg_spatial_size;
117 let mu = cfg.window_size_at_stage(0);
118
119 t.insert(
120 "image_encoder.trunk.patch_embed.proj.weight".into(),
121 (z(e0 * 3 * k * k), vec![e0, 3, k, k]),
122 );
123 t.insert(
124 "image_encoder.trunk.patch_embed.proj.bias".into(),
125 (z(e0), vec![e0]),
126 );
127 t.insert(
128 "image_encoder.trunk.pos_embed".into(),
129 (z(e0 * ph * pw), vec![1, e0, ph, pw]),
130 );
131 t.insert(
132 "image_encoder.trunk.pos_embed_window".into(),
133 (z(e0 * mu * mu), vec![1, e0, mu, mu]),
134 );
135
136 let q_pool = cfg.q_pool_block_indices();
137 let total = cfg.total_blocks();
138 let mut stage = 0usize;
139 let mut dim_curr = e0;
140 for i in 0..total {
141 let is_q_pool = q_pool.contains(&i);
142 let dim_in = dim_curr;
143 let stage_after = if is_q_pool { stage + 1 } else { stage };
144 let dim_out = cfg.embed_dim_at_stage(stage_after);
145 let lp = format!("image_encoder.trunk.blocks.{i}");
146
147 t.insert(format!("{lp}.norm1.weight"), (z(dim_in), vec![dim_in]));
148 t.insert(format!("{lp}.norm1.bias"), (z(dim_in), vec![dim_in]));
149 if dim_in != dim_out {
150 t.insert(
151 format!("{lp}.proj.weight"),
152 (z(dim_in * dim_out), vec![dim_out, dim_in]),
153 );
154 t.insert(format!("{lp}.proj.bias"), (z(dim_out), vec![dim_out]));
155 }
156 t.insert(
157 format!("{lp}.attn.qkv.weight"),
158 (z(dim_in * 3 * dim_out), vec![3 * dim_out, dim_in]),
159 );
160 if cfg.qkv_bias {
161 t.insert(
162 format!("{lp}.attn.qkv.bias"),
163 (z(3 * dim_out), vec![3 * dim_out]),
164 );
165 }
166 t.insert(
167 format!("{lp}.attn.proj.weight"),
168 (z(dim_out * dim_out), vec![dim_out, dim_out]),
169 );
170 t.insert(format!("{lp}.attn.proj.bias"), (z(dim_out), vec![dim_out]));
171 t.insert(format!("{lp}.norm2.weight"), (z(dim_out), vec![dim_out]));
172 t.insert(format!("{lp}.norm2.bias"), (z(dim_out), vec![dim_out]));
173
174 let hidden = (dim_out as f64 * cfg.mlp_ratio) as usize;
175 t.insert(
176 format!("{lp}.mlp.layers.0.weight"),
177 (z(dim_out * hidden), vec![hidden, dim_out]),
178 );
179 t.insert(format!("{lp}.mlp.layers.0.bias"), (z(hidden), vec![hidden]));
180 t.insert(
181 format!("{lp}.mlp.layers.1.weight"),
182 (z(hidden * dim_out), vec![dim_out, hidden]),
183 );
184 t.insert(
185 format!("{lp}.mlp.layers.1.bias"),
186 (z(dim_out), vec![dim_out]),
187 );
188
189 if is_q_pool {
190 stage += 1;
191 dim_curr = dim_out;
192 }
193 }
194
195 let fpn = Sam2FpnConfig::for_hiera(cfg);
196 for (i, &cin) in fpn.backbone_channel_list.iter().enumerate() {
197 t.insert(
198 format!("image_encoder.neck.convs.{i}.conv.weight"),
199 (z(fpn.d_model * cin), vec![fpn.d_model, cin, 1, 1]),
200 );
201 t.insert(
202 format!("image_encoder.neck.convs.{i}.conv.bias"),
203 (z(fpn.d_model), vec![fpn.d_model]),
204 );
205 }
206 }
207
208 fn add_prompt_encoder_weights(t: &mut T, embed_dim: usize, mask_in_chans: usize) {
209 let half = embed_dim / 2;
210 let q = mask_in_chans / 4;
211 t.insert(
212 "sam_prompt_encoder.pe_layer.positional_encoding_gaussian_matrix".into(),
213 (z(2 * half), vec![2, half]),
214 );
215 t.insert(
216 "sam_prompt_encoder.not_a_point_embed.weight".into(),
217 (z(embed_dim), vec![1, embed_dim]),
218 );
219 t.insert(
220 "sam_prompt_encoder.no_mask_embed.weight".into(),
221 (z(embed_dim), vec![1, embed_dim]),
222 );
223 for i in 0..4 {
224 t.insert(
225 format!("sam_prompt_encoder.point_embeddings.{i}.weight"),
226 (z(embed_dim), vec![1, embed_dim]),
227 );
228 }
229 t.insert(
230 "sam_prompt_encoder.mask_downscaling.0.weight".into(),
231 (z(q * 4), vec![q, 1, 2, 2]),
232 );
233 t.insert(
234 "sam_prompt_encoder.mask_downscaling.0.bias".into(),
235 (z(q), vec![q]),
236 );
237 t.insert(
238 "sam_prompt_encoder.mask_downscaling.1.weight".into(),
239 (z(q), vec![q]),
240 );
241 t.insert(
242 "sam_prompt_encoder.mask_downscaling.1.bias".into(),
243 (z(q), vec![q]),
244 );
245 t.insert(
246 "sam_prompt_encoder.mask_downscaling.3.weight".into(),
247 (z(mask_in_chans * q * 4), vec![mask_in_chans, q, 2, 2]),
248 );
249 t.insert(
250 "sam_prompt_encoder.mask_downscaling.3.bias".into(),
251 (z(mask_in_chans), vec![mask_in_chans]),
252 );
253 t.insert(
254 "sam_prompt_encoder.mask_downscaling.4.weight".into(),
255 (z(mask_in_chans), vec![mask_in_chans]),
256 );
257 t.insert(
258 "sam_prompt_encoder.mask_downscaling.4.bias".into(),
259 (z(mask_in_chans), vec![mask_in_chans]),
260 );
261 t.insert(
262 "sam_prompt_encoder.mask_downscaling.6.weight".into(),
263 (
264 z(embed_dim * mask_in_chans),
265 vec![embed_dim, mask_in_chans, 1, 1],
266 ),
267 );
268 t.insert(
269 "sam_prompt_encoder.mask_downscaling.6.bias".into(),
270 (z(embed_dim), vec![embed_dim]),
271 );
272 }
273
274 fn add_two_way_transformer_weights(t: &mut T, cfg: &Sam2DecoderConfig) {
275 let e = cfg.transformer_dim;
276 let id = e / 2;
277 let mlp = cfg.transformer_mlp_dim;
278 for i in 0..cfg.transformer_depth {
279 let p = format!("sam_mask_decoder.transformer.layers.{i}");
280 {
282 let sub = "self_attn";
283 t.insert(format!("{p}.{sub}.q_proj.weight"), (z(e * e), vec![e, e]));
284 t.insert(format!("{p}.{sub}.q_proj.bias"), (z(e), vec![e]));
285 t.insert(format!("{p}.{sub}.k_proj.weight"), (z(e * e), vec![e, e]));
286 t.insert(format!("{p}.{sub}.k_proj.bias"), (z(e), vec![e]));
287 t.insert(format!("{p}.{sub}.v_proj.weight"), (z(e * e), vec![e, e]));
288 t.insert(format!("{p}.{sub}.v_proj.bias"), (z(e), vec![e]));
289 t.insert(format!("{p}.{sub}.out_proj.weight"), (z(e * e), vec![e, e]));
290 t.insert(format!("{p}.{sub}.out_proj.bias"), (z(e), vec![e]));
291 }
292 t.insert(format!("{p}.norm1.weight"), (z(e), vec![e]));
293 t.insert(format!("{p}.norm1.bias"), (z(e), vec![e]));
294 for sub in ["cross_attn_token_to_image", "cross_attn_image_to_token"] {
296 t.insert(format!("{p}.{sub}.q_proj.weight"), (z(e * id), vec![id, e]));
297 t.insert(format!("{p}.{sub}.q_proj.bias"), (z(id), vec![id]));
298 t.insert(format!("{p}.{sub}.k_proj.weight"), (z(e * id), vec![id, e]));
299 t.insert(format!("{p}.{sub}.k_proj.bias"), (z(id), vec![id]));
300 t.insert(format!("{p}.{sub}.v_proj.weight"), (z(e * id), vec![id, e]));
301 t.insert(format!("{p}.{sub}.v_proj.bias"), (z(id), vec![id]));
302 t.insert(
303 format!("{p}.{sub}.out_proj.weight"),
304 (z(e * id), vec![e, id]),
305 );
306 t.insert(format!("{p}.{sub}.out_proj.bias"), (z(e), vec![e]));
307 }
308 t.insert(format!("{p}.norm2.weight"), (z(e), vec![e]));
309 t.insert(format!("{p}.norm2.bias"), (z(e), vec![e]));
310 t.insert(
311 format!("{p}.mlp.layers.0.weight"),
312 (z(mlp * e), vec![mlp, e]),
313 );
314 t.insert(format!("{p}.mlp.layers.0.bias"), (z(mlp), vec![mlp]));
315 t.insert(
316 format!("{p}.mlp.layers.1.weight"),
317 (z(mlp * e), vec![e, mlp]),
318 );
319 t.insert(format!("{p}.mlp.layers.1.bias"), (z(e), vec![e]));
320 t.insert(format!("{p}.norm3.weight"), (z(e), vec![e]));
321 t.insert(format!("{p}.norm3.bias"), (z(e), vec![e]));
322 t.insert(format!("{p}.norm4.weight"), (z(e), vec![e]));
323 t.insert(format!("{p}.norm4.bias"), (z(e), vec![e]));
324 }
325 let p = "sam_mask_decoder.transformer.final_attn_token_to_image";
327 t.insert(format!("{p}.q_proj.weight"), (z(e * id), vec![id, e]));
328 t.insert(format!("{p}.q_proj.bias"), (z(id), vec![id]));
329 t.insert(format!("{p}.k_proj.weight"), (z(e * id), vec![id, e]));
330 t.insert(format!("{p}.k_proj.bias"), (z(id), vec![id]));
331 t.insert(format!("{p}.v_proj.weight"), (z(e * id), vec![id, e]));
332 t.insert(format!("{p}.v_proj.bias"), (z(id), vec![id]));
333 t.insert(format!("{p}.out_proj.weight"), (z(e * id), vec![e, id]));
334 t.insert(format!("{p}.out_proj.bias"), (z(e), vec![e]));
335 t.insert(
336 "sam_mask_decoder.transformer.norm_final_attn.weight".into(),
337 (z(e), vec![e]),
338 );
339 t.insert(
340 "sam_mask_decoder.transformer.norm_final_attn.bias".into(),
341 (z(e), vec![e]),
342 );
343 }
344
345 fn add_mask_decoder_weights(t: &mut T, cfg: &Sam2DecoderConfig) {
346 let e = cfg.transformer_dim;
347 let q4 = e / 4;
348 let q8 = e / 8;
349 t.insert(
350 "sam_mask_decoder.iou_token.weight".into(),
351 (z(e), vec![1, e]),
352 );
353 t.insert(
354 "sam_mask_decoder.mask_tokens.weight".into(),
355 (z(cfg.num_mask_tokens * e), vec![cfg.num_mask_tokens, e]),
356 );
357 if cfg.pred_obj_scores {
358 t.insert(
359 "sam_mask_decoder.obj_score_token.weight".into(),
360 (z(e), vec![1, e]),
361 );
362 }
363 t.insert(
364 "sam_mask_decoder.output_upscaling.0.weight".into(),
365 (z(e * q4 * 4), vec![e, q4, 2, 2]),
366 );
367 t.insert(
368 "sam_mask_decoder.output_upscaling.0.bias".into(),
369 (z(q4), vec![q4]),
370 );
371 t.insert(
372 "sam_mask_decoder.output_upscaling.1.weight".into(),
373 (z(q4), vec![q4]),
374 );
375 t.insert(
376 "sam_mask_decoder.output_upscaling.1.bias".into(),
377 (z(q4), vec![q4]),
378 );
379 t.insert(
380 "sam_mask_decoder.output_upscaling.3.weight".into(),
381 (z(q4 * q8 * 4), vec![q4, q8, 2, 2]),
382 );
383 t.insert(
384 "sam_mask_decoder.output_upscaling.3.bias".into(),
385 (z(q8), vec![q8]),
386 );
387 if cfg.use_high_res_features {
388 t.insert(
389 "sam_mask_decoder.conv_s0.weight".into(),
390 (z(q8 * e), vec![q8, e, 1, 1]),
391 );
392 t.insert("sam_mask_decoder.conv_s0.bias".into(), (z(q8), vec![q8]));
393 t.insert(
394 "sam_mask_decoder.conv_s1.weight".into(),
395 (z(q4 * e), vec![q4, e, 1, 1]),
396 );
397 t.insert("sam_mask_decoder.conv_s1.bias".into(), (z(q4), vec![q4]));
398 }
399 for i in 0..cfg.num_mask_tokens {
400 let p = format!("sam_mask_decoder.output_hypernetworks_mlps.{i}");
401 t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
403 t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
404 t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
405 t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
406 t.insert(format!("{p}.layers.2.weight"), (z(e * q8), vec![q8, e]));
407 t.insert(format!("{p}.layers.2.bias"), (z(q8), vec![q8]));
408 }
409 let p = "sam_mask_decoder.iou_prediction_head";
411 let hidden = cfg.iou_head_hidden_dim;
412 t.insert(
413 format!("{p}.layers.0.weight"),
414 (z(e * hidden), vec![hidden, e]),
415 );
416 t.insert(format!("{p}.layers.0.bias"), (z(hidden), vec![hidden]));
417 t.insert(
418 format!("{p}.layers.1.weight"),
419 (z(hidden * hidden), vec![hidden, hidden]),
420 );
421 t.insert(format!("{p}.layers.1.bias"), (z(hidden), vec![hidden]));
422 t.insert(
423 format!("{p}.layers.2.weight"),
424 (
425 z(hidden * cfg.num_mask_tokens),
426 vec![cfg.num_mask_tokens, hidden],
427 ),
428 );
429 t.insert(
430 format!("{p}.layers.2.bias"),
431 (z(cfg.num_mask_tokens), vec![cfg.num_mask_tokens]),
432 );
433 if cfg.pred_obj_scores {
435 if cfg.pred_obj_scores_mlp {
436 let p = "sam_mask_decoder.pred_obj_score_head";
437 t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
438 t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
439 t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
440 t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
441 t.insert(format!("{p}.layers.2.weight"), (z(e), vec![1, e]));
442 t.insert(format!("{p}.layers.2.bias"), (z(1), vec![1]));
443 } else {
444 t.insert(
445 "sam_mask_decoder.pred_obj_score_head.weight".into(),
446 (z(e), vec![1, e]),
447 );
448 t.insert(
449 "sam_mask_decoder.pred_obj_score_head.bias".into(),
450 (z(1), vec![1]),
451 );
452 }
453 }
454 if cfg.use_object_pointer {
456 if cfg.use_mlp_for_obj_ptr_proj {
457 let p = "obj_ptr_proj";
458 t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
459 t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
460 t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
461 t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
462 t.insert(format!("{p}.layers.2.weight"), (z(e * e), vec![e, e]));
463 t.insert(format!("{p}.layers.2.bias"), (z(e), vec![e]));
464 } else {
465 t.insert("obj_ptr_proj.weight".into(), (z(e * e), vec![e, e]));
466 t.insert("obj_ptr_proj.bias".into(), (z(e), vec![e]));
467 }
468 }
469 add_two_way_transformer_weights(t, cfg);
470 }
471
472 fn add_memory_encoder_weights(t: &mut T, cfg: &Sam2MemoryEncoderConfig) {
473 let mut in_c = 1usize;
475 let stride2 = cfg.mask_downsampler_stride * cfg.mask_downsampler_stride;
476 let mut num_levels = 0;
477 let mut acc = 1usize;
478 while acc < cfg.mask_downsampler_total_stride {
479 acc *= cfg.mask_downsampler_stride;
480 num_levels += 1;
481 }
482 for li in 0..num_levels {
483 let out_c = in_c * stride2;
484 let conv_idx = li * 3;
485 let ln_idx = conv_idx + 1;
486 let k = cfg.mask_downsampler_kernel;
487 t.insert(
488 format!("memory_encoder.mask_downsampler.encoder.{conv_idx}.weight"),
489 (z(out_c * in_c * k * k), vec![out_c, in_c, k, k]),
490 );
491 t.insert(
492 format!("memory_encoder.mask_downsampler.encoder.{conv_idx}.bias"),
493 (z(out_c), vec![out_c]),
494 );
495 t.insert(
496 format!("memory_encoder.mask_downsampler.encoder.{ln_idx}.weight"),
497 (z(out_c), vec![out_c]),
498 );
499 t.insert(
500 format!("memory_encoder.mask_downsampler.encoder.{ln_idx}.bias"),
501 (z(out_c), vec![out_c]),
502 );
503 in_c = out_c;
504 }
505 let final_idx = num_levels * 3;
506 t.insert(
507 format!("memory_encoder.mask_downsampler.encoder.{final_idx}.weight"),
508 (z(cfg.in_dim * in_c), vec![cfg.in_dim, in_c, 1, 1]),
509 );
510 t.insert(
511 format!("memory_encoder.mask_downsampler.encoder.{final_idx}.bias"),
512 (z(cfg.in_dim), vec![cfg.in_dim]),
513 );
514 t.insert(
516 "memory_encoder.pix_feat_proj.weight".into(),
517 (
518 z(cfg.in_dim * cfg.in_dim),
519 vec![cfg.in_dim, cfg.in_dim, 1, 1],
520 ),
521 );
522 t.insert(
523 "memory_encoder.pix_feat_proj.bias".into(),
524 (z(cfg.in_dim), vec![cfg.in_dim]),
525 );
526 for i in 0..cfg.fuser_num_layers {
528 let p = format!("memory_encoder.fuser.layers.{i}");
529 let dim = cfg.fuser_dim;
530 let k = cfg.fuser_kernel;
531 if cfg.fuser_use_dwconv {
532 t.insert(
533 format!("{p}.dwconv.weight"),
534 (z(dim * k * k), vec![dim, 1, k, k]),
535 );
536 } else {
537 t.insert(
538 format!("{p}.dwconv.weight"),
539 (z(dim * dim * k * k), vec![dim, dim, k, k]),
540 );
541 }
542 t.insert(format!("{p}.dwconv.bias"), (z(dim), vec![dim]));
543 t.insert(format!("{p}.norm.weight"), (z(dim), vec![dim]));
544 t.insert(format!("{p}.norm.bias"), (z(dim), vec![dim]));
545 t.insert(
546 format!("{p}.pwconv1.weight"),
547 (z(4 * dim * dim), vec![4 * dim, dim]),
548 );
549 t.insert(format!("{p}.pwconv1.bias"), (z(4 * dim), vec![4 * dim]));
550 t.insert(
551 format!("{p}.pwconv2.weight"),
552 (z(dim * 4 * dim), vec![dim, 4 * dim]),
553 );
554 t.insert(format!("{p}.pwconv2.bias"), (z(dim), vec![dim]));
555 if cfg.fuser_layer_scale_init_value > 0.0 {
556 t.insert(format!("{p}.gamma"), (z(dim), vec![dim]));
557 }
558 }
559 if cfg.in_dim != cfg.out_dim {
561 t.insert(
562 "memory_encoder.out_proj.weight".into(),
563 (
564 z(cfg.in_dim * cfg.out_dim),
565 vec![cfg.out_dim, cfg.in_dim, 1, 1],
566 ),
567 );
568 t.insert(
569 "memory_encoder.out_proj.bias".into(),
570 (z(cfg.out_dim), vec![cfg.out_dim]),
571 );
572 }
573 }
574
575 fn add_memory_attention_weights(t: &mut T, cfg: &Sam2MemoryConfig) {
576 let d = cfg.d_model;
577 let kv = cfg.kv_in_dim;
578 let dff = cfg.dim_feedforward;
579 for i in 0..cfg.num_layers {
580 let p = format!("memory_attention.layers.{i}");
581 {
583 let sub = "self_attn";
584 t.insert(format!("{p}.{sub}.q_proj.weight"), (z(d * d), vec![d, d]));
585 t.insert(format!("{p}.{sub}.q_proj.bias"), (z(d), vec![d]));
586 t.insert(format!("{p}.{sub}.k_proj.weight"), (z(d * d), vec![d, d]));
587 t.insert(format!("{p}.{sub}.k_proj.bias"), (z(d), vec![d]));
588 t.insert(format!("{p}.{sub}.v_proj.weight"), (z(d * d), vec![d, d]));
589 t.insert(format!("{p}.{sub}.v_proj.bias"), (z(d), vec![d]));
590 t.insert(format!("{p}.{sub}.out_proj.weight"), (z(d * d), vec![d, d]));
591 t.insert(format!("{p}.{sub}.out_proj.bias"), (z(d), vec![d]));
592 }
593 {
595 let sub = "cross_attn_image";
596 t.insert(format!("{p}.{sub}.q_proj.weight"), (z(d * d), vec![d, d]));
597 t.insert(format!("{p}.{sub}.q_proj.bias"), (z(d), vec![d]));
598 t.insert(format!("{p}.{sub}.k_proj.weight"), (z(d * kv), vec![d, kv]));
599 t.insert(format!("{p}.{sub}.k_proj.bias"), (z(d), vec![d]));
600 t.insert(format!("{p}.{sub}.v_proj.weight"), (z(d * kv), vec![d, kv]));
601 t.insert(format!("{p}.{sub}.v_proj.bias"), (z(d), vec![d]));
602 t.insert(format!("{p}.{sub}.out_proj.weight"), (z(d * d), vec![d, d]));
603 t.insert(format!("{p}.{sub}.out_proj.bias"), (z(d), vec![d]));
604 }
605 t.insert(format!("{p}.norm1.weight"), (z(d), vec![d]));
606 t.insert(format!("{p}.norm1.bias"), (z(d), vec![d]));
607 t.insert(format!("{p}.norm2.weight"), (z(d), vec![d]));
608 t.insert(format!("{p}.norm2.bias"), (z(d), vec![d]));
609 t.insert(format!("{p}.norm3.weight"), (z(d), vec![d]));
610 t.insert(format!("{p}.norm3.bias"), (z(d), vec![d]));
611 t.insert(format!("{p}.linear1.weight"), (z(dff * d), vec![dff, d]));
612 t.insert(format!("{p}.linear1.bias"), (z(dff), vec![dff]));
613 t.insert(format!("{p}.linear2.weight"), (z(d * dff), vec![d, dff]));
614 t.insert(format!("{p}.linear2.bias"), (z(d), vec![d]));
615 }
616 t.insert("memory_attention.norm.weight".into(), (z(d), vec![d]));
617 t.insert("memory_attention.norm.bias".into(), (z(d), vec![d]));
618 }
619
620 fn synthetic_full_sam2_weights(cfg: &Sam2Config) -> WeightMap {
621 let mut t: T = HashMap::new();
622 add_hiera_weights(&mut t, &cfg.hiera);
623 add_prompt_encoder_weights(&mut t, cfg.decoder.transformer_dim, SAM2_MASK_IN_CHANS);
624 add_mask_decoder_weights(&mut t, &cfg.decoder);
625 add_memory_encoder_weights(&mut t, &cfg.memory_encoder);
626 add_memory_attention_weights(&mut t, &cfg.memory);
627 WeightMap::from_tensors(t)
628 }
629
630 fn assert_encoder_builds(cfg: Sam2HieraConfig) {
631 let mut t: T = HashMap::new();
632 add_hiera_weights(&mut t, &cfg);
633 let mut wm = WeightMap::from_tensors(t);
634 let (g, _params, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg, &mut wm)
635 .unwrap_or_else(|e| panic!("encoder build failed: {e}"));
636 assert_eq!(g.outputs.len(), cfg.stages.len());
637 for (s, out_id) in g.outputs.iter().copied().enumerate() {
638 let shape = g.shape(out_id);
639 let dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
640 let hw_s = cfg.grid_size_at_stage(s);
641 let dim_s = cfg.embed_dim_at_stage(s);
642 assert_eq!(dims, vec![1, hw_s, hw_s, dim_s], "stage {s} shape mismatch");
643 }
644 let leftovers: Vec<&str> = wm.keys().collect();
645 assert!(leftovers.is_empty(), "leftover weights: {leftovers:?}");
646 }
647
648 #[test]
649 fn encoder_graph_builds_tiny() {
650 assert_encoder_builds(Sam2HieraConfig::tiny());
651 }
652
653 #[test]
654 fn encoder_graph_builds_small() {
655 assert_encoder_builds(Sam2HieraConfig::small());
656 }
657
658 #[test]
659 fn encoder_graph_builds_base_plus() {
660 assert_encoder_builds(Sam2HieraConfig::base_plus());
661 }
662
663 #[test]
664 fn encoder_graph_builds_large() {
665 assert_encoder_builds(Sam2HieraConfig::large());
666 }
667
668 #[test]
669 fn preprocess_round_trip_shapes() {
670 let img = vec![64u8; 80 * 120 * 3];
671 let nchw = preprocess_image(&img, 80, 120);
672 assert_eq!(nchw.len(), 3 * 1024 * 1024);
673 }
674
675 #[test]
676 fn fpn_neck_runs_on_synth_outputs() {
677 let cfg = Sam2HieraConfig::base_plus();
678 let mut t: T = HashMap::new();
679 add_hiera_weights(&mut t, &cfg);
680 let mut wm = WeightMap::from_tensors(t);
681 let (_g, _p, _pre, neck) = build_sam2_image_encoder_graph(&cfg, &mut wm).unwrap();
682
683 let stage_hw: Vec<(usize, usize)> = (0..cfg.stages.len())
684 .map(|s| (cfg.grid_size_at_stage(s), cfg.grid_size_at_stage(s)))
685 .collect();
686 let stage_dims: Vec<usize> = (0..cfg.stages.len())
687 .map(|s| cfg.embed_dim_at_stage(s))
688 .collect();
689 let stage_outputs: Vec<Vec<f32>> = stage_hw
690 .iter()
691 .zip(&stage_dims)
692 .map(|(&(h, w), &d)| vec![0f32; h * w * d])
693 .collect();
694
695 let mut fpn_ir = super::fpn_neck_ir::compile_fpn_neck_ir(
696 &neck,
697 &stage_hw,
698 &stage_dims,
699 Device::Cpu,
700 &rlx_flow::CompileProfile::sam2(),
701 )
702 .unwrap();
703 let levels =
704 apply_fpn_neck(&neck, &mut fpn_ir, &stage_outputs, &stage_hw, &stage_dims).unwrap();
705 let levels_host = apply_fpn_neck_host(&neck, &stage_outputs, &stage_hw, &stage_dims);
706 assert_eq!(levels.len(), levels_host.len());
707 for (a, b) in levels.iter().zip(&levels_host) {
708 assert_eq!(a.features.len(), b.features.len());
709 assert_eq!(a.h, b.h);
710 assert_eq!(a.w, b.w);
711 let fd = a
712 .features
713 .iter()
714 .zip(&b.features)
715 .map(|(x, y)| (x - y).abs())
716 .fold(0f32, f32::max);
717 assert!(
718 fd < 1e-4,
719 "FPN IR vs host max |Δ| = {fd:.3e} at level {}×{}",
720 a.h,
721 a.w
722 );
723 }
724 assert_eq!(levels.len(), 4);
725 assert_eq!((levels[0].h, levels[0].w), (256, 256));
726 assert_eq!((levels[3].h, levels[3].w), (32, 32));
727 }
728
729 #[test]
730 fn full_weight_extraction_drains_map() {
731 let cfg = Sam2Config::hiera_base_plus();
735 let mut wm = synthetic_full_sam2_weights(&cfg);
736
737 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
739 let _ = prompt_encoder::extract_prompt_encoder_weights(
740 &mut wm,
741 cfg.decoder.transformer_dim,
742 SAM2_MASK_IN_CHANS,
743 )
744 .unwrap();
745 let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
746 let _ =
747 memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
748 let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
749
750 let leftovers: Vec<&str> = wm.keys().collect();
751 assert!(
752 leftovers.is_empty(),
753 "leftover weights after full extraction: {leftovers:?}"
754 );
755 }
756
757 #[test]
758 fn prompt_encoder_no_prompt_produces_pe_and_no_mask() {
759 let cfg = Sam2Config::hiera_base_plus();
760 let mut wm = synthetic_full_sam2_weights(&cfg);
761 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
763 let pe = prompt_encoder::extract_prompt_encoder_weights(
764 &mut wm,
765 cfg.decoder.transformer_dim,
766 SAM2_MASK_IN_CHANS,
767 )
768 .unwrap();
769 let mut mask_stack =
770 super::prompt_mask_ir::Sam2PromptMaskCompiled::compile(&pe, Device::Cpu).unwrap();
771 let out = prompt_encoder_forward(&pe, &mut mask_stack, None, None, None).unwrap();
772 assert_eq!(out.num_sparse_tokens, 0);
773 assert_eq!(
774 out.dense_embeddings.len(),
775 cfg.decoder.transformer_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
776 );
777 assert_eq!(
778 out.image_pe.len(),
779 cfg.decoder.transformer_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
780 );
781 }
782
783 #[test]
784 fn mask_decoder_runs_on_zero_inputs() {
785 let cfg = Sam2Config::hiera_base_plus();
786 let mut wm = synthetic_full_sam2_weights(&cfg);
787 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
788 let _pe = prompt_encoder::extract_prompt_encoder_weights(
789 &mut wm,
790 cfg.decoder.transformer_dim,
791 SAM2_MASK_IN_CHANS,
792 )
793 .unwrap();
794 let dec = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
795 let _ =
796 memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
797 let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
798
799 let e = cfg.decoder.transformer_dim;
800 let g = SAM2_PROMPT_GRID;
801 let image_emb = vec![0f32; e * g * g];
802 let image_pe = vec![0f32; e * g * g];
803 let dense = vec![0f32; e * g * g];
804 let sparse: Vec<f32> = Vec::new();
805 let s0 = vec![0f32; e * (4 * g) * (4 * g)];
806 let s1 = vec![0f32; e * (2 * g) * (2 * g)];
807
808 let mut upscale =
809 super::upscale_ir::Sam2MaskUpscaleCompiled::compile(&dec, g, Device::Cpu).unwrap();
810 let mut hyper_matmul = rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled::compile(
811 dec.num_mask_tokens,
812 cfg.decoder.transformer_dim / 8,
813 g,
814 Device::Cpu,
815 )
816 .unwrap();
817 let mut hyper_mlps_ir =
818 super::mlp_ir::compile_hyper_mlps(&dec.hyper_mlps, Device::Cpu).unwrap();
819 let mut iou_head_ir = super::mlp_ir::compile_hyper_mlp(&dec.iou_head, Device::Cpu).unwrap();
820 let mut obj_score_head_ir =
821 super::mlp_ir::compile_optional_hyper_mlp(&dec.obj_score_head, 1, Device::Cpu).unwrap();
822 let obj_ptr_rows = super::mlp_ir::obj_ptr_proj_rows(
823 dec.num_mask_tokens,
824 dec.use_multimask_token_for_obj_ptr,
825 );
826 let mut obj_ptr_proj_ir =
827 super::mlp_ir::compile_optional_hyper_mlp(&dec.obj_ptr_proj, obj_ptr_rows, Device::Cpu)
828 .unwrap();
829 let s_tok = if dec.obj_score_token.is_some() { 1 } else { 0 };
830 let base_q_n = s_tok + 1 + dec.num_mask_tokens;
831 let mut tw_ir = super::transformer_ir::compile_two_way_transformer(
832 &dec.transformer,
833 base_q_n,
834 g,
835 Device::Cpu,
836 )
837 .unwrap();
838 let out = mask_decoder_forward(
839 &dec,
840 &mut upscale,
841 Some(&mut hyper_matmul),
842 Some(&mut hyper_mlps_ir),
843 Some(&mut iou_head_ir),
844 obj_score_head_ir.as_mut(),
845 obj_ptr_proj_ir.as_mut(),
846 Some(&mut tw_ir),
847 &image_emb,
848 &image_pe,
849 &sparse,
850 0,
851 &dense,
852 Some((&s0, &s1)),
853 true,
854 g,
855 )
856 .unwrap();
857 assert_eq!(out.num_masks, 3);
858 assert_eq!(out.h_out, 4 * g);
859 assert_eq!(out.w_out, 4 * g);
860 assert_eq!(out.masks.len(), 3 * out.h_out * out.w_out);
861 assert_eq!(out.iou_pred.len(), 3);
862 assert_eq!(out.object_score_logits.len(), 1);
864 }
865
866 #[test]
867 fn memory_encoder_prefix_matches_split_ir() {
868 let cfg = Sam2Config::hiera_base_plus();
869 let mut wm = synthetic_full_sam2_weights(&cfg);
870 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
871 let mem =
872 memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
873
874 let pix = vec![0.1f32; cfg.memory_encoder.in_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID];
875 let mask = vec![0.5f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
876
877 let mut md = memory_mask_ir::Sam2MemoryMaskDownCompiled::compile(
878 &mem.mask_downsampler,
879 SAM2_IMG_SIZE,
880 SAM2_IMG_SIZE,
881 Device::Cpu,
882 )
883 .unwrap();
884 let mut pp = memory_mask_ir::Sam2MemoryConv1x1Compiled::compile(
885 mem.in_dim,
886 mem.in_dim,
887 SAM2_PROMPT_GRID,
888 SAM2_PROMPT_GRID,
889 &mem.pix_feat_proj_w,
890 &mem.pix_feat_proj_b,
891 Device::Cpu,
892 )
893 .unwrap();
894 let m_down = md.run(&mask).unwrap();
895 let mut split = pp.run(&pix).unwrap();
896 for i in 0..split.len() {
897 split[i] += m_down[i];
898 }
899
900 let mut prefix = memory_mask_ir::Sam2MemoryPrefixCompiled::compile(
901 &mem.mask_downsampler,
902 mem.in_dim,
903 SAM2_IMG_SIZE,
904 SAM2_IMG_SIZE,
905 SAM2_PROMPT_GRID,
906 SAM2_PROMPT_GRID,
907 &mem.pix_feat_proj_w,
908 &mem.pix_feat_proj_b,
909 Device::Cpu,
910 )
911 .unwrap();
912 let fused = prefix.run(&mask, &pix).unwrap();
913 assert_eq!(split.len(), fused.len());
914 let fd = split
915 .iter()
916 .zip(&fused)
917 .map(|(a, b)| (a - b).abs())
918 .fold(0f32, f32::max);
919 assert!(fd < 1e-4, "prefix vs split max |Δ| = {fd:.3e}");
920 }
921
922 #[test]
923 fn memory_encoder_shapes_match_for_b_plus() {
924 let cfg = Sam2Config::hiera_base_plus();
925 let mut wm = synthetic_full_sam2_weights(&cfg);
926 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
927 let _ = prompt_encoder::extract_prompt_encoder_weights(
928 &mut wm,
929 cfg.decoder.transformer_dim,
930 SAM2_MASK_IN_CHANS,
931 )
932 .unwrap();
933 let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
934 let mut mem =
935 memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
936 memory_encoder::compile_memory_encoder_ir(
937 &mut mem,
938 SAM2_IMG_SIZE,
939 SAM2_IMG_SIZE,
940 SAM2_PROMPT_GRID,
941 SAM2_PROMPT_GRID,
942 Device::Cpu,
943 &rlx_flow::CompileProfile::sam2(),
944 )
945 .unwrap();
946 let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
947
948 let pix = vec![0f32; cfg.memory_encoder.in_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID];
949 let mask = vec![0f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
950 let out = memory_encoder_forward(
951 &mut mem,
952 &pix,
953 &mask,
954 SAM2_PROMPT_GRID,
955 SAM2_PROMPT_GRID,
956 true,
957 )
958 .unwrap();
959 assert_eq!(out.h, SAM2_PROMPT_GRID);
960 assert_eq!(out.w, SAM2_PROMPT_GRID);
961 assert_eq!(
962 out.features.len(),
963 cfg.memory_encoder.out_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
964 );
965 assert_eq!(
967 out.pos.len(),
968 2 * cfg.memory_encoder.pe_num_pos_feats * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
969 );
970 }
971
972 #[test]
973 fn memory_attention_runs_on_zero_inputs() {
974 let cfg = Sam2Config::hiera_base_plus();
975 let mut wm = synthetic_full_sam2_weights(&cfg);
976 let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
977 let _ = prompt_encoder::extract_prompt_encoder_weights(
978 &mut wm,
979 cfg.decoder.transformer_dim,
980 SAM2_MASK_IN_CHANS,
981 )
982 .unwrap();
983 let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
984 let _ =
985 memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
986 let mat = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
987
988 let [end_x, end_y] = cfg.memory.rope_feat_size;
989 let n_img = end_x * end_y;
990 let d = cfg.memory.d_model;
991 let kv = cfg.memory.kv_in_dim;
992 let curr = vec![0f32; n_img * d];
993 let curr_pos = vec![0f32; n_img * d];
994 let n_mem = end_x * end_y;
996 let memory = vec![0f32; n_mem * kv];
997 let memory_pos = vec![0f32; n_mem * kv];
998 let out = memory_attention_forward(
999 &mat,
1000 &curr,
1001 &curr_pos,
1002 &memory,
1003 &memory_pos,
1004 n_img,
1005 n_mem,
1006 kv,
1007 0,
1008 )
1009 .unwrap();
1010 assert_eq!(out.len(), n_img * d);
1011 assert!(out.iter().all(|v| v.is_finite()));
1012 }
1013}