Skip to main content

rlx_sam2/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 — Meta's Segment Anything Model 2 (image + video segmentation).
17//!
18//! Mirrors `facebookresearch/sam2` so the published
19//! `sam2_hiera_{t,s,b+,l}.{pt,safetensors}` checkpoints load with no
20//! weight-key remapping.
21//!
22//! ## Components
23//!
24//! - **Phase 1** — Hiera image encoder + FpnNeck
25//!   ([`image_encoder`], [`fpn_neck`], [`preprocess`]).
26//! - **Phase 2** — prompt encoder + TwoWayTransformer + mask decoder
27//!   with object-pointer / object-score / high-res mask path
28//!   ([`prompt_encoder`], [`transformer`], [`mask_decoder`]).
29//! - **Phase 3** — memory encoder + memory attention for video
30//!   tracking ([`memory_encoder`], [`memory_attention`]).
31//! - **Top-level wrapper** — [`Sam2`] orchestrator with
32//!   `predict_image()` and `predict_video_frame()` APIs.
33//!
34//! ## Parity status
35//!
36//! Synthetic-weights build tests in [`tests`] exercise every component
37//! (encoder, prompt enc, decoder, memory enc/attn, end-to-end Sam2
38//! object) for every Hiera variant. Numerical parity against the
39//! pytorch reference is wired up in `tests/sam2_parity.rs` behind the
40//! `parity-pytorch` feature flag — turning the bisect options there
41//! against a real `sam2_hiera_*.safetensors` checkpoint is the
42//! follow-up bisect work (analogous to how SAM v1 Phase 1 landed
43//! parity in iterative passes after the initial graph was wired).
44
45pub mod axial_rope;
46pub mod cli;
47pub mod config;
48pub mod flow;
49pub mod fpn_neck;
50pub mod fpn_neck_ir;
51pub mod image_encoder;
52pub mod mask_decoder;
53pub mod memory_attention;
54pub mod memory_attention_ir;
55pub mod memory_encoder;
56pub mod memory_mask_ir;
57pub mod mlp_ir;
58pub mod preprocess;
59pub mod prompt_encoder;
60pub mod prompt_mask_ir;
61#[allow(clippy::module_inception)]
62pub mod sam2;
63pub mod transformer;
64pub mod transformer_ir;
65pub mod upscale_ir;
66
67pub use rlx_sam::profile::{
68    SAM_PROFILE_FILE, sam_profile_near_weights, sam2_profile_default, sam2_profile_near_weights,
69};
70
71pub use config::{
72    SAM2_IMG_SIZE, SAM2_PATCH_GRID, SAM2_PATCH_KERNEL, SAM2_PATCH_PADDING, SAM2_PATCH_STRIDE,
73    SAM2_PIXEL_MEAN, SAM2_PIXEL_STD, SAM2_PROMPT_EMBED_DIM, SAM2_Q_POOL_COUNT, SAM2_Q_STRIDE,
74    Sam2Config, Sam2DecoderConfig, Sam2FpnConfig, Sam2HieraConfig, Sam2MemoryConfig,
75    Sam2MemoryEncoderConfig,
76};
77pub use flow::{Sam2ImageEncoderBuilt, Sam2ImageEncoderFlow, build_sam2_image_encoder_built};
78pub use fpn_neck::{FpnLevel, FpnNeckWeights, apply_fpn_neck, apply_fpn_neck_host};
79pub use fpn_neck_ir::{Sam2FpnNeckIr, compile_fpn_neck_ir};
80pub use image_encoder::{build_sam2_image_encoder_graph, build_sam2_image_encoder_hir};
81pub use mask_decoder::{Sam2MaskDecoderOutput, Sam2MaskDecoderWeights, mask_decoder_forward};
82pub use memory_attention::{Sam2MemoryAttentionWeights, memory_attention_forward};
83pub use memory_attention_ir::MemoryAttentionCompiled;
84pub use memory_encoder::{
85    Sam2MemoryEncoderOutput, Sam2MemoryEncoderWeights, memory_encoder_forward,
86};
87pub use preprocess::{Sam2PreprocessWeights, assemble_patch_tokens, preprocess_image};
88pub use prompt_encoder::{
89    SAM2_MASK_IN_CHANS, SAM2_PROMPT_GRID, Sam2PromptEncoderOutput, Sam2PromptEncoderWeights,
90    prompt_encoder_forward,
91};
92pub use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
93pub use sam2::{Sam2, Sam2ImagePrediction, Sam2VideoState};
94pub use transformer::{Sam2TwoWayTransformerWeights, two_way_transformer_forward};
95pub use transformer_ir::compile_two_way_transformer;
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use rlx_core::weight_map::WeightMap;
101    use rlx_runtime::Device;
102    use std::collections::HashMap;
103
104    type T = HashMap<String, (Vec<f32>, Vec<usize>)>;
105
106    fn z(n: usize) -> Vec<f32> {
107        vec![0.0f32; n]
108    }
109
110    /// Insert every key required by the Hiera image encoder + FPN
111    /// neck. Mirrors the assertions in `image_encoder.rs` and
112    /// `fpn_neck.rs::extract_fpn_weights`.
113    fn add_hiera_weights(t: &mut T, cfg: &Sam2HieraConfig) {
114        let e0 = cfg.embed_dim;
115        let k = SAM2_PATCH_KERNEL;
116        let [ph, pw] = cfg.window_pos_embed_bkg_spatial_size;
117        let mu = cfg.window_size_at_stage(0);
118
119        t.insert(
120            "image_encoder.trunk.patch_embed.proj.weight".into(),
121            (z(e0 * 3 * k * k), vec![e0, 3, k, k]),
122        );
123        t.insert(
124            "image_encoder.trunk.patch_embed.proj.bias".into(),
125            (z(e0), vec![e0]),
126        );
127        t.insert(
128            "image_encoder.trunk.pos_embed".into(),
129            (z(e0 * ph * pw), vec![1, e0, ph, pw]),
130        );
131        t.insert(
132            "image_encoder.trunk.pos_embed_window".into(),
133            (z(e0 * mu * mu), vec![1, e0, mu, mu]),
134        );
135
136        let q_pool = cfg.q_pool_block_indices();
137        let total = cfg.total_blocks();
138        let mut stage = 0usize;
139        let mut dim_curr = e0;
140        for i in 0..total {
141            let is_q_pool = q_pool.contains(&i);
142            let dim_in = dim_curr;
143            let stage_after = if is_q_pool { stage + 1 } else { stage };
144            let dim_out = cfg.embed_dim_at_stage(stage_after);
145            let lp = format!("image_encoder.trunk.blocks.{i}");
146
147            t.insert(format!("{lp}.norm1.weight"), (z(dim_in), vec![dim_in]));
148            t.insert(format!("{lp}.norm1.bias"), (z(dim_in), vec![dim_in]));
149            if dim_in != dim_out {
150                t.insert(
151                    format!("{lp}.proj.weight"),
152                    (z(dim_in * dim_out), vec![dim_out, dim_in]),
153                );
154                t.insert(format!("{lp}.proj.bias"), (z(dim_out), vec![dim_out]));
155            }
156            t.insert(
157                format!("{lp}.attn.qkv.weight"),
158                (z(dim_in * 3 * dim_out), vec![3 * dim_out, dim_in]),
159            );
160            if cfg.qkv_bias {
161                t.insert(
162                    format!("{lp}.attn.qkv.bias"),
163                    (z(3 * dim_out), vec![3 * dim_out]),
164                );
165            }
166            t.insert(
167                format!("{lp}.attn.proj.weight"),
168                (z(dim_out * dim_out), vec![dim_out, dim_out]),
169            );
170            t.insert(format!("{lp}.attn.proj.bias"), (z(dim_out), vec![dim_out]));
171            t.insert(format!("{lp}.norm2.weight"), (z(dim_out), vec![dim_out]));
172            t.insert(format!("{lp}.norm2.bias"), (z(dim_out), vec![dim_out]));
173
174            let hidden = (dim_out as f64 * cfg.mlp_ratio) as usize;
175            t.insert(
176                format!("{lp}.mlp.layers.0.weight"),
177                (z(dim_out * hidden), vec![hidden, dim_out]),
178            );
179            t.insert(format!("{lp}.mlp.layers.0.bias"), (z(hidden), vec![hidden]));
180            t.insert(
181                format!("{lp}.mlp.layers.1.weight"),
182                (z(hidden * dim_out), vec![dim_out, hidden]),
183            );
184            t.insert(
185                format!("{lp}.mlp.layers.1.bias"),
186                (z(dim_out), vec![dim_out]),
187            );
188
189            if is_q_pool {
190                stage += 1;
191                dim_curr = dim_out;
192            }
193        }
194
195        let fpn = Sam2FpnConfig::for_hiera(cfg);
196        for (i, &cin) in fpn.backbone_channel_list.iter().enumerate() {
197            t.insert(
198                format!("image_encoder.neck.convs.{i}.conv.weight"),
199                (z(fpn.d_model * cin), vec![fpn.d_model, cin, 1, 1]),
200            );
201            t.insert(
202                format!("image_encoder.neck.convs.{i}.conv.bias"),
203                (z(fpn.d_model), vec![fpn.d_model]),
204            );
205        }
206    }
207
208    fn add_prompt_encoder_weights(t: &mut T, embed_dim: usize, mask_in_chans: usize) {
209        let half = embed_dim / 2;
210        let q = mask_in_chans / 4;
211        t.insert(
212            "sam_prompt_encoder.pe_layer.positional_encoding_gaussian_matrix".into(),
213            (z(2 * half), vec![2, half]),
214        );
215        t.insert(
216            "sam_prompt_encoder.not_a_point_embed.weight".into(),
217            (z(embed_dim), vec![1, embed_dim]),
218        );
219        t.insert(
220            "sam_prompt_encoder.no_mask_embed.weight".into(),
221            (z(embed_dim), vec![1, embed_dim]),
222        );
223        for i in 0..4 {
224            t.insert(
225                format!("sam_prompt_encoder.point_embeddings.{i}.weight"),
226                (z(embed_dim), vec![1, embed_dim]),
227            );
228        }
229        t.insert(
230            "sam_prompt_encoder.mask_downscaling.0.weight".into(),
231            (z(q * 4), vec![q, 1, 2, 2]),
232        );
233        t.insert(
234            "sam_prompt_encoder.mask_downscaling.0.bias".into(),
235            (z(q), vec![q]),
236        );
237        t.insert(
238            "sam_prompt_encoder.mask_downscaling.1.weight".into(),
239            (z(q), vec![q]),
240        );
241        t.insert(
242            "sam_prompt_encoder.mask_downscaling.1.bias".into(),
243            (z(q), vec![q]),
244        );
245        t.insert(
246            "sam_prompt_encoder.mask_downscaling.3.weight".into(),
247            (z(mask_in_chans * q * 4), vec![mask_in_chans, q, 2, 2]),
248        );
249        t.insert(
250            "sam_prompt_encoder.mask_downscaling.3.bias".into(),
251            (z(mask_in_chans), vec![mask_in_chans]),
252        );
253        t.insert(
254            "sam_prompt_encoder.mask_downscaling.4.weight".into(),
255            (z(mask_in_chans), vec![mask_in_chans]),
256        );
257        t.insert(
258            "sam_prompt_encoder.mask_downscaling.4.bias".into(),
259            (z(mask_in_chans), vec![mask_in_chans]),
260        );
261        t.insert(
262            "sam_prompt_encoder.mask_downscaling.6.weight".into(),
263            (
264                z(embed_dim * mask_in_chans),
265                vec![embed_dim, mask_in_chans, 1, 1],
266            ),
267        );
268        t.insert(
269            "sam_prompt_encoder.mask_downscaling.6.bias".into(),
270            (z(embed_dim), vec![embed_dim]),
271        );
272    }
273
274    fn add_two_way_transformer_weights(t: &mut T, cfg: &Sam2DecoderConfig) {
275        let e = cfg.transformer_dim;
276        let id = e / 2;
277        let mlp = cfg.transformer_mlp_dim;
278        for i in 0..cfg.transformer_depth {
279            let p = format!("sam_mask_decoder.transformer.layers.{i}");
280            // self_attn (downsample_rate=1 → internal_dim=e)
281            {
282                let sub = "self_attn";
283                t.insert(format!("{p}.{sub}.q_proj.weight"), (z(e * e), vec![e, e]));
284                t.insert(format!("{p}.{sub}.q_proj.bias"), (z(e), vec![e]));
285                t.insert(format!("{p}.{sub}.k_proj.weight"), (z(e * e), vec![e, e]));
286                t.insert(format!("{p}.{sub}.k_proj.bias"), (z(e), vec![e]));
287                t.insert(format!("{p}.{sub}.v_proj.weight"), (z(e * e), vec![e, e]));
288                t.insert(format!("{p}.{sub}.v_proj.bias"), (z(e), vec![e]));
289                t.insert(format!("{p}.{sub}.out_proj.weight"), (z(e * e), vec![e, e]));
290                t.insert(format!("{p}.{sub}.out_proj.bias"), (z(e), vec![e]));
291            }
292            t.insert(format!("{p}.norm1.weight"), (z(e), vec![e]));
293            t.insert(format!("{p}.norm1.bias"), (z(e), vec![e]));
294            // cross_attn_token_to_image, cross_attn_image_to_token (downsample_rate=2 → internal=e/2)
295            for sub in ["cross_attn_token_to_image", "cross_attn_image_to_token"] {
296                t.insert(format!("{p}.{sub}.q_proj.weight"), (z(e * id), vec![id, e]));
297                t.insert(format!("{p}.{sub}.q_proj.bias"), (z(id), vec![id]));
298                t.insert(format!("{p}.{sub}.k_proj.weight"), (z(e * id), vec![id, e]));
299                t.insert(format!("{p}.{sub}.k_proj.bias"), (z(id), vec![id]));
300                t.insert(format!("{p}.{sub}.v_proj.weight"), (z(e * id), vec![id, e]));
301                t.insert(format!("{p}.{sub}.v_proj.bias"), (z(id), vec![id]));
302                t.insert(
303                    format!("{p}.{sub}.out_proj.weight"),
304                    (z(e * id), vec![e, id]),
305                );
306                t.insert(format!("{p}.{sub}.out_proj.bias"), (z(e), vec![e]));
307            }
308            t.insert(format!("{p}.norm2.weight"), (z(e), vec![e]));
309            t.insert(format!("{p}.norm2.bias"), (z(e), vec![e]));
310            t.insert(
311                format!("{p}.mlp.layers.0.weight"),
312                (z(mlp * e), vec![mlp, e]),
313            );
314            t.insert(format!("{p}.mlp.layers.0.bias"), (z(mlp), vec![mlp]));
315            t.insert(
316                format!("{p}.mlp.layers.1.weight"),
317                (z(mlp * e), vec![e, mlp]),
318            );
319            t.insert(format!("{p}.mlp.layers.1.bias"), (z(e), vec![e]));
320            t.insert(format!("{p}.norm3.weight"), (z(e), vec![e]));
321            t.insert(format!("{p}.norm3.bias"), (z(e), vec![e]));
322            t.insert(format!("{p}.norm4.weight"), (z(e), vec![e]));
323            t.insert(format!("{p}.norm4.bias"), (z(e), vec![e]));
324        }
325        // final_attn_token_to_image (downsample_rate=2)
326        let p = "sam_mask_decoder.transformer.final_attn_token_to_image";
327        t.insert(format!("{p}.q_proj.weight"), (z(e * id), vec![id, e]));
328        t.insert(format!("{p}.q_proj.bias"), (z(id), vec![id]));
329        t.insert(format!("{p}.k_proj.weight"), (z(e * id), vec![id, e]));
330        t.insert(format!("{p}.k_proj.bias"), (z(id), vec![id]));
331        t.insert(format!("{p}.v_proj.weight"), (z(e * id), vec![id, e]));
332        t.insert(format!("{p}.v_proj.bias"), (z(id), vec![id]));
333        t.insert(format!("{p}.out_proj.weight"), (z(e * id), vec![e, id]));
334        t.insert(format!("{p}.out_proj.bias"), (z(e), vec![e]));
335        t.insert(
336            "sam_mask_decoder.transformer.norm_final_attn.weight".into(),
337            (z(e), vec![e]),
338        );
339        t.insert(
340            "sam_mask_decoder.transformer.norm_final_attn.bias".into(),
341            (z(e), vec![e]),
342        );
343    }
344
345    fn add_mask_decoder_weights(t: &mut T, cfg: &Sam2DecoderConfig) {
346        let e = cfg.transformer_dim;
347        let q4 = e / 4;
348        let q8 = e / 8;
349        t.insert(
350            "sam_mask_decoder.iou_token.weight".into(),
351            (z(e), vec![1, e]),
352        );
353        t.insert(
354            "sam_mask_decoder.mask_tokens.weight".into(),
355            (z(cfg.num_mask_tokens * e), vec![cfg.num_mask_tokens, e]),
356        );
357        if cfg.pred_obj_scores {
358            t.insert(
359                "sam_mask_decoder.obj_score_token.weight".into(),
360                (z(e), vec![1, e]),
361            );
362        }
363        t.insert(
364            "sam_mask_decoder.output_upscaling.0.weight".into(),
365            (z(e * q4 * 4), vec![e, q4, 2, 2]),
366        );
367        t.insert(
368            "sam_mask_decoder.output_upscaling.0.bias".into(),
369            (z(q4), vec![q4]),
370        );
371        t.insert(
372            "sam_mask_decoder.output_upscaling.1.weight".into(),
373            (z(q4), vec![q4]),
374        );
375        t.insert(
376            "sam_mask_decoder.output_upscaling.1.bias".into(),
377            (z(q4), vec![q4]),
378        );
379        t.insert(
380            "sam_mask_decoder.output_upscaling.3.weight".into(),
381            (z(q4 * q8 * 4), vec![q4, q8, 2, 2]),
382        );
383        t.insert(
384            "sam_mask_decoder.output_upscaling.3.bias".into(),
385            (z(q8), vec![q8]),
386        );
387        if cfg.use_high_res_features {
388            t.insert(
389                "sam_mask_decoder.conv_s0.weight".into(),
390                (z(q8 * e), vec![q8, e, 1, 1]),
391            );
392            t.insert("sam_mask_decoder.conv_s0.bias".into(), (z(q8), vec![q8]));
393            t.insert(
394                "sam_mask_decoder.conv_s1.weight".into(),
395                (z(q4 * e), vec![q4, e, 1, 1]),
396            );
397            t.insert("sam_mask_decoder.conv_s1.bias".into(), (z(q4), vec![q4]));
398        }
399        for i in 0..cfg.num_mask_tokens {
400            let p = format!("sam_mask_decoder.output_hypernetworks_mlps.{i}");
401            // 3-layer ReLU MLP: e → e → e → q8
402            t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
403            t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
404            t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
405            t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
406            t.insert(format!("{p}.layers.2.weight"), (z(e * q8), vec![q8, e]));
407            t.insert(format!("{p}.layers.2.bias"), (z(q8), vec![q8]));
408        }
409        // IoU prediction head: e → hidden → hidden → num_masks
410        let p = "sam_mask_decoder.iou_prediction_head";
411        let hidden = cfg.iou_head_hidden_dim;
412        t.insert(
413            format!("{p}.layers.0.weight"),
414            (z(e * hidden), vec![hidden, e]),
415        );
416        t.insert(format!("{p}.layers.0.bias"), (z(hidden), vec![hidden]));
417        t.insert(
418            format!("{p}.layers.1.weight"),
419            (z(hidden * hidden), vec![hidden, hidden]),
420        );
421        t.insert(format!("{p}.layers.1.bias"), (z(hidden), vec![hidden]));
422        t.insert(
423            format!("{p}.layers.2.weight"),
424            (
425                z(hidden * cfg.num_mask_tokens),
426                vec![cfg.num_mask_tokens, hidden],
427            ),
428        );
429        t.insert(
430            format!("{p}.layers.2.bias"),
431            (z(cfg.num_mask_tokens), vec![cfg.num_mask_tokens]),
432        );
433        // pred_obj_score_head MLP
434        if cfg.pred_obj_scores {
435            if cfg.pred_obj_scores_mlp {
436                let p = "sam_mask_decoder.pred_obj_score_head";
437                t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
438                t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
439                t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
440                t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
441                t.insert(format!("{p}.layers.2.weight"), (z(e), vec![1, e]));
442                t.insert(format!("{p}.layers.2.bias"), (z(1), vec![1]));
443            } else {
444                t.insert(
445                    "sam_mask_decoder.pred_obj_score_head.weight".into(),
446                    (z(e), vec![1, e]),
447                );
448                t.insert(
449                    "sam_mask_decoder.pred_obj_score_head.bias".into(),
450                    (z(1), vec![1]),
451                );
452            }
453        }
454        // obj_ptr_proj MLP — top-level under SAM2Base, not nested.
455        if cfg.use_object_pointer {
456            if cfg.use_mlp_for_obj_ptr_proj {
457                let p = "obj_ptr_proj";
458                t.insert(format!("{p}.layers.0.weight"), (z(e * e), vec![e, e]));
459                t.insert(format!("{p}.layers.0.bias"), (z(e), vec![e]));
460                t.insert(format!("{p}.layers.1.weight"), (z(e * e), vec![e, e]));
461                t.insert(format!("{p}.layers.1.bias"), (z(e), vec![e]));
462                t.insert(format!("{p}.layers.2.weight"), (z(e * e), vec![e, e]));
463                t.insert(format!("{p}.layers.2.bias"), (z(e), vec![e]));
464            } else {
465                t.insert("obj_ptr_proj.weight".into(), (z(e * e), vec![e, e]));
466                t.insert("obj_ptr_proj.bias".into(), (z(e), vec![e]));
467            }
468        }
469        add_two_way_transformer_weights(t, cfg);
470    }
471
472    fn add_memory_encoder_weights(t: &mut T, cfg: &Sam2MemoryEncoderConfig) {
473        // MaskDownSampler levels.
474        let mut in_c = 1usize;
475        let stride2 = cfg.mask_downsampler_stride * cfg.mask_downsampler_stride;
476        let mut num_levels = 0;
477        let mut acc = 1usize;
478        while acc < cfg.mask_downsampler_total_stride {
479            acc *= cfg.mask_downsampler_stride;
480            num_levels += 1;
481        }
482        for li in 0..num_levels {
483            let out_c = in_c * stride2;
484            let conv_idx = li * 3;
485            let ln_idx = conv_idx + 1;
486            let k = cfg.mask_downsampler_kernel;
487            t.insert(
488                format!("memory_encoder.mask_downsampler.encoder.{conv_idx}.weight"),
489                (z(out_c * in_c * k * k), vec![out_c, in_c, k, k]),
490            );
491            t.insert(
492                format!("memory_encoder.mask_downsampler.encoder.{conv_idx}.bias"),
493                (z(out_c), vec![out_c]),
494            );
495            t.insert(
496                format!("memory_encoder.mask_downsampler.encoder.{ln_idx}.weight"),
497                (z(out_c), vec![out_c]),
498            );
499            t.insert(
500                format!("memory_encoder.mask_downsampler.encoder.{ln_idx}.bias"),
501                (z(out_c), vec![out_c]),
502            );
503            in_c = out_c;
504        }
505        let final_idx = num_levels * 3;
506        t.insert(
507            format!("memory_encoder.mask_downsampler.encoder.{final_idx}.weight"),
508            (z(cfg.in_dim * in_c), vec![cfg.in_dim, in_c, 1, 1]),
509        );
510        t.insert(
511            format!("memory_encoder.mask_downsampler.encoder.{final_idx}.bias"),
512            (z(cfg.in_dim), vec![cfg.in_dim]),
513        );
514        // pix_feat_proj
515        t.insert(
516            "memory_encoder.pix_feat_proj.weight".into(),
517            (
518                z(cfg.in_dim * cfg.in_dim),
519                vec![cfg.in_dim, cfg.in_dim, 1, 1],
520            ),
521        );
522        t.insert(
523            "memory_encoder.pix_feat_proj.bias".into(),
524            (z(cfg.in_dim), vec![cfg.in_dim]),
525        );
526        // Fuser
527        for i in 0..cfg.fuser_num_layers {
528            let p = format!("memory_encoder.fuser.layers.{i}");
529            let dim = cfg.fuser_dim;
530            let k = cfg.fuser_kernel;
531            if cfg.fuser_use_dwconv {
532                t.insert(
533                    format!("{p}.dwconv.weight"),
534                    (z(dim * k * k), vec![dim, 1, k, k]),
535                );
536            } else {
537                t.insert(
538                    format!("{p}.dwconv.weight"),
539                    (z(dim * dim * k * k), vec![dim, dim, k, k]),
540                );
541            }
542            t.insert(format!("{p}.dwconv.bias"), (z(dim), vec![dim]));
543            t.insert(format!("{p}.norm.weight"), (z(dim), vec![dim]));
544            t.insert(format!("{p}.norm.bias"), (z(dim), vec![dim]));
545            t.insert(
546                format!("{p}.pwconv1.weight"),
547                (z(4 * dim * dim), vec![4 * dim, dim]),
548            );
549            t.insert(format!("{p}.pwconv1.bias"), (z(4 * dim), vec![4 * dim]));
550            t.insert(
551                format!("{p}.pwconv2.weight"),
552                (z(dim * 4 * dim), vec![dim, 4 * dim]),
553            );
554            t.insert(format!("{p}.pwconv2.bias"), (z(dim), vec![dim]));
555            if cfg.fuser_layer_scale_init_value > 0.0 {
556                t.insert(format!("{p}.gamma"), (z(dim), vec![dim]));
557            }
558        }
559        // out_proj (only when dims differ)
560        if cfg.in_dim != cfg.out_dim {
561            t.insert(
562                "memory_encoder.out_proj.weight".into(),
563                (
564                    z(cfg.in_dim * cfg.out_dim),
565                    vec![cfg.out_dim, cfg.in_dim, 1, 1],
566                ),
567            );
568            t.insert(
569                "memory_encoder.out_proj.bias".into(),
570                (z(cfg.out_dim), vec![cfg.out_dim]),
571            );
572        }
573    }
574
575    fn add_memory_attention_weights(t: &mut T, cfg: &Sam2MemoryConfig) {
576        let d = cfg.d_model;
577        let kv = cfg.kv_in_dim;
578        let dff = cfg.dim_feedforward;
579        for i in 0..cfg.num_layers {
580            let p = format!("memory_attention.layers.{i}");
581            // self_attn: q/k/v all from d → d
582            {
583                let sub = "self_attn";
584                t.insert(format!("{p}.{sub}.q_proj.weight"), (z(d * d), vec![d, d]));
585                t.insert(format!("{p}.{sub}.q_proj.bias"), (z(d), vec![d]));
586                t.insert(format!("{p}.{sub}.k_proj.weight"), (z(d * d), vec![d, d]));
587                t.insert(format!("{p}.{sub}.k_proj.bias"), (z(d), vec![d]));
588                t.insert(format!("{p}.{sub}.v_proj.weight"), (z(d * d), vec![d, d]));
589                t.insert(format!("{p}.{sub}.v_proj.bias"), (z(d), vec![d]));
590                t.insert(format!("{p}.{sub}.out_proj.weight"), (z(d * d), vec![d, d]));
591                t.insert(format!("{p}.{sub}.out_proj.bias"), (z(d), vec![d]));
592            }
593            // cross_attn_image: q from d → d, k/v from kv → d
594            {
595                let sub = "cross_attn_image";
596                t.insert(format!("{p}.{sub}.q_proj.weight"), (z(d * d), vec![d, d]));
597                t.insert(format!("{p}.{sub}.q_proj.bias"), (z(d), vec![d]));
598                t.insert(format!("{p}.{sub}.k_proj.weight"), (z(d * kv), vec![d, kv]));
599                t.insert(format!("{p}.{sub}.k_proj.bias"), (z(d), vec![d]));
600                t.insert(format!("{p}.{sub}.v_proj.weight"), (z(d * kv), vec![d, kv]));
601                t.insert(format!("{p}.{sub}.v_proj.bias"), (z(d), vec![d]));
602                t.insert(format!("{p}.{sub}.out_proj.weight"), (z(d * d), vec![d, d]));
603                t.insert(format!("{p}.{sub}.out_proj.bias"), (z(d), vec![d]));
604            }
605            t.insert(format!("{p}.norm1.weight"), (z(d), vec![d]));
606            t.insert(format!("{p}.norm1.bias"), (z(d), vec![d]));
607            t.insert(format!("{p}.norm2.weight"), (z(d), vec![d]));
608            t.insert(format!("{p}.norm2.bias"), (z(d), vec![d]));
609            t.insert(format!("{p}.norm3.weight"), (z(d), vec![d]));
610            t.insert(format!("{p}.norm3.bias"), (z(d), vec![d]));
611            t.insert(format!("{p}.linear1.weight"), (z(dff * d), vec![dff, d]));
612            t.insert(format!("{p}.linear1.bias"), (z(dff), vec![dff]));
613            t.insert(format!("{p}.linear2.weight"), (z(d * dff), vec![d, dff]));
614            t.insert(format!("{p}.linear2.bias"), (z(d), vec![d]));
615        }
616        t.insert("memory_attention.norm.weight".into(), (z(d), vec![d]));
617        t.insert("memory_attention.norm.bias".into(), (z(d), vec![d]));
618    }
619
620    fn synthetic_full_sam2_weights(cfg: &Sam2Config) -> WeightMap {
621        let mut t: T = HashMap::new();
622        add_hiera_weights(&mut t, &cfg.hiera);
623        add_prompt_encoder_weights(&mut t, cfg.decoder.transformer_dim, SAM2_MASK_IN_CHANS);
624        add_mask_decoder_weights(&mut t, &cfg.decoder);
625        add_memory_encoder_weights(&mut t, &cfg.memory_encoder);
626        add_memory_attention_weights(&mut t, &cfg.memory);
627        WeightMap::from_tensors(t)
628    }
629
630    fn assert_encoder_builds(cfg: Sam2HieraConfig) {
631        let mut t: T = HashMap::new();
632        add_hiera_weights(&mut t, &cfg);
633        let mut wm = WeightMap::from_tensors(t);
634        let (g, _params, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg, &mut wm)
635            .unwrap_or_else(|e| panic!("encoder build failed: {e}"));
636        assert_eq!(g.outputs.len(), cfg.stages.len());
637        for (s, out_id) in g.outputs.iter().copied().enumerate() {
638            let shape = g.shape(out_id);
639            let dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
640            let hw_s = cfg.grid_size_at_stage(s);
641            let dim_s = cfg.embed_dim_at_stage(s);
642            assert_eq!(dims, vec![1, hw_s, hw_s, dim_s], "stage {s} shape mismatch");
643        }
644        let leftovers: Vec<&str> = wm.keys().collect();
645        assert!(leftovers.is_empty(), "leftover weights: {leftovers:?}");
646    }
647
648    #[test]
649    fn encoder_graph_builds_tiny() {
650        assert_encoder_builds(Sam2HieraConfig::tiny());
651    }
652
653    #[test]
654    fn encoder_graph_builds_small() {
655        assert_encoder_builds(Sam2HieraConfig::small());
656    }
657
658    #[test]
659    fn encoder_graph_builds_base_plus() {
660        assert_encoder_builds(Sam2HieraConfig::base_plus());
661    }
662
663    #[test]
664    fn encoder_graph_builds_large() {
665        assert_encoder_builds(Sam2HieraConfig::large());
666    }
667
668    #[test]
669    fn preprocess_round_trip_shapes() {
670        let img = vec![64u8; 80 * 120 * 3];
671        let nchw = preprocess_image(&img, 80, 120);
672        assert_eq!(nchw.len(), 3 * 1024 * 1024);
673    }
674
675    #[test]
676    fn fpn_neck_runs_on_synth_outputs() {
677        let cfg = Sam2HieraConfig::base_plus();
678        let mut t: T = HashMap::new();
679        add_hiera_weights(&mut t, &cfg);
680        let mut wm = WeightMap::from_tensors(t);
681        let (_g, _p, _pre, neck) = build_sam2_image_encoder_graph(&cfg, &mut wm).unwrap();
682
683        let stage_hw: Vec<(usize, usize)> = (0..cfg.stages.len())
684            .map(|s| (cfg.grid_size_at_stage(s), cfg.grid_size_at_stage(s)))
685            .collect();
686        let stage_dims: Vec<usize> = (0..cfg.stages.len())
687            .map(|s| cfg.embed_dim_at_stage(s))
688            .collect();
689        let stage_outputs: Vec<Vec<f32>> = stage_hw
690            .iter()
691            .zip(&stage_dims)
692            .map(|(&(h, w), &d)| vec![0f32; h * w * d])
693            .collect();
694
695        let mut fpn_ir = super::fpn_neck_ir::compile_fpn_neck_ir(
696            &neck,
697            &stage_hw,
698            &stage_dims,
699            Device::Cpu,
700            &rlx_flow::CompileProfile::sam2(),
701        )
702        .unwrap();
703        let levels =
704            apply_fpn_neck(&neck, &mut fpn_ir, &stage_outputs, &stage_hw, &stage_dims).unwrap();
705        let levels_host = apply_fpn_neck_host(&neck, &stage_outputs, &stage_hw, &stage_dims);
706        assert_eq!(levels.len(), levels_host.len());
707        for (a, b) in levels.iter().zip(&levels_host) {
708            assert_eq!(a.features.len(), b.features.len());
709            assert_eq!(a.h, b.h);
710            assert_eq!(a.w, b.w);
711            let fd = a
712                .features
713                .iter()
714                .zip(&b.features)
715                .map(|(x, y)| (x - y).abs())
716                .fold(0f32, f32::max);
717            assert!(
718                fd < 1e-4,
719                "FPN IR vs host max |Δ| = {fd:.3e} at level {}×{}",
720                a.h,
721                a.w
722            );
723        }
724        assert_eq!(levels.len(), 4);
725        assert_eq!((levels[0].h, levels[0].w), (256, 256));
726        assert_eq!((levels[3].h, levels[3].w), (32, 32));
727    }
728
729    #[test]
730    fn full_weight_extraction_drains_map() {
731        // End-to-end: build the synthetic WeightMap for *every* SAM 2
732        // component, instantiate via the same code paths Sam2 uses, and
733        // assert no expected keys are left over.
734        let cfg = Sam2Config::hiera_base_plus();
735        let mut wm = synthetic_full_sam2_weights(&cfg);
736
737        // Mirror Sam2::from_safetensors_on weight extraction order.
738        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
739        let _ = prompt_encoder::extract_prompt_encoder_weights(
740            &mut wm,
741            cfg.decoder.transformer_dim,
742            SAM2_MASK_IN_CHANS,
743        )
744        .unwrap();
745        let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
746        let _ =
747            memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
748        let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
749
750        let leftovers: Vec<&str> = wm.keys().collect();
751        assert!(
752            leftovers.is_empty(),
753            "leftover weights after full extraction: {leftovers:?}"
754        );
755    }
756
757    #[test]
758    fn prompt_encoder_no_prompt_produces_pe_and_no_mask() {
759        let cfg = Sam2Config::hiera_base_plus();
760        let mut wm = synthetic_full_sam2_weights(&cfg);
761        // Drain encoder + FPN keys to keep the test focused.
762        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
763        let pe = prompt_encoder::extract_prompt_encoder_weights(
764            &mut wm,
765            cfg.decoder.transformer_dim,
766            SAM2_MASK_IN_CHANS,
767        )
768        .unwrap();
769        let mut mask_stack =
770            super::prompt_mask_ir::Sam2PromptMaskCompiled::compile(&pe, Device::Cpu).unwrap();
771        let out = prompt_encoder_forward(&pe, &mut mask_stack, None, None, None).unwrap();
772        assert_eq!(out.num_sparse_tokens, 0);
773        assert_eq!(
774            out.dense_embeddings.len(),
775            cfg.decoder.transformer_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
776        );
777        assert_eq!(
778            out.image_pe.len(),
779            cfg.decoder.transformer_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
780        );
781    }
782
783    #[test]
784    fn mask_decoder_runs_on_zero_inputs() {
785        let cfg = Sam2Config::hiera_base_plus();
786        let mut wm = synthetic_full_sam2_weights(&cfg);
787        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
788        let _pe = prompt_encoder::extract_prompt_encoder_weights(
789            &mut wm,
790            cfg.decoder.transformer_dim,
791            SAM2_MASK_IN_CHANS,
792        )
793        .unwrap();
794        let dec = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
795        let _ =
796            memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
797        let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
798
799        let e = cfg.decoder.transformer_dim;
800        let g = SAM2_PROMPT_GRID;
801        let image_emb = vec![0f32; e * g * g];
802        let image_pe = vec![0f32; e * g * g];
803        let dense = vec![0f32; e * g * g];
804        let sparse: Vec<f32> = Vec::new();
805        let s0 = vec![0f32; e * (4 * g) * (4 * g)];
806        let s1 = vec![0f32; e * (2 * g) * (2 * g)];
807
808        let mut upscale =
809            super::upscale_ir::Sam2MaskUpscaleCompiled::compile(&dec, g, Device::Cpu).unwrap();
810        let mut hyper_matmul = rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled::compile(
811            dec.num_mask_tokens,
812            cfg.decoder.transformer_dim / 8,
813            g,
814            Device::Cpu,
815        )
816        .unwrap();
817        let mut hyper_mlps_ir =
818            super::mlp_ir::compile_hyper_mlps(&dec.hyper_mlps, Device::Cpu).unwrap();
819        let mut iou_head_ir = super::mlp_ir::compile_hyper_mlp(&dec.iou_head, Device::Cpu).unwrap();
820        let mut obj_score_head_ir =
821            super::mlp_ir::compile_optional_hyper_mlp(&dec.obj_score_head, 1, Device::Cpu).unwrap();
822        let obj_ptr_rows = super::mlp_ir::obj_ptr_proj_rows(
823            dec.num_mask_tokens,
824            dec.use_multimask_token_for_obj_ptr,
825        );
826        let mut obj_ptr_proj_ir =
827            super::mlp_ir::compile_optional_hyper_mlp(&dec.obj_ptr_proj, obj_ptr_rows, Device::Cpu)
828                .unwrap();
829        let s_tok = if dec.obj_score_token.is_some() { 1 } else { 0 };
830        let base_q_n = s_tok + 1 + dec.num_mask_tokens;
831        let mut tw_ir = super::transformer_ir::compile_two_way_transformer(
832            &dec.transformer,
833            base_q_n,
834            g,
835            Device::Cpu,
836        )
837        .unwrap();
838        let out = mask_decoder_forward(
839            &dec,
840            &mut upscale,
841            Some(&mut hyper_matmul),
842            Some(&mut hyper_mlps_ir),
843            Some(&mut iou_head_ir),
844            obj_score_head_ir.as_mut(),
845            obj_ptr_proj_ir.as_mut(),
846            Some(&mut tw_ir),
847            &image_emb,
848            &image_pe,
849            &sparse,
850            0,
851            &dense,
852            Some((&s0, &s1)),
853            /*multimask_output=*/ true,
854            g,
855        )
856        .unwrap();
857        assert_eq!(out.num_masks, 3);
858        assert_eq!(out.h_out, 4 * g);
859        assert_eq!(out.w_out, 4 * g);
860        assert_eq!(out.masks.len(), 3 * out.h_out * out.w_out);
861        assert_eq!(out.iou_pred.len(), 3);
862        // pred_obj_scores=true → object_score_logits is from the MLP head (single scalar).
863        assert_eq!(out.object_score_logits.len(), 1);
864    }
865
866    #[test]
867    fn memory_encoder_prefix_matches_split_ir() {
868        let cfg = Sam2Config::hiera_base_plus();
869        let mut wm = synthetic_full_sam2_weights(&cfg);
870        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
871        let mem =
872            memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
873
874        let pix = vec![0.1f32; cfg.memory_encoder.in_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID];
875        let mask = vec![0.5f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
876
877        let mut md = memory_mask_ir::Sam2MemoryMaskDownCompiled::compile(
878            &mem.mask_downsampler,
879            SAM2_IMG_SIZE,
880            SAM2_IMG_SIZE,
881            Device::Cpu,
882        )
883        .unwrap();
884        let mut pp = memory_mask_ir::Sam2MemoryConv1x1Compiled::compile(
885            mem.in_dim,
886            mem.in_dim,
887            SAM2_PROMPT_GRID,
888            SAM2_PROMPT_GRID,
889            &mem.pix_feat_proj_w,
890            &mem.pix_feat_proj_b,
891            Device::Cpu,
892        )
893        .unwrap();
894        let m_down = md.run(&mask).unwrap();
895        let mut split = pp.run(&pix).unwrap();
896        for i in 0..split.len() {
897            split[i] += m_down[i];
898        }
899
900        let mut prefix = memory_mask_ir::Sam2MemoryPrefixCompiled::compile(
901            &mem.mask_downsampler,
902            mem.in_dim,
903            SAM2_IMG_SIZE,
904            SAM2_IMG_SIZE,
905            SAM2_PROMPT_GRID,
906            SAM2_PROMPT_GRID,
907            &mem.pix_feat_proj_w,
908            &mem.pix_feat_proj_b,
909            Device::Cpu,
910        )
911        .unwrap();
912        let fused = prefix.run(&mask, &pix).unwrap();
913        assert_eq!(split.len(), fused.len());
914        let fd = split
915            .iter()
916            .zip(&fused)
917            .map(|(a, b)| (a - b).abs())
918            .fold(0f32, f32::max);
919        assert!(fd < 1e-4, "prefix vs split max |Δ| = {fd:.3e}");
920    }
921
922    #[test]
923    fn memory_encoder_shapes_match_for_b_plus() {
924        let cfg = Sam2Config::hiera_base_plus();
925        let mut wm = synthetic_full_sam2_weights(&cfg);
926        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
927        let _ = prompt_encoder::extract_prompt_encoder_weights(
928            &mut wm,
929            cfg.decoder.transformer_dim,
930            SAM2_MASK_IN_CHANS,
931        )
932        .unwrap();
933        let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
934        let mut mem =
935            memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
936        memory_encoder::compile_memory_encoder_ir(
937            &mut mem,
938            SAM2_IMG_SIZE,
939            SAM2_IMG_SIZE,
940            SAM2_PROMPT_GRID,
941            SAM2_PROMPT_GRID,
942            Device::Cpu,
943            &rlx_flow::CompileProfile::sam2(),
944        )
945        .unwrap();
946        let _ = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
947
948        let pix = vec![0f32; cfg.memory_encoder.in_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID];
949        let mask = vec![0f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
950        let out = memory_encoder_forward(
951            &mut mem,
952            &pix,
953            &mask,
954            SAM2_PROMPT_GRID,
955            SAM2_PROMPT_GRID,
956            /*skip_mask_sigmoid=*/ true,
957        )
958        .unwrap();
959        assert_eq!(out.h, SAM2_PROMPT_GRID);
960        assert_eq!(out.w, SAM2_PROMPT_GRID);
961        assert_eq!(
962            out.features.len(),
963            cfg.memory_encoder.out_dim * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
964        );
965        // PE channel count = 2 · num_pos_feats.
966        assert_eq!(
967            out.pos.len(),
968            2 * cfg.memory_encoder.pe_num_pos_feats * SAM2_PROMPT_GRID * SAM2_PROMPT_GRID
969        );
970    }
971
972    #[test]
973    fn memory_attention_runs_on_zero_inputs() {
974        let cfg = Sam2Config::hiera_base_plus();
975        let mut wm = synthetic_full_sam2_weights(&cfg);
976        let (_g, _p, _pre, _fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm).unwrap();
977        let _ = prompt_encoder::extract_prompt_encoder_weights(
978            &mut wm,
979            cfg.decoder.transformer_dim,
980            SAM2_MASK_IN_CHANS,
981        )
982        .unwrap();
983        let _ = mask_decoder::extract_mask_decoder_weights(&mut wm, &cfg.decoder).unwrap();
984        let _ =
985            memory_encoder::extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder).unwrap();
986        let mat = memory_attention::extract_memory_attention_weights(&mut wm, &cfg.memory).unwrap();
987
988        let [end_x, end_y] = cfg.memory.rope_feat_size;
989        let n_img = end_x * end_y;
990        let d = cfg.memory.d_model;
991        let kv = cfg.memory.kv_in_dim;
992        let curr = vec![0f32; n_img * d];
993        let curr_pos = vec![0f32; n_img * d];
994        // 1 frame of memory.
995        let n_mem = end_x * end_y;
996        let memory = vec![0f32; n_mem * kv];
997        let memory_pos = vec![0f32; n_mem * kv];
998        let out = memory_attention_forward(
999            &mat,
1000            &curr,
1001            &curr_pos,
1002            &memory,
1003            &memory_pos,
1004            n_img,
1005            n_mem,
1006            kv,
1007            /*num_obj_ptr_tokens=*/ 0,
1008        )
1009        .unwrap();
1010        assert_eq!(out.len(), n_img * d);
1011        assert!(out.iter().all(|v| v.is_finite()));
1012    }
1013}