Skip to main content

rlx_sam3/
sam3.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM3 top-level API.
17//!
18//! This module owns the shipped checkpoint-facing surface for base SAM3.
19//! The native surface exposes SAM3 preprocessing and ViT patch embeddings.
20//! Full image/video inference runs through native Rust modules. The Python
21//! helper in this directory is kept only as a parity reference.
22
23use super::config::Sam3Config;
24use super::detector::{Sam3DetectorWeights, detector_forward_native};
25use super::detector_decoder::{
26    Sam3DecoderOutput, Sam3DecoderWeights, extract_decoder_weights, forward_decoder,
27};
28use super::detector_encoder::{Sam3EncoderWeights, extract_encoder_weights, forward_encoder};
29use super::detector_encoder_ir::{forward_encoder_ir, forward_encoder_ir_on_with_profile};
30use super::geometry::{Sam3GeometryWeights, encode_geometry_native};
31use super::neck::{
32    Sam3NeckWeights, apply_neck_native, compile_neck_branches, extract_neck_weights,
33};
34use super::preprocess::{assemble_patch_tokens, preprocess_image};
35use super::segmentation_head::{
36    Sam3DotProductScoringWeights, Sam3SegmentationHeadWeights, Sam3SegmentationOutput,
37    compile_segmentation_ir, extract_dot_product_scoring_weights,
38    extract_segmentation_head_weights, forward_dot_prod_scoring, forward_segmentation,
39    segmentation_forward_native,
40};
41use super::text_encoder::{
42    Sam3TextEncoded, Sam3TextEncoderWeights, encode_text_native, encode_tokens,
43    extract_text_encoder_weights,
44};
45use super::tracker::{Sam3TrackerWeights, extract_tracker_weights, tracker_forward_native};
46use super::vision_encoder::{
47    Sam3VisionEncoderWeights, encode_image_native, extract_vision_encoder_weights,
48};
49use anyhow::{Context, Result, ensure};
50use rlx_flow::CompileProfile;
51use rlx_runtime::Device;
52use rlx_sam::profile::sam3_profile_near_weights;
53use std::path::Path;
54
55#[derive(Debug, Clone)]
56pub struct Sam3EncodedImage {
57    /// `[grid * grid, embed_dim]` flattened row-major patch tokens.
58    pub patch_tokens: Vec<f32>,
59    pub grid: usize,
60    pub embed_dim: usize,
61    pub resized_hw: (usize, usize),
62}
63
64#[derive(Debug, Clone)]
65pub struct Sam3ImagePrediction {
66    /// Mask logits/probabilities flattened in row-major order. The shape
67    /// is available in `mask_shape`.
68    pub masks: Vec<f32>,
69    pub mask_shape: Vec<usize>,
70    pub boxes: Vec<f32>,
71    pub boxes_shape: Vec<usize>,
72    pub scores: Vec<f32>,
73    pub scores_shape: Vec<usize>,
74    pub num_instances: usize,
75    pub h_out: usize,
76    pub w_out: usize,
77}
78
79#[derive(Debug, Clone, Default)]
80pub struct Sam3VideoState {
81    pub frame_index: usize,
82    pub memory_tokens: Vec<Vec<f32>>,
83    pub last_prediction: Option<Sam3ImagePrediction>,
84}
85
86#[derive(Debug, Clone)]
87pub struct Sam3VideoFramePrediction {
88    pub frame_index: usize,
89    pub image: Sam3ImagePrediction,
90    pub memory_len: usize,
91}
92
93pub struct Sam3 {
94    cfg: Sam3Config,
95    vision: Option<Sam3VisionEncoderWeights>,
96    neck: Sam3NeckWeights,
97    text: Sam3TextEncoderWeights,
98    geometry: Sam3GeometryWeights,
99    detector: Sam3DetectorWeights,
100    encoder: Sam3EncoderWeights,
101    decoder: Sam3DecoderWeights,
102    seg_head: Sam3SegmentationHeadWeights,
103    scoring: Sam3DotProductScoringWeights,
104    seg: Sam3SegmentationHeadWeights,
105    tracker: Sam3TrackerWeights,
106    device: Device,
107    compile_profile: CompileProfile,
108    gguf_packed: Option<rlx_flow::GgufPackedParams>,
109}
110
111impl Sam3 {
112    /// Load a SAM3 checkpoint for native inference.
113    ///
114    /// Load from `.safetensors` or `.gguf` (`general.architecture` = `sam3`).
115    ///
116    /// For PyTorch checkpoints, convert with `tests/sam3_parity_helpers/pt_to_safetensors.py`
117    /// or use community GGUF (e.g. rob-laz/sam3-gguf).
118    pub fn from_checkpoint(weights_path: &str, cfg: Sam3Config) -> Result<Self> {
119        Self::from_checkpoint_on(weights_path, cfg, Device::Cpu)
120    }
121
122    pub fn from_checkpoint_on(weights_path: &str, cfg: Sam3Config, device: Device) -> Result<Self> {
123        Self::from_safetensors_on(weights_path, cfg, device)
124    }
125
126    pub fn from_safetensors(weights_path: &str, cfg: Sam3Config) -> Result<Self> {
127        Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
128    }
129
130    pub fn from_safetensors_on(
131        weights_path: &str,
132        cfg: Sam3Config,
133        device: Device,
134    ) -> Result<Self> {
135        rlx_core::validate_sam_device("sam3", device)?;
136
137        let path = Path::new(weights_path);
138        let is_gguf = path.extension().is_some_and(|e| e == "gguf");
139        if is_gguf {
140            rlx_core::gguf_validate_arch(path, rlx_core::SAM3_GGUF_ARCHES)?;
141        }
142        let (mut wm, gguf_packed) = if is_gguf && crate::packed_gguf::gguf_has_packed_linears(path)?
143        {
144            eprintln!("[sam3] loading GGUF with packed ViT matmul {path:?}");
145            let (wm, packed) = crate::packed_gguf::load_sam3_from_gguf(path)?;
146            (wm, Some(packed))
147        } else {
148            (
149                rlx_core::load_weight_map(path, rlx_core::SAM3_GGUF_ARCHES)?,
150                None,
151            )
152        };
153        let compile_profile = sam3_profile_near_weights(path);
154        let vision = extract_vision_encoder_weights(&mut wm, &cfg.vit, gguf_packed.as_ref())?;
155        let mut neck = extract_neck_weights(&mut wm)?;
156        compile_neck_branches(
157            &mut neck,
158            cfg.vit.embed_dim,
159            cfg.vit.patch_grid(),
160            device,
161            &compile_profile,
162        )?;
163        let text = extract_text_encoder_weights(&mut wm, &cfg.text, gguf_packed.as_ref())?;
164        let encoder = extract_encoder_weights(&mut wm, gguf_packed.as_ref())?;
165        let decoder = extract_decoder_weights(&mut wm, gguf_packed.as_ref())?;
166        let mut seg_head = extract_segmentation_head_weights(&mut wm, gguf_packed.as_ref())?;
167        compile_segmentation_ir(
168            &mut seg_head,
169            gguf_packed.as_ref(),
170            cfg.vit.patch_grid(),
171            device,
172            &compile_profile,
173        )?;
174        let scoring = extract_dot_product_scoring_weights(&mut wm, gguf_packed.as_ref())?;
175        let tracker = extract_tracker_weights(&mut wm)?;
176        Ok(Self {
177            cfg,
178            vision: Some(vision),
179            neck,
180            text,
181            geometry: Sam3GeometryWeights::default(),
182            detector: Sam3DetectorWeights::default(),
183            encoder,
184            seg: Sam3SegmentationHeadWeights::default(),
185            tracker,
186            decoder,
187            seg_head,
188            scoring,
189            device,
190            compile_profile,
191            gguf_packed,
192        })
193    }
194
195    /// Tier-1 compile profile (`sam.rlx.toml` next to weights when present).
196    pub fn compile_profile(&self) -> &CompileProfile {
197        &self.compile_profile
198    }
199
200    pub fn config(&self) -> &Sam3Config {
201        &self.cfg
202    }
203
204    /// Returns the loaded tracker weights (used by the video basic test
205    /// to confirm checkpoint coverage).
206    pub fn tracker_weights(&self) -> &Sam3TrackerWeights {
207        &self.tracker
208    }
209
210    pub fn encoder_weights(&self) -> &Sam3EncoderWeights {
211        &self.encoder
212    }
213
214    pub fn decoder_weights(&self) -> &Sam3DecoderWeights {
215        &self.decoder
216    }
217
218    pub fn device(&self) -> Device {
219        self.device
220    }
221
222    pub fn encode_image(
223        &self,
224        image_u8: &[u8],
225        h_in: usize,
226        w_in: usize,
227    ) -> Result<Sam3EncodedImage> {
228        let vision = self
229            .vision
230            .as_ref()
231            .context("SAM3 encode_image requires native vision weights")?;
232        let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
233        let encoded = encode_image_native(
234            vision,
235            self.gguf_packed.as_ref(),
236            &self.cfg.vit,
237            &image_nchw,
238        )?;
239        Ok(Sam3EncodedImage {
240            patch_tokens: encoded.tokens,
241            grid: encoded.grid,
242            embed_dim: encoded.dim,
243            resized_hw,
244        })
245    }
246
247    /// Run the vision trunk + 4-scale neck and return per-level
248    /// `[channels, h, w]` feature maps with matching sinusoidal positional
249    /// encodings. Used by the detector and as a parity gate.
250    /// End-to-end image inference with a pre-tokenized text prompt
251    /// (`tokens` has length `seq_len` == decoder context length, usually
252    /// 32). Returns the same 3-tuple the public `Sam3Processor.set_text_
253    /// prompt` exposes — without NMS / score thresholding, which we leave
254    /// to callers so parity tests can compare raw model outputs.
255    pub fn predict_image_text(
256        &mut self,
257        image_u8: &[u8],
258        h_in: usize,
259        w_in: usize,
260        tokens: &[u32],
261    ) -> Result<Sam3ImagePrediction> {
262        let cfg = &self.cfg;
263        let nq = 200;
264        let seq_len = tokens.len();
265
266        // Vision + neck.
267        let vision = self
268            .vision
269            .as_ref()
270            .context("predict_image_text requires native vision weights")?;
271        let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
272        let vision_out = super::vision_encoder::encode_image_native(
273            vision,
274            self.gguf_packed.as_ref(),
275            &cfg.vit,
276            &image_nchw,
277        )?;
278        let levels = apply_neck_native(&mut self.neck, &vision_out)?;
279        // Drop the last (scale=0.5) level per scalp=1.
280        let kept = &levels[..3];
281        let backbone_fpn: Vec<Vec<f32>> = kept.iter().map(|l| l.features.clone()).collect();
282        let backbone_shapes: Vec<(usize, usize)> = kept.iter().map(|l| (l.h, l.w)).collect();
283        // Encoder input = level scale=1.0 (index 2).
284        let src_level = &kept[2];
285        let h = src_level.h;
286        let w = src_level.w;
287        let batch = 1;
288
289        // Text encoder.
290        let text_out = encode_tokens(
291            &self.text,
292            tokens,
293            batch,
294            seq_len,
295            self.gguf_packed.as_ref(),
296        )?;
297
298        // Detector encoder (single-level fusion).
299        let memory_bf = forward_encoder(
300            &self.encoder,
301            &src_level.features,
302            &src_level.pos,
303            &text_out.text_memory_resized,
304            &text_out.attention_mask,
305            batch,
306            h,
307            w,
308            seq_len,
309            self.gguf_packed.as_ref(),
310        )?;
311        // Convert FPN pos to batch-first for the decoder.
312        let mut memory_pos = vec![0f32; batch * h * w * 256];
313        for b in 0..batch {
314            for y in 0..h {
315                for xc in 0..w {
316                    for c in 0..256 {
317                        memory_pos[(b * h * w + y * w + xc) * 256 + c] =
318                            src_level.pos[((b * 256 + c) * h + y) * w + xc];
319                    }
320                }
321            }
322        }
323
324        // Detector decoder.
325        let dec = forward_decoder(
326            &self.decoder,
327            &memory_bf,
328            &memory_pos,
329            &text_out.text_memory_resized,
330            &text_out.attention_mask,
331            batch,
332            h,
333            w,
334            seq_len,
335            self.gguf_packed.as_ref(),
336        )?;
337
338        // Last-layer queries batch-first.
339        let num_layers = dec.num_layers;
340        let mut queries_last_bf = vec![0f32; batch * nq * 256];
341        let li = num_layers - 1;
342        for q in 0..nq {
343            for b in 0..batch {
344                let src = ((li * nq + q) * batch + b) * 256;
345                let dst = (b * nq + q) * 256;
346                queries_last_bf[dst..dst + 256].copy_from_slice(&dec.intermediate[src..src + 256]);
347            }
348        }
349
350        // Last layer's refined reference boxes.
351        let mut ref_last_bf = vec![0f32; batch * nq * 4];
352        for q in 0..nq {
353            for b in 0..batch {
354                let src = ((li * nq + q) * batch + b) * 4;
355                let dst = (b * nq + q) * 4;
356                ref_last_bf[dst..dst + 4]
357                    .copy_from_slice(&dec.intermediate_ref_boxes[src..src + 4]);
358            }
359        }
360
361        // Final boxes: sigmoid(inv_sigmoid(ref_last) + bbox_embed(queries_last)).
362        let delta = super::detector_decoder::bbox_embed_forward(
363            &self.decoder,
364            &queries_last_bf,
365            batch * nq,
366            self.gguf_packed.as_ref(),
367        )?;
368        let mut final_boxes_cxcywh = vec![0f32; batch * nq * 4];
369        for q in 0..nq {
370            for b in 0..batch {
371                let rb = &ref_last_bf[(b * nq + q) * 4..(b * nq + q + 1) * 4];
372                let d = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
373                let out_off = (b * nq + q) * 4;
374                for k in 0..4 {
375                    let inv = if rb[k] <= 0.0 {
376                        (1e-3f32 / (1.0 - 1e-3)).ln()
377                    } else if rb[k] >= 1.0 {
378                        ((1.0 - 1e-3) / 1e-3f32).ln()
379                    } else {
380                        (rb[k].max(1e-3) / (1.0 - rb[k]).max(1e-3)).ln()
381                    };
382                    let s = inv + d[k];
383                    final_boxes_cxcywh[out_off + k] = 1.0 / (1.0 + (-s).exp());
384                }
385            }
386        }
387        // Convert to xyxy.
388        let mut boxes_xyxy = vec![0f32; batch * nq * 4];
389        for i in 0..(batch * nq) {
390            let cx = final_boxes_cxcywh[i * 4];
391            let cy = final_boxes_cxcywh[i * 4 + 1];
392            let bw = final_boxes_cxcywh[i * 4 + 2];
393            let bh = final_boxes_cxcywh[i * 4 + 3];
394            boxes_xyxy[i * 4] = cx - 0.5 * bw;
395            boxes_xyxy[i * 4 + 1] = cy - 0.5 * bh;
396            boxes_xyxy[i * 4 + 2] = cx + 0.5 * bw;
397            boxes_xyxy[i * 4 + 3] = cy + 0.5 * bh;
398        }
399
400        // Scores: dot product scoring, last layer.
401        let mut hs_bf = vec![0f32; num_layers * batch * nq * 256];
402        for l in 0..num_layers {
403            for q in 0..nq {
404                for b in 0..batch {
405                    let src = ((l * nq + q) * batch + b) * 256;
406                    let dst = ((l * batch + b) * nq + q) * 256;
407                    hs_bf[dst..dst + 256].copy_from_slice(&dec.intermediate[src..src + 256]);
408                }
409            }
410        }
411        let all_scores = forward_dot_prod_scoring(
412            &self.scoring,
413            &hs_bf,
414            &text_out.text_memory_resized,
415            &text_out.attention_mask,
416            num_layers,
417            batch,
418            nq,
419            seq_len,
420            self.gguf_packed.as_ref(),
421        )?;
422        let last_scores =
423            all_scores[(num_layers - 1) * batch * nq..num_layers * batch * nq].to_vec();
424
425        // Segmentation.
426        let seg = forward_segmentation(
427            &mut self.seg_head,
428            &memory_bf,
429            &backbone_fpn,
430            &backbone_shapes,
431            &queries_last_bf,
432            &text_out.text_memory_resized,
433            &text_out.attention_mask,
434            batch,
435            h,
436            w,
437            nq,
438            seq_len,
439            self.gguf_packed.as_ref(),
440        )?;
441
442        Ok(Sam3ImagePrediction {
443            masks: seg.mask_pred,
444            mask_shape: vec![batch, nq, seg.h_out, seg.w_out],
445            boxes: boxes_xyxy,
446            boxes_shape: vec![batch, nq, 4],
447            scores: last_scores,
448            scores_shape: vec![batch, nq],
449            num_instances: nq,
450            h_out: resized_hw.0,
451            w_out: resized_hw.1,
452        })
453    }
454
455    /// Forward the segmentation head: cross-attend encoder memory to the
456    /// text prompt, run the pixel decoder, and emit per-query mask logits
457    /// plus the semantic mask.
458    #[allow(clippy::too_many_arguments)]
459    pub fn run_segmentation(
460        &mut self,
461        enc_memory_bf: &[f32],
462        backbone_fpn: &[Vec<f32>],
463        backbone_shapes: &[(usize, usize)],
464        obj_queries_last_bf: &[f32],
465        prompt_seq_first: &[f32],
466        prompt_kpm: &[u8],
467        batch: usize,
468        enc_h: usize,
469        enc_w: usize,
470        num_queries: usize,
471        seq_len: usize,
472    ) -> Result<Sam3SegmentationOutput> {
473        forward_segmentation(
474            &mut self.seg_head,
475            enc_memory_bf,
476            backbone_fpn,
477            backbone_shapes,
478            obj_queries_last_bf,
479            prompt_seq_first,
480            prompt_kpm,
481            batch,
482            enc_h,
483            enc_w,
484            num_queries,
485            seq_len,
486            self.gguf_packed.as_ref(),
487        )
488    }
489
490    /// Compute per-query, per-layer scores via mean-pooled text + linear
491    /// projections + dot product.
492    #[allow(clippy::too_many_arguments)]
493    pub fn run_dot_prod_scoring(
494        &self,
495        hs_bf: &[f32],
496        prompt_seq_first: &[f32],
497        prompt_kpm: &[u8],
498        num_layers: usize,
499        batch: usize,
500        num_queries: usize,
501        seq_len: usize,
502    ) -> Result<Vec<f32>> {
503        forward_dot_prod_scoring(
504            &self.scoring,
505            hs_bf,
506            prompt_seq_first,
507            prompt_kpm,
508            num_layers,
509            batch,
510            num_queries,
511            seq_len,
512            self.gguf_packed.as_ref(),
513        )
514    }
515
516    /// Run the detector decoder. Inputs are the encoder memory in
517    /// batch-first flat `[batch, h*w, 256]` plus matching positional
518    /// encoding, and the text memory in seq-first `[seq, batch, 256]`.
519    /// Returns intermediate layer outputs, refined boxes, and presence
520    /// logits — the same triple the upstream model uses to derive scores
521    /// and final box predictions.
522    #[allow(clippy::too_many_arguments)]
523    pub fn run_decoder(
524        &self,
525        memory: &[f32],
526        memory_pos: &[f32],
527        memory_text: &[f32],
528        text_attention_mask: &[u8],
529        batch: usize,
530        h: usize,
531        w: usize,
532        seq_len: usize,
533    ) -> Result<Sam3DecoderOutput> {
534        if rlx_ir::env::flag("RLX_SAM3_DECODER_HOST") {
535            return forward_decoder(
536                &self.decoder,
537                memory,
538                memory_pos,
539                memory_text,
540                text_attention_mask,
541                batch,
542                h,
543                w,
544                seq_len,
545                self.gguf_packed.as_ref(),
546            );
547        }
548        let dev = match rlx_ir::env::var("RLX_SAM3_DECODER_DEVICE").as_deref() {
549            Some("metal") => Device::Metal,
550            Some("mlx") => Device::Mlx,
551            Some("cuda") => Device::Cuda,
552            _ => self.device,
553        };
554        super::detector_decoder_ir::forward_decoder_ir_on_with_profile(
555            &self.decoder,
556            memory,
557            memory_pos,
558            memory_text,
559            text_attention_mask,
560            batch,
561            h,
562            w,
563            seq_len,
564            dev,
565            &self.compile_profile,
566            self.gguf_packed.as_ref(),
567        )
568    }
569
570    /// Run the detector encoder fusion on a single FPN level + text
571    /// prompt. Returns the encoded image memory in batch-first flat
572    /// `[batch, h*w, 256]`.
573    #[allow(clippy::too_many_arguments)]
574    pub fn run_encoder(
575        &self,
576        src_bchw: &[f32],
577        src_pos_bchw: &[f32],
578        prompt_seq_first: &[f32],
579        prompt_kpm: &[u8],
580        batch: usize,
581        src_h: usize,
582        src_w: usize,
583        prompt_len: usize,
584    ) -> Result<Vec<f32>> {
585        // Backend selection:
586        //   RLX_SAM3_ENCODER_HOST=1 → host-side per-head sgemm (legacy).
587        //   RLX_SAM3_ENCODER_DEVICE=metal → IR on Metal (default Cpu).
588        if rlx_ir::env::flag("RLX_SAM3_ENCODER_HOST") {
589            return forward_encoder(
590                &self.encoder,
591                src_bchw,
592                src_pos_bchw,
593                prompt_seq_first,
594                prompt_kpm,
595                batch,
596                src_h,
597                src_w,
598                prompt_len,
599                self.gguf_packed.as_ref(),
600            );
601        }
602        let dev = match rlx_ir::env::var("RLX_SAM3_ENCODER_DEVICE").as_deref() {
603            Some("metal") => Device::Metal,
604            Some("mlx") => Device::Mlx,
605            _ => Device::Cpu,
606        };
607        let _ = forward_encoder_ir; // silence unused if always _on
608        forward_encoder_ir_on_with_profile(
609            &self.encoder,
610            src_bchw,
611            src_pos_bchw,
612            prompt_seq_first,
613            prompt_kpm,
614            batch,
615            src_h,
616            src_w,
617            prompt_len,
618            dev,
619            &self.compile_profile,
620            self.gguf_packed.as_ref(),
621        )
622    }
623
624    /// Run the text encoder on already-tokenized inputs. Returns the
625    /// resized memory the detector consumes.
626    pub fn encode_text_tokens(
627        &self,
628        tokens: &[u32],
629        batch: usize,
630        seq_len: usize,
631    ) -> Result<Sam3TextEncoded> {
632        encode_tokens(
633            &self.text,
634            tokens,
635            batch,
636            seq_len,
637            self.gguf_packed.as_ref(),
638        )
639    }
640
641    pub fn predict_neck(
642        &mut self,
643        image_u8: &[u8],
644        h_in: usize,
645        w_in: usize,
646    ) -> Result<Vec<super::neck::Sam3FeatureLevel>> {
647        let vision = self
648            .vision
649            .as_ref()
650            .context("SAM3 predict_neck requires native vision weights")?;
651        let (image_nchw, _) = preprocess_image(image_u8, h_in, w_in);
652        let vision_out = super::vision_encoder::encode_image_native(
653            vision,
654            self.gguf_packed.as_ref(),
655            &self.cfg.vit,
656            &image_nchw,
657        )?;
658        apply_neck_native(&mut self.neck, &vision_out)
659    }
660
661    pub fn patch_embed_image(
662        &self,
663        image_u8: &[u8],
664        h_in: usize,
665        w_in: usize,
666    ) -> Result<Sam3EncodedImage> {
667        let vision = self
668            .vision
669            .as_ref()
670            .context("SAM3 patch_embed_image requires native vision weights")?;
671        let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
672        let patch_tokens = assemble_patch_tokens(&vision.pre, &image_nchw)?;
673        Ok(Sam3EncodedImage {
674            patch_tokens,
675            grid: vision.pre.grid,
676            embed_dim: vision.pre.embed_dim,
677            resized_hw,
678        })
679    }
680
681    pub fn predict_image(
682        &mut self,
683        image_u8: &[u8],
684        h_in: usize,
685        w_in: usize,
686        text_prompt: Option<&str>,
687        boxes: Option<&[f32]>,
688        points: Option<(&[f32], &[f32])>,
689    ) -> Result<Sam3ImagePrediction> {
690        self.predict_image_native(image_u8, h_in, w_in, text_prompt, boxes, points)
691    }
692
693    pub fn predict_video_frame(
694        &mut self,
695        state: &mut Sam3VideoState,
696        image_u8: &[u8],
697        h_in: usize,
698        w_in: usize,
699        text_prompt: Option<&str>,
700    ) -> Result<Sam3VideoFramePrediction> {
701        let pred = self.predict_image_native(image_u8, h_in, w_in, text_prompt, None, None)?;
702        Ok(tracker_forward_native(&self.tracker, state, pred))
703    }
704
705    fn predict_image_native(
706        &mut self,
707        image_u8: &[u8],
708        h_in: usize,
709        w_in: usize,
710        text_prompt: Option<&str>,
711        boxes: Option<&[f32]>,
712        points: Option<(&[f32], &[f32])>,
713    ) -> Result<Sam3ImagePrediction> {
714        ensure!(
715            image_u8.len() == h_in * w_in * 3,
716            "SAM3 image must be RGB u8 with len {} (got {})",
717            h_in * w_in * 3,
718            image_u8.len()
719        );
720        let vision = self
721            .vision
722            .as_ref()
723            .context("SAM3 predict_image requires native vision weights")?;
724        let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
725        let vision_out = encode_image_native(
726            vision,
727            self.gguf_packed.as_ref(),
728            &self.cfg.vit,
729            &image_nchw,
730        )?;
731        let levels = apply_neck_native(&mut self.neck, &vision_out)?;
732        let text = encode_text_native(
733            &self.text,
734            &self.cfg.text,
735            text_prompt,
736            self.gguf_packed.as_ref(),
737        )?;
738        let geometry = encode_geometry_native(&self.geometry, boxes, points);
739        let det = detector_forward_native(
740            &self.detector,
741            &self.cfg.detector,
742            &levels,
743            &text,
744            &geometry,
745        )?;
746        Ok(segmentation_forward_native(
747            &self.seg,
748            &det,
749            resized_hw.0,
750            resized_hw.1,
751        ))
752    }
753}