Skip to main content

rlx_sam2/
sam2.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 top-level orchestrator — ties together the IR-graph Hiera
17//! image encoder, the host-side FpnNeck, prompt encoder, mask decoder,
18//! memory encoder, and memory attention into the two reference APIs:
19//!
20//!   - [`Sam2::predict_image`] — single-image segmentation (matches
21//!     `SAM2ImagePredictor.predict` in spirit).
22//!   - [`Sam2::predict_video_frame`] — stateful per-frame call with a
23//!     [`Sam2VideoState`] memory bank (mirrors `SAM2VideoPredictor`).
24//!
25//! The image encoder is compiled once on the chosen
26//! [`rlx_runtime::Device`]; every other component runs host-side
27//! because their compute is < 1 % of total per inference and the IR
28//! surface to support them all (cross-attention with kv_in_dim,
29//! depthwise Conv2d, ConvTranspose2d, sigmoid, etc.) isn't worth
30//! growing for a fraction of a millisecond's win.
31
32use super::config::{SAM2_IMG_SIZE, Sam2Config, Sam2DecoderConfig};
33use super::fpn_neck::{FpnLevel, FpnNeckWeights, apply_fpn_neck};
34use super::fpn_neck_ir::{Sam2FpnNeckIr, compile_fpn_neck_ir};
35use super::image_encoder::build_sam2_image_encoder_graph;
36use super::mask_decoder::{
37    Sam2MaskDecoderOutput, Sam2MaskDecoderWeights, extract_mask_decoder_weights,
38    mask_decoder_forward,
39};
40use super::memory_attention::{
41    Sam2MemoryAttentionWeights, extract_memory_attention_weights, memory_attention_forward,
42};
43use super::memory_attention_ir::{MemoryAttentionCompiled, max_memory_slots};
44use super::memory_encoder::{
45    Sam2MemoryEncoderOutput, Sam2MemoryEncoderWeights, extract_memory_encoder_weights,
46    memory_encoder_forward,
47};
48use super::preprocess::{Sam2PreprocessWeights, assemble_patch_tokens, preprocess_image};
49use super::prompt_encoder::{
50    SAM2_MASK_IN_CHANS, SAM2_PROMPT_GRID, Sam2PromptEncoderOutput, Sam2PromptEncoderWeights,
51    extract_prompt_encoder_weights, prompt_encoder_forward,
52};
53use super::prompt_mask_ir::Sam2PromptMaskCompiled;
54use super::upscale_ir::Sam2MaskUpscaleCompiled;
55use anyhow::{Result, ensure};
56use rlx_flow::CompileProfile;
57use rlx_runtime::{CompiledGraph, Device, Session};
58use rlx_sam::profile::sam2_profile_near_weights;
59use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
60use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
61use std::path::Path;
62
63/// SAM 2 image-encoder hiera stage spec — needed by the host-side FPN.
64#[derive(Clone)]
65struct HieraOutputShapes {
66    stage_hw: Vec<(usize, usize)>,
67    stage_dims: Vec<usize>,
68}
69
70/// Full SAM 2 model — owns the compiled image encoder + every
71/// host-side weight bundle. The encoder result is recomputed per call
72/// (no encoder-caching here; layer above can wrap if needed).
73pub struct Sam2 {
74    cfg: Sam2Config,
75    encoder: CompiledGraph,
76    pre: Sam2PreprocessWeights,
77    fpn: FpnNeckWeights,
78    fpn_ir: Sam2FpnNeckIr,
79    prompt_enc: Sam2PromptEncoderWeights,
80    mask_dec: Sam2MaskDecoderWeights,
81    mask_stack: Sam2PromptMaskCompiled,
82    upscale: Sam2MaskUpscaleCompiled,
83    hyper_matmul: MaskHyperMatmulCompiled,
84    hyper_mlps_ir: Vec<MlpReluCompiled>,
85    iou_head_ir: MlpReluCompiled,
86    obj_score_head_ir: Option<MlpReluCompiled>,
87    obj_ptr_proj_ir: Option<MlpReluCompiled>,
88    tw_ir: rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled,
89    mem_enc: Sam2MemoryEncoderWeights,
90    mem_attn: Sam2MemoryAttentionWeights,
91    /// Compiled on first video frame (avoids multi-minute compile at checkpoint load).
92    mem_attn_ir: Option<MemoryAttentionCompiled>,
93    mem_attn_device: Device,
94    hiera_shapes: HieraOutputShapes,
95    compile_profile: CompileProfile,
96}
97
98impl Sam2 {
99    /// Load every SAM 2 component from a safetensors checkpoint and
100    /// compile the image encoder for the CPU backend. For GPU/Metal,
101    /// see [`Sam2::from_safetensors_on`].
102    pub fn from_safetensors(weights_path: &str, cfg: Sam2Config) -> Result<Self> {
103        Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
104    }
105
106    /// Same as [`Sam2::from_safetensors`] but compiles the image
107    /// encoder for the given backend. The cross-backend feature flags
108    /// match SAM v1's [`rlx_sam::Sam::from_safetensors_on`].
109    pub fn from_safetensors_on(
110        weights_path: &str,
111        cfg: Sam2Config,
112        device: Device,
113    ) -> Result<Self> {
114        rlx_core::validate_sam_device("sam2", device)?;
115        let mut wm =
116            rlx_core::load_weight_map(Path::new(weights_path), rlx_core::SAM2_GGUF_ARCHES)?;
117        let compile_profile = sam2_profile_near_weights(Path::new(weights_path));
118
119        // 1) Hiera image encoder graph (drains its weight keys + the
120        //    preprocess + FPN-neck weights).
121        let (graph, params, pre, fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm)?;
122
123        let hiera_shapes = HieraOutputShapes {
124            stage_hw: (0..cfg.hiera.stages.len())
125                .map(|s| {
126                    (
127                        cfg.hiera.grid_size_at_stage(s),
128                        cfg.hiera.grid_size_at_stage(s),
129                    )
130                })
131                .collect(),
132            stage_dims: (0..cfg.hiera.stages.len())
133                .map(|s| cfg.hiera.embed_dim_at_stage(s))
134                .collect(),
135        };
136
137        // 2) Prompt encoder.
138        let prompt_enc = extract_prompt_encoder_weights(
139            &mut wm,
140            cfg.decoder.transformer_dim,
141            SAM2_MASK_IN_CHANS,
142        )?;
143
144        // 3) Mask decoder.
145        let mask_dec = extract_mask_decoder_weights(&mut wm, &cfg.decoder)?;
146
147        // 4) Memory encoder.
148        let mut mem_enc = extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder)?;
149        super::memory_encoder::compile_memory_encoder_ir(
150            &mut mem_enc,
151            SAM2_IMG_SIZE,
152            SAM2_IMG_SIZE,
153            SAM2_PROMPT_GRID,
154            SAM2_PROMPT_GRID,
155            device,
156            &compile_profile,
157        )?;
158
159        // 5) Memory attention.
160        let mem_attn = extract_memory_attention_weights(&mut wm, &cfg.memory)?;
161        let grid = cfg.hiera.grid_size_at_stage(cfg.hiera.stages.len() - 1);
162        let mask_stack =
163            Sam2PromptMaskCompiled::compile_with_profile(&prompt_enc, device, &compile_profile)?;
164        let upscale = Sam2MaskUpscaleCompiled::compile_with_profile(
165            &mask_dec,
166            grid,
167            device,
168            &compile_profile,
169        )?;
170        let hyper_matmul = MaskHyperMatmulCompiled::compile_with_profile(
171            mask_dec.num_mask_tokens,
172            cfg.decoder.transformer_dim / 8,
173            grid,
174            device,
175            &compile_profile,
176        )?;
177        let hyper_mlps_ir = super::mlp_ir::compile_hyper_mlps_with_profile(
178            &mask_dec.hyper_mlps,
179            device,
180            &compile_profile,
181        )?;
182        let iou_head_ir = super::mlp_ir::compile_hyper_mlp_with_profile(
183            &mask_dec.iou_head,
184            device,
185            &compile_profile,
186        )?;
187        let obj_score_head_ir = super::mlp_ir::compile_optional_hyper_mlp_with_profile(
188            &mask_dec.obj_score_head,
189            1,
190            device,
191            &compile_profile,
192        )?;
193        let obj_ptr_rows = super::mlp_ir::obj_ptr_proj_rows(
194            mask_dec.num_mask_tokens,
195            mask_dec.use_multimask_token_for_obj_ptr,
196        );
197        let obj_ptr_proj_ir = super::mlp_ir::compile_optional_hyper_mlp_with_profile(
198            &mask_dec.obj_ptr_proj,
199            obj_ptr_rows,
200            device,
201            &compile_profile,
202        )?;
203        let s_tok = if mask_dec.obj_score_token.is_some() {
204            1
205        } else {
206            0
207        };
208        let base_q_n = s_tok + 1 + mask_dec.num_mask_tokens;
209        let grid = cfg.hiera.grid_size_at_stage(cfg.hiera.stages.len() - 1);
210        let tw_ir = super::transformer_ir::compile_two_way_transformer_with_profile(
211            &mask_dec.transformer,
212            base_q_n,
213            grid,
214            device,
215            &compile_profile,
216        )?;
217        let fpn_ir = compile_fpn_neck_ir(
218            &fpn,
219            &hiera_shapes.stage_hw,
220            &hiera_shapes.stage_dims,
221            device,
222            &compile_profile,
223        )?;
224
225        let opts = rlx_core::flow_bridge::compile_options_for_profile(&compile_profile, device);
226        let mut encoder = Session::new(device).compile_with(graph, &opts);
227        for (name, data) in &params {
228            encoder.set_param(name, data);
229        }
230
231        // Preflight: at least the most-important keys should be drained.
232        // We don't assert the full map is empty because the published
233        // sam2 checkpoints include training-only buffers we choose to
234        // ignore (e.g. `maskmem_tpos_enc`, optimizer state remnants).
235        Ok(Self {
236            cfg,
237            encoder,
238            pre,
239            fpn,
240            fpn_ir,
241            prompt_enc,
242            mask_dec,
243            mask_stack,
244            upscale,
245            hyper_matmul,
246            hyper_mlps_ir,
247            iou_head_ir,
248            obj_score_head_ir,
249            obj_ptr_proj_ir,
250            tw_ir,
251            mem_enc,
252            mem_attn,
253            mem_attn_ir: None,
254            mem_attn_device: device,
255            hiera_shapes,
256            compile_profile,
257        })
258    }
259
260    /// Tier-1 compile profile (`sam.rlx.toml` next to weights when present).
261    pub fn compile_profile(&self) -> &CompileProfile {
262        &self.compile_profile
263    }
264
265    pub fn config(&self) -> &Sam2Config {
266        &self.cfg
267    }
268
269    fn ensure_mem_attn_ir(&mut self) -> Result<()> {
270        if self.mem_attn_ir.is_some() {
271            return Ok(());
272        }
273        let [rope_x, rope_y] = self.cfg.memory.rope_feat_size;
274        let n_img_mem = rope_x * rope_y;
275        let max_n_mem = max_memory_slots(n_img_mem, self.cfg.memory.max_obj_ptrs_in_encoder);
276        self.mem_attn_ir = Some(if self.cfg.memory.mem_attn_in_graph_rope {
277            MemoryAttentionCompiled::compile_in_graph_rope_with_profile(
278                &self.mem_attn,
279                n_img_mem,
280                max_n_mem,
281                self.cfg.memory.max_obj_ptrs_in_encoder,
282                self.mem_attn_device,
283                &self.compile_profile,
284            )?
285        } else {
286            MemoryAttentionCompiled::compile_with_profile(
287                &self.mem_attn,
288                n_img_mem,
289                max_n_mem,
290                self.cfg.memory.max_obj_ptrs_in_encoder,
291                self.mem_attn_device,
292                &self.compile_profile,
293            )?
294        });
295        Ok(())
296    }
297
298    /// Run the encoder + FPN host-side neck and return per-level
299    /// features ordered fine → coarse (stride 4, 8, 16, 32).
300    fn encode(&mut self, image_u8: &[u8], h_in: usize, w_in: usize) -> Result<Vec<FpnLevel>> {
301        let image_nchw = preprocess_image(image_u8, h_in, w_in);
302        let hidden = assemble_patch_tokens(&self.pre, &image_nchw)?;
303        let outputs = self.encoder.run(&[("hidden", hidden.as_slice())]);
304        ensure!(
305            outputs.len() == self.hiera_shapes.stage_dims.len(),
306            "encoder produced {} outputs (expected {})",
307            outputs.len(),
308            self.hiera_shapes.stage_dims.len()
309        );
310        apply_fpn_neck(
311            &self.fpn,
312            &mut self.fpn_ir,
313            &outputs,
314            &self.hiera_shapes.stage_hw,
315            &self.hiera_shapes.stage_dims,
316        )
317    }
318
319    /// Image-segmentation API.
320    ///
321    /// `image_u8`: row-major RGB `h_in × w_in × 3` u8.
322    /// `points`: optional `(coords [N,2], labels [N])` — coords in
323    ///     input-image pixels (0..max(h_in, w_in)), labels per
324    ///     [`prompt_encoder_forward`].
325    /// `boxes`: optional `[M, 4]` boxes (x0, y0, x1, y1) in input
326    ///     pixels.
327    /// `mask_input`: optional `[1, 256, 256]` low-res mask logits.
328    /// `multimask_output`: true → 3 masks; false → 1 (with optional
329    ///     dynamic-stability fallback).
330    ///
331    /// Returns `(mask_logits, iou_pred, num_masks, h_out, w_out)`
332    /// where `(h_out, w_out)` = `(4·SAM2_PROMPT_GRID, 4·SAM2_PROMPT_GRID)`
333    /// = 256×256 — caller resizes to the original image resolution.
334    pub fn predict_image(
335        &mut self,
336        image_u8: &[u8],
337        h_in: usize,
338        w_in: usize,
339        points: Option<(&[f32], &[f32])>,
340        boxes: Option<&[f32]>,
341        mask_input: Option<&[f32]>,
342        multimask_output: bool,
343    ) -> Result<Sam2ImagePrediction> {
344        let levels = self.encode(image_u8, h_in, w_in)?;
345        // FPN levels are fine→coarse: stride 4, 8, 16, 32.
346        // Image embedding for the mask decoder is the stride-16 level
347        // (index 2). High-res features are stride-4 + stride-8.
348        let prompt = self.run_prompt(points, boxes, mask_input)?;
349        let dec = self.run_decoder(&levels, &prompt, multimask_output)?;
350
351        Ok(Sam2ImagePrediction {
352            masks: dec.masks,
353            iou_pred: dec.iou_pred,
354            num_masks: dec.num_masks,
355            h_out: dec.h_out,
356            w_out: dec.w_out,
357            object_score_logits: dec.object_score_logits,
358            object_pointer: dec.object_pointer,
359        })
360    }
361
362    fn run_prompt(
363        &mut self,
364        points: Option<(&[f32], &[f32])>,
365        boxes: Option<&[f32]>,
366        mask_input: Option<&[f32]>,
367    ) -> Result<Sam2PromptEncoderOutput> {
368        prompt_encoder_forward(
369            &self.prompt_enc,
370            &mut self.mask_stack,
371            points,
372            boxes,
373            mask_input,
374        )
375    }
376
377    fn run_decoder(
378        &mut self,
379        levels: &[FpnLevel],
380        prompt: &Sam2PromptEncoderOutput,
381        multimask_output: bool,
382    ) -> Result<Sam2MaskDecoderOutput> {
383        let lvl_stride16 = &levels[2]; // stride 16 → 64×64
384        let lvl_stride8 = &levels[1]; // stride 8  → 128×128
385        let lvl_stride4 = &levels[0]; // stride 4  → 256×256
386
387        let high_res_features = if self.mask_dec.use_high_res_features {
388            Some((
389                lvl_stride4.features.as_slice(),
390                lvl_stride8.features.as_slice(),
391            ))
392        } else {
393            None
394        };
395
396        ensure!(
397            lvl_stride16.h == SAM2_PROMPT_GRID && lvl_stride16.w == SAM2_PROMPT_GRID,
398            "stride-16 FPN level must be {}×{} (got {}×{})",
399            SAM2_PROMPT_GRID,
400            SAM2_PROMPT_GRID,
401            lvl_stride16.h,
402            lvl_stride16.w
403        );
404
405        mask_decoder_forward(
406            &self.mask_dec,
407            &mut self.upscale,
408            Some(&mut self.hyper_matmul),
409            Some(&mut self.hyper_mlps_ir),
410            Some(&mut self.iou_head_ir),
411            self.obj_score_head_ir.as_mut(),
412            self.obj_ptr_proj_ir.as_mut(),
413            Some(&mut self.tw_ir),
414            &lvl_stride16.features,
415            &lvl_stride16.pos,
416            &prompt.sparse_embeddings,
417            prompt.num_sparse_tokens,
418            &prompt.dense_embeddings,
419            high_res_features,
420            multimask_output,
421            SAM2_PROMPT_GRID,
422        )
423    }
424
425    /// Per-frame video API. Wraps [`Sam2::predict_image`] with the
426    /// memory-attention path (cross-attend the current frame's stride-32
427    /// features to the bank) and the memory-encoder path (encode the
428    /// chosen mask + features into the bank).
429    ///
430    /// Mirrors `SAM2VideoPredictor.add_new_points_or_box` +
431    /// `propagate_in_video` semantics: when `state` is empty, this acts
432    /// as image-predict; otherwise it conditions on stored frames.
433    pub fn predict_video_frame(
434        &mut self,
435        state: &mut Sam2VideoState,
436        image_u8: &[u8],
437        h_in: usize,
438        w_in: usize,
439        points: Option<(&[f32], &[f32])>,
440        boxes: Option<&[f32]>,
441        mask_input: Option<&[f32]>,
442        multimask_output: bool,
443    ) -> Result<Sam2ImagePrediction> {
444        let levels = self.encode(image_u8, h_in, w_in)?;
445
446        // Stride-32 level (index 3) is the queries source for memory
447        // attention — matches the reference's `vision_features` at
448        // 32×32 resolution.
449        let stride32 = &levels[3];
450        let mut conditioned_stride32: Vec<f32> = stride32.features.clone();
451        if !state.memory.is_empty() {
452            let curr = nchw_to_seq_c(
453                &stride32.features,
454                self.cfg.memory.d_model,
455                stride32.h,
456                stride32.w,
457            );
458            let curr_pos = nchw_to_seq_c(
459                &stride32.pos,
460                self.cfg.memory.d_model,
461                stride32.h,
462                stride32.w,
463            );
464
465            let (memory_flat, memory_pos_flat, n_mem) =
466                state.assembled_memory(self.cfg.memory.kv_in_dim, self.cfg.memory.mem_dim);
467            let n_img = stride32.h * stride32.w;
468            let num_ptr = state.num_obj_ptr_tokens(self.cfg.memory.mem_dim);
469            self.ensure_mem_attn_ir()?;
470            let ir = self.mem_attn_ir.as_mut().expect("mem_attn_ir");
471            let attn_out = if n_img == ir.n_img && n_mem <= ir.max_n_mem {
472                ir.run(
473                    &curr,
474                    &curr_pos,
475                    &memory_flat,
476                    &memory_pos_flat,
477                    n_mem,
478                    num_ptr,
479                )?
480            } else {
481                memory_attention_forward(
482                    &self.mem_attn,
483                    &curr,
484                    &curr_pos,
485                    &memory_flat,
486                    &memory_pos_flat,
487                    n_img,
488                    n_mem,
489                    self.cfg.memory.kv_in_dim,
490                    num_ptr,
491                )?
492            };
493            // Reshape back to NCHW.
494            conditioned_stride32 =
495                seq_c_to_nchw(&attn_out, self.cfg.memory.d_model, stride32.h, stride32.w);
496        }
497
498        // Splice the conditioned features back into level[3] for the
499        // decoder. Decoder reads stride-16 (level[2]) for image_emb +
500        // dense, so we only condition the memory-attention output for
501        // *propagation* — the stride-16 path is unmodified per the
502        // reference.
503        let mut levels = levels;
504        levels[3].features = conditioned_stride32;
505
506        let prompt = self.run_prompt(points, boxes, mask_input)?;
507        let dec = self.run_decoder(&levels, &prompt, multimask_output)?;
508
509        // Encode the chosen mask + stride-16 features into memory and
510        // push them onto the state's bank.
511        let stride16 = &levels[2];
512        let mem = run_memory_encoder(&mut self.mem_enc, &stride16.features, &dec)?;
513        state.push_frame_memory(
514            mem,
515            dec.object_pointer.clone(),
516            self.cfg.memory.max_obj_ptrs_in_encoder,
517        );
518
519        Ok(Sam2ImagePrediction {
520            masks: dec.masks,
521            iou_pred: dec.iou_pred,
522            num_masks: dec.num_masks,
523            h_out: dec.h_out,
524            w_out: dec.w_out,
525            object_score_logits: dec.object_score_logits,
526            object_pointer: dec.object_pointer,
527        })
528    }
529}
530
531/// One frame's worth of mask-decoder output, as returned by both
532/// [`Sam2::predict_image`] and [`Sam2::predict_video_frame`].
533pub struct Sam2ImagePrediction {
534    pub masks: Vec<f32>,
535    pub iou_pred: Vec<f32>,
536    pub num_masks: usize,
537    pub h_out: usize,
538    pub w_out: usize,
539    pub object_score_logits: Vec<f32>,
540    pub object_pointer: Option<Vec<f32>>,
541}
542
543/// Per-track state for [`Sam2::predict_video_frame`]. Stores up to
544/// `max_obj_ptrs_in_encoder` past memory tokens + the rolling
545/// object-pointer queue.
546pub struct Sam2VideoState {
547    /// Each entry: `(features [out_dim, h, w] flat, pos [..., h, w] flat, h, w)`.
548    pub memory: Vec<Sam2MemoryEncoderOutput>,
549    pub obj_ptr_queue: Vec<Vec<f32>>,
550}
551
552impl Sam2VideoState {
553    pub fn new() -> Self {
554        Self {
555            memory: Vec::new(),
556            obj_ptr_queue: Vec::new(),
557        }
558    }
559
560    /// Total number of memory tokens (spatial + obj-ptr) in the
561    /// concatenated memory bank. `mem_dim` is the obj-pointer
562    /// channel dim (typically 64).
563    pub fn num_obj_ptr_tokens(&self, _mem_dim: usize) -> usize {
564        // Each stored obj-ptr is a single token (the reference splits a
565        // higher-dim ptr into 4 sub-tokens via `obj_ptr_proj`, but at
566        // the level we expose here we treat each frame's pointer as a
567        // single token). When training a sub-token split, the user can
568        // extend this fn.
569        self.obj_ptr_queue.len()
570    }
571
572    /// Concatenate the per-frame memories into a single
573    /// `(memory [N_mem, kv_in_dim], memory_pos [N_mem, kv_in_dim])`
574    /// pair for the memory-attention call. Spatial tokens go first,
575    /// object-pointer tokens at the tail (so `num_k_exclude_rope`
576    /// works correctly).
577    pub fn assembled_memory(
578        &self,
579        kv_in_dim: usize,
580        _mem_dim: usize,
581    ) -> (Vec<f32>, Vec<f32>, usize) {
582        let mut features = Vec::new();
583        let mut positions = Vec::new();
584        let mut total_tokens = 0usize;
585
586        for m in &self.memory {
587            let tokens = m.h * m.w;
588            // Flatten [out_dim, h, w] → [tokens, out_dim] (matches kv_in_dim).
589            let mut feat_seq = vec![0f32; tokens * kv_in_dim];
590            let mut pos_seq = vec![0f32; tokens * kv_in_dim];
591            let pe_chans = m.pos.len() / (m.h * m.w);
592            for t in 0..tokens {
593                for c in 0..kv_in_dim {
594                    feat_seq[t * kv_in_dim + c] = m.features[c * tokens + t];
595                }
596                // PE may have more channels than kv_in_dim (e.g. 128 vs 64).
597                // We only copy the first `kv_in_dim` to match memory's channel layout.
598                for c in 0..kv_in_dim.min(pe_chans) {
599                    pos_seq[t * kv_in_dim + c] = m.pos[c * tokens + t];
600                }
601            }
602            features.extend_from_slice(&feat_seq);
603            positions.extend_from_slice(&pos_seq);
604            total_tokens += tokens;
605        }
606
607        // Append object-pointer tokens (no PE — they go in the
608        // `num_k_exclude_rope` band).
609        for ptr in &self.obj_ptr_queue {
610            ensure_or_zero(&mut features, &mut positions, ptr, kv_in_dim);
611            total_tokens += 1;
612        }
613
614        (features, positions, total_tokens)
615    }
616
617    fn push_frame_memory(
618        &mut self,
619        mem: Sam2MemoryEncoderOutput,
620        obj_ptr: Option<Vec<f32>>,
621        max_ptrs: usize,
622    ) {
623        self.memory.push(mem);
624        if let Some(p) = obj_ptr {
625            self.obj_ptr_queue.push(p);
626            while self.obj_ptr_queue.len() > max_ptrs {
627                self.obj_ptr_queue.remove(0);
628            }
629        }
630    }
631}
632
633impl Default for Sam2VideoState {
634    fn default() -> Self {
635        Self::new()
636    }
637}
638
639fn ensure_or_zero(
640    features: &mut Vec<f32>,
641    positions: &mut Vec<f32>,
642    ptr: &[f32],
643    kv_in_dim: usize,
644) {
645    if ptr.len() == kv_in_dim {
646        features.extend_from_slice(ptr);
647    } else {
648        // Reference's `obj_ptr_proj` produces `transformer_dim`-sized
649        // pointers (256), which the loader reshape-projects into
650        // `mem_dim` (64) chunks via `obj_ptr_proj.layers.{i}.weight`.
651        // We approximate by taking the first `kv_in_dim` channels — a
652        // correct full split requires the loader's reshape; the user
653        // can pre-project before calling.
654        let take = ptr.len().min(kv_in_dim);
655        features.extend_from_slice(&ptr[..take]);
656        for _ in take..kv_in_dim {
657            features.push(0.0);
658        }
659    }
660    for _ in 0..kv_in_dim {
661        positions.push(0.0);
662    }
663}
664
665fn run_memory_encoder(
666    mem_enc: &mut Sam2MemoryEncoderWeights,
667    pix_feat: &[f32],
668    dec: &Sam2MaskDecoderOutput,
669) -> Result<Sam2MemoryEncoderOutput> {
670    // We always pick the first (top-IoU) mask to encode. Reference
671    // `SAM2Base._encode_new_memory` does the same when caller doesn't
672    // override.
673    // dec.masks shape: [num_masks, h_out, w_out]. Take mask 0.
674    let m_chunk = dec.h_out * dec.w_out;
675    ensure!(
676        dec.masks.len() >= m_chunk,
677        "decoder produced empty mask buffer"
678    );
679    let mask0 = &dec.masks[..m_chunk];
680
681    // Reference upsamples the 256×256 mask to 1024×1024 before
682    // memory-encoding (`F.interpolate(masks, size=(1024, 1024),
683    // mode="bilinear")`). We do the same with a cheap bilinear.
684    let mut up_mask = vec![0f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
685    bilinear_upsample_1ch(
686        mask0,
687        dec.h_out,
688        dec.w_out,
689        &mut up_mask,
690        SAM2_IMG_SIZE,
691        SAM2_IMG_SIZE,
692    );
693
694    memory_encoder_forward(
695        mem_enc,
696        pix_feat,
697        &up_mask,
698        SAM2_PROMPT_GRID,
699        SAM2_PROMPT_GRID,
700        /*skip_mask_sigmoid=*/ false,
701    )
702}
703
704fn bilinear_upsample_1ch(src: &[f32], sh: usize, sw: usize, dst: &mut [f32], dh: usize, dw: usize) {
705    let sx = (sw as f32) / (dw as f32);
706    let sy = (sh as f32) / (dh as f32);
707    for y in 0..dh {
708        let yf = (y as f32 + 0.5) * sy - 0.5;
709        let y0 = yf.floor().max(0.0) as usize;
710        let y1 = (y0 + 1).min(sh - 1);
711        let dy = (yf - yf.floor()).clamp(0.0, 1.0);
712        for x in 0..dw {
713            let xf = (x as f32 + 0.5) * sx - 0.5;
714            let x0 = xf.floor().max(0.0) as usize;
715            let x1 = (x0 + 1).min(sw - 1);
716            let dx = (xf - xf.floor()).clamp(0.0, 1.0);
717            let p00 = src[y0 * sw + x0];
718            let p01 = src[y0 * sw + x1];
719            let p10 = src[y1 * sw + x0];
720            let p11 = src[y1 * sw + x1];
721            let top = p00 * (1.0 - dx) + p01 * dx;
722            let bot = p10 * (1.0 - dx) + p11 * dx;
723            dst[y * dw + x] = top * (1.0 - dy) + bot * dy;
724        }
725    }
726}
727
728fn nchw_to_seq_c(src: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
729    let mut out = vec![0f32; h * w * c];
730    for y in 0..h {
731        for x in 0..w {
732            for ch in 0..c {
733                out[(y * w + x) * c + ch] = src[ch * h * w + y * w + x];
734            }
735        }
736    }
737    out
738}
739
740fn seq_c_to_nchw(src: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
741    let mut out = vec![0f32; c * h * w];
742    for y in 0..h {
743        for x in 0..w {
744            for ch in 0..c {
745                out[ch * h * w + y * w + x] = src[(y * w + x) * c + ch];
746            }
747        }
748    }
749    out
750}
751
752#[allow(dead_code)]
753fn _silence_decoder_cfg(_d: &Sam2DecoderConfig) {}